diff --git a/.github/bot_config.yml b/.github/bot_config.yml index b90b4f52c56d0f..9ddb1c272bbf1e 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -18,6 +18,7 @@ assignees: - sushreebarsa - SuryanarayanaY - tilakrayal + - Varsha-anjanappa # A list of assignees for compiler folder compiler_assignees: - joker-eph diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index a191c65a98f35f..0af184576f4489 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -19,6 +19,8 @@ on: push: tags: - v2.** + branches: + - r2.** schedule: - cron: '0 8 * * *' @@ -30,7 +32,7 @@ jobs: strategy: fail-fast: false matrix: - pyver: ['3.8', '3.9', '3.10'] + pyver: ['3.9', '3.10'] experimental: [false] include: - pyver: '3.11' @@ -66,5 +68,6 @@ jobs: CI_DOCKER_BUILD_EXTRA_PARAMS="--build-arg py_major_minor_version=${{ matrix.pyver }} --build-arg is_nightly=${is_nightly} --build-arg tf_project_name=${tf_project_name}" \ ./tensorflow/tools/ci_build/ci_build.sh cpu.arm64 bash tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh - name: Upload pip wheel to PyPI + if: github.event_name == 'schedule' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v2')) # only if it is a scheduled nightly or tagged shell: bash run: python3 -m twine upload --verbose /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/whl/* -u "__token__" -p ${{ secrets.AWS_PYPI_ACCOUNT_TOKEN }} diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 7c386590addf70..8dd4f437cde18b 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -29,7 +29,7 @@ jobs: strategy: fail-fast: false matrix: - pyver: ['3.8', '3.9', '3.10', '3.11'] + pyver: ['3.9', '3.10', '3.11'] steps: - name: Stop old running containers (if any) shell: bash diff --git a/WORKSPACE b/WORKSPACE index 389a4e5788011e..fb3af8a2bea085 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,9 +14,9 @@ http_archive( http_archive( name = "rules_python", - sha256 = "29a801171f7ca190c543406f9894abf2d483c206e14d6acbd695623662320097", - strip_prefix = "rules_python-0.18.1", - url = "https://github.com/bazelbuild/rules_python/releases/download/0.18.1/rules_python-0.18.1.tar.gz", + sha256 = "84aec9e21cc56fbc7f1335035a71c850d1b9b5cc6ff497306f84cced9a769841", + strip_prefix = "rules_python-0.23.1", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.23.1/rules_python-0.23.1.tar.gz", ) load("@rules_python//python:repositories.bzl", "python_register_toolchains") diff --git a/ci/official/any.sh b/ci/official/any.sh new file mode 100755 index 00000000000000..fd031ee8c757c9 --- /dev/null +++ b/ci/official/any.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source "${BASH_SOURCE%/*}/utilities/setup.sh" + +# Parse options and build targets into arrays, so that shelllint doesn't yell +# about readability. We can't pipe into 'read -ra' to create an array because +# piped commands run in subshells, which can't store variables outside of the +# subshell environment. +# Ignore grep failures since we're using it for basic filtering +set +e +filtered_build_targets=( $(echo "$BUILD_TARGETS" | tr ' ' '\n' | grep .) ) +nonpip_targets=( $(echo "$TEST_TARGETS" | tr ' ' '\n' | grep -E "^//tensorflow/" ) ) +config=( $(echo "$CONFIG_OPTIONS" ) ) +test_flags=( $(echo "$TEST_FLAGS" ) ) +set -e + +if [[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]]; then + tfrun nvidia-smi +fi + +if [[ "${#filtered_build_targets[@]}" -ne 0 ]]; then + tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" "${config[@]}" "${filtered_build_targets[@]}" +fi + +if [[ "${PIP_WHEEL}" -eq "1" ]]; then + # Update the version numbers to build a "nightly" package + if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then + tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly + fi + + tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package + tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" + tfrun ./ci/official/utilities/rename_and_verify_wheels.sh +fi + +if [[ "${#nonpip_targets[@]}" -ne 0 ]]; then + tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${config[@]}" "${test_flags[@]}" "${nonpip_targets[@]}" +fi diff --git a/ci/official/bazelrcs/cpu.bazelrc b/ci/official/bazelrcs/cpu.bazelrc new file mode 100644 index 00000000000000..f5094ce2289371 --- /dev/null +++ b/ci/official/bazelrcs/cpu.bazelrc @@ -0,0 +1,110 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This bazelrc can build a CPU-supporting TF package. + +# Convenient cache configurations +# Use a cache directory mounted to /tf/cache. Very useful! +build:sigbuild_local_cache --disk_cache=/tf/cache +# Use the public-access TF DevInfra cache (read only) +build:sigbuild_remote_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +# Write to the TF DevInfra cache (only works for internal TF CI) +build:sigbuild_remote_cache_push --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --google_default_credentials +# Change the value of CACHEBUSTER when upgrading the toolchain, or when testing +# different compilation methods. E.g. for a PR to test a new CUDA version, set +# the CACHEBUSTER to the PR number. +build --action_env=CACHEBUSTER=501872366 + +# Use Python 3.X as installed in container image +build --action_env PYTHON_BIN_PATH="/usr/bin/python3" +build --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" +build --python_path="/usr/bin/python3" + +# Build TensorFlow v2 +build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 + +# Target the AVX instruction set +build --copt=-mavx --host_copt=-mavx + +# Use lld as the linker +build --linkopt="-fuse-ld=lld" +build --linkopt="-lm" + +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build --copt=-Wno-gnu-offsetof-extensions + +# Store performance profiling log in the mounted artifact directory. +# The profile can be viewed by visiting chrome://tracing in a Chrome browser. +# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling +build --profile=build/profile.json.gz + +# Use the NVCC toolchain to compile for manylinux2014 +build --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" + +# Test-related settings below this point. +test --build_tests_only --keep_going --test_output=errors --verbose_failures=true +test --local_test_jobs=HOST_CPUS +test --test_env=LD_LIBRARY_PATH +# Give only the list of failed tests at the end of the log +test --test_summary=short + +# "nonpip" tests are regular py_test tests. +# Pass --config=nonpip to run the same suite of tests. If you want to run just +# one test for investigation, you don't need --config=nonpip; just run the +# bazel test invocation as normal. +test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... + +# For building libtensorflow archives +test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test +build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip + +# For outputting Build Event Protocol files +build:build_event_export --build_event_json_file=build/bep.json + +# For Remote Build Execution. +build:rbe --google_default_credentials +build:rbe --bes_backend=buildeventservice.googleapis.com +build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" +build:rbe --bes_timeout=600s +build:rbe --define=EXECUTOR=remote +build:rbe --jobs=800 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +build:rbe --remote_download_toplevel +build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm # Unclear why this is here +build:rbe --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --host_platform="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --platforms="@sigbuild-r2.14-clang_config_platform//:platform" +# Python config is the same across all containers because the binary is the same +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" +build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe --project_id="tensorflow-testing" + +# For continuous builds +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/bazelrcs/cpu_gcc.bazelrc b/ci/official/bazelrcs/cpu_gcc.bazelrc new file mode 100644 index 00000000000000..ff120786635cc2 --- /dev/null +++ b/ci/official/bazelrcs/cpu_gcc.bazelrc @@ -0,0 +1,99 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This bazelrc can build a CPU-supporting TF package. + +# Convenient cache configurations +# Use a cache directory mounted to /tf/cache. Very useful! +build:sigbuild_local_cache --disk_cache=/tf/cache +# Use the public-access TF DevInfra cache (read only) +build:sigbuild_remote_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +# Write to the TF DevInfra cache (only works for internal TF CI) +build:sigbuild_remote_cache_push --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --google_default_credentials +# Change the value of CACHEBUSTER when upgrading the toolchain, or when testing +# different compilation methods. E.g. for a PR to test a new CUDA version, set +# the CACHEBUSTER to the PR number. +build --action_env=CACHEBUSTER=501872366 + +# Use Python 3.X as installed in container image +build --action_env PYTHON_BIN_PATH="/usr/bin/python3" +build --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" +build --python_path="/usr/bin/python3" + +# Build TensorFlow v2 +build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 + +# Target the AVX instruction set +build --copt=-mavx --host_copt=-mavx + +# Store performance profiling log in the mounted artifact directory. +# The profile can be viewed by visiting chrome://tracing in a Chrome browser. +# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling +build --profile=build/profile.json.gz + +# Use the NVCC toolchain to compile for manylinux2014 +build --crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain" + +# Test-related settings below this point. +test --build_tests_only --keep_going --test_output=errors --verbose_failures=true +test --local_test_jobs=HOST_CPUS +test --test_env=LD_LIBRARY_PATH +# Give only the list of failed tests at the end of the log +test --test_summary=short + +# "nonpip" tests are regular py_test tests. +# Pass --config=nonpip to run the same suite of tests. If you want to run just +# one test for investigation, you don't need --config=nonpip; just run the +# bazel test invocation as normal. +test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... + +# For building libtensorflow archives +test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test +build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip + +# For outputting Build Event Protocol files +build:build_event_export --build_event_json_file=build/bep.json + +# For Remote Build Execution. +build:rbe --google_default_credentials +build:rbe --bes_backend=buildeventservice.googleapis.com +build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" +build:rbe --bes_timeout=600s +build:rbe --define=EXECUTOR=remote +build:rbe --jobs=800 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +build:rbe --remote_download_toplevel +build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm # Unclear why this is here +build:rbe --host_crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain" +build:rbe --crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain" +build:rbe --extra_toolchains="@sigbuild-r2.14_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe --extra_execution_platforms="@sigbuild-r2.14_config_platform//:platform" +build:rbe --host_platform="@sigbuild-r2.14_config_platform//:platform" +build:rbe --platforms="@sigbuild-r2.14_config_platform//:platform" +# Python config is the same across all containers because the binary is the same +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14_config_python" +build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe --project_id="tensorflow-testing" + +# For continuous builds +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/bazelrcs/nvidia.bazelrc b/ci/official/bazelrcs/nvidia.bazelrc new file mode 100644 index 00000000000000..f90cd2a5d3860d --- /dev/null +++ b/ci/official/bazelrcs/nvidia.bazelrc @@ -0,0 +1,141 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This bazelrc can build a GPU-supporting TF package. + +# Convenient cache configurations +# Use a cache directory mounted to /tf/cache. Very useful! +build:sigbuild_local_cache --disk_cache=/tf/cache +# Use the public-access TF DevInfra cache (read only) +build:sigbuild_remote_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +# Write to the TF DevInfra cache (only works for internal TF CI) +build:sigbuild_remote_cache_push --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --google_default_credentials +# Change the value of CACHEBUSTER when upgrading the toolchain, or when testing +# different compilation methods. E.g. for a PR to test a new CUDA version, set +# the CACHEBUSTER to the PR number. +build --action_env=CACHEBUSTER=501872366 + +# Use Python 3.X as installed in container image +build --action_env PYTHON_BIN_PATH="/usr/bin/python3" +build --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" +build --python_path="/usr/bin/python3" + +# Build TensorFlow v2 +build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 + +# Target the AVX instruction set +build --copt=-mavx --host_copt=-mavx + +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build --copt=-Wno-gnu-offsetof-extensions + +# Use lld as the linker +build --linkopt="-fuse-ld=lld" +build --linkopt="-lm" + +# Store performance profiling log in the mounted artifact directory. +# The profile can be viewed by visiting chrome://tracing in a Chrome browser. +# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling +build --profile=build/profile.json.gz + +# CUDA: Set up compilation CUDA version and paths +build --@local_config_cuda//:enable_cuda +build --@local_config_cuda//:cuda_compiler=clang +build --repo_env TF_NEED_CUDA=1 +build --config cuda_clang +build --action_env=TF_CUDA_VERSION="11" +build --action_env=TF_CUDNN_VERSION="8" +build --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8" +build --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" +build --action_env=TF_CUDA_CLANG="1" +build --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" +build --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" + +# CUDA: Enable TensorRT optimizations +# https://developer.nvidia.com/tensorrt +build --repo_env TF_NEED_TENSORRT=1 + +# CUDA: Select supported compute capabilities (supported graphics cards). +# This is the same as the official TensorFlow builds. +# See https://developer.nvidia.com/cuda-gpus#compute +# TODO(angerson, perfinion): What does sm_ vs compute_ mean? +# TODO(angerson, perfinion): How can users select a good value for this? +build --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" + +# Test-related settings below this point. +test --build_tests_only --keep_going --test_output=errors --verbose_failures=true +test --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +# Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think +test --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute +# Give only the list of failed tests at the end of the log +test --test_summary=short + +# "nonpip" tests are regular py_test tests. +# Pass --config=nonpip to run the same suite of tests. If you want to run just +# one test for investigation, you don't need --config=nonpip; just run the +# bazel test invocation as normal. +test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... + +# For building libtensorflow archives +test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test +build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip + +# For outputting Build Event Protocol files +build:build_event_export --build_event_json_file=build/bep.json + +# For Remote Build Execution. +build:rbe --google_default_credentials +build:rbe --bes_backend=buildeventservice.googleapis.com +build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" +build:rbe --bes_timeout=600s +build:rbe --define=EXECUTOR=remote +build:rbe --jobs=800 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +build:rbe --remote_download_toplevel +build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm # Unclear why this is here +build:rbe --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --host_platform="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --platforms="@sigbuild-r2.14-clang_config_platform//:platform" +# Python config is the same across all containers because the binary is the same +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" +build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe --project_id="tensorflow-testing" + +# For Remote build execution -- GPU configuration +build:rbe --repo_env=REMOTE_GPU_TESTING=1 +test:rbe --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.14-clang_config_cuda" +build:rbe --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.14-clang_config_tensorrt" +build:rbe --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.14-clang_config_nccl" +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" + +# For continuous builds +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/code_check_changed_files.sh b/ci/official/code_check_changed_files.sh new file mode 100755 index 00000000000000..50241e6bf6b3a9 --- /dev/null +++ b/ci/official/code_check_changed_files.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source "${BASH_SOURCE%/*}/utilities/setup.sh" + +tfrun bats ./ci/official/utilities/code_check_changed_files.bats --timing --output build diff --git a/ci/official/code_check_full.sh b/ci/official/code_check_full.sh new file mode 100755 index 00000000000000..4fb08a897431f2 --- /dev/null +++ b/ci/official/code_check_full.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source "${BASH_SOURCE%/*}/utilities/setup.sh" + +tfrun bats ./ci/official/utilities/code_check_full.bats --timing --output build diff --git a/ci/official/envs/local_cpu b/ci/official/envs/local_cpu new file mode 100644 index 00000000000000..79e64bfc28bcf9 --- /dev/null +++ b/ci/official/envs/local_cpu @@ -0,0 +1,21 @@ +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache) +TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu") +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 +TFCI_DOCKER_PULL_ENABLE= +TFCI_GIT_DIR=. +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= +TFCI_NVIDIA_SMI_ENABLE= +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_LIB_URI= +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cpu_py310 b/ci/official/envs/nightly_cpu_py310 new file mode 100644 index 00000000000000..eabe2dcc845a1e --- /dev/null +++ b/ci/official/envs/nightly_cpu_py310 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cpu_py311 b/ci/official/envs/nightly_cpu_py311 new file mode 100644 index 00000000000000..0201e5aa44c0d4 --- /dev/null +++ b/ci/official/envs/nightly_cpu_py311 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cpu_py39 b/ci/official/envs/nightly_cpu_py39 new file mode 100644 index 00000000000000..436bd41e169143 --- /dev/null +++ b/ci/official/envs/nightly_cpu_py39 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py310 b/ci/official/envs/nightly_nvidia_py310 new file mode 100644 index 00000000000000..214efe40d42db3 --- /dev/null +++ b/ci/official/envs/nightly_nvidia_py310 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py311 b/ci/official/envs/nightly_nvidia_py311 new file mode 100644 index 00000000000000..9a4a8f173eb2a6 --- /dev/null +++ b/ci/official/envs/nightly_nvidia_py311 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py39 b/ci/official/envs/nightly_nvidia_py39 new file mode 100644 index 00000000000000..4e729536b1d60e --- /dev/null +++ b/ci/official/envs/nightly_nvidia_py39 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/libtensorflow.sh b/ci/official/libtensorflow.sh new file mode 100755 index 00000000000000..0457a405a08852 --- /dev/null +++ b/ci/official/libtensorflow.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source "${BASH_SOURCE%/*}/utilities/setup.sh" + +# Record GPU count and CUDA version status +if [[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]]; then + tfrun nvidia-smi +fi + +# Update the version numbers for Nightly only +if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then + tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly +fi + +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=libtensorflow_test +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=libtensorflow_build + +tfrun ./ci/official/utilities/repack_libtensorflow.sh build "$TFCI_LIB_SUFFIX" + +if [[ "$TFCI_UPLOAD_LIB_ENABLE" == 1 ]]; then + gsutil cp build/*.tar.gz "$TFCI_UPLOAD_LIB_GCS_URI" + if [[ "$TFCI_UPLOAD_LIB_LATEST_ENABLE" == 1 ]]; then + gsutil cp build/*.tar.gz "$TFCI_UPLOAD_LIB_LATEST_GCS_URI" + fi +fi diff --git a/tensorflow/python/autograph/converters/__init__.py b/ci/official/pycpp.sh old mode 100644 new mode 100755 similarity index 50% rename from tensorflow/python/autograph/converters/__init__.py rename to ci/official/pycpp.sh index fc8ae684c2a2a8..f29fd5e4d4329b --- a/tensorflow/python/autograph/converters/__init__.py +++ b/ci/official/pycpp.sh @@ -1,4 +1,5 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,19 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Code converters used by Autograph.""" +source "${BASH_SOURCE%/*}/utilities/setup.sh" -# Naming conventions: -# * each converter should specialize on a single idiom; be consistent with -# the Python reference for naming -# * all converters inherit core.converter.Base -# * module names describe the idiom that the converter covers, plural -# * the converter class is named consistent with the module, singular and -# includes the word Transformer -# -# Example: -# -# lists.py -# class ListTransformer(converter.Base) +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=pycpp -from tensorflow.python.autograph.converters import list_comprehensions +tfrun bazel analyze-profile build/profile.json.gz diff --git a/ci/official/utilities/code_check_changed_files.bats b/ci/official/utilities/code_check_changed_files.bats new file mode 100644 index 00000000000000..8704ddb53064a9 --- /dev/null +++ b/ci/official/utilities/code_check_changed_files.bats @@ -0,0 +1,76 @@ +# vim: filetype=bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +setup_file() { + cd "$TFCI_GIT_DIR" + bazel version # Start the bazel server + # Without this, git errors if /tf/tensorflow directory owner is different + git config --global --add safe.directory "$TFCI_GIT_DIR" + # Note that you could generate a list of all the affected targets with e.g.: + # bazel query $(paste -sd "+" $BATS_FILE_TMPDIR/changed_files) --keep_going + # Only shows Added, Changed, Modified, Renamed, and Type-changed files + if [[ "$(git rev-parse --abbrev-ref HEAD)" == "pull_branch" ]]; then + # TF's CI runs 'git fetch origin "pull/PR#/merge:pull_branch"' + # To get the as-merged branch during the CI tests + git diff --diff-filter ACMRT --name-only pull_branch^ pull_branch > $BATS_FILE_TMPDIR/changed_files + else + # If the branch is not present, then diff against origin/master + git diff --diff-filter ACMRT --name-only origin/master > $BATS_FILE_TMPDIR/changed_files + fi +} + +# Note: this is excluded on the full code base, since any submitted code must +# have passed Google's internal style guidelines. +@test "Check buildifier formatting on BUILD files" { + echo "buildifier formatting is recommended. Here are the suggested fixes:" + echo "=============================" + grep -e 'BUILD' $BATS_FILE_TMPDIR/changed_files \ + | xargs buildifier -v -mode=diff -diff_command="git diff --no-index" +} + +# Note: this is excluded on the full code base, since any submitted code must +# have passed Google's internal style guidelines. +@test "Check formatting for C++ files" { + skip "clang-format doesn't match internal clang-format checker" + echo "clang-format is recommended. Here are the suggested changes:" + echo "=============================" + grep -e '\.h$' -e '\.cc$' $BATS_FILE_TMPDIR/changed_files > $BATS_TEST_TMPDIR/files || true + if [[ ! -s $BATS_TEST_TMPDIR/files ]]; then return 0; fi + xargs -a $BATS_TEST_TMPDIR/files -i -n1 -P $(nproc --all) \ + bash -c 'clang-format-12 --style=Google {} | git diff --no-index {} -' \ + | tee $BATS_TEST_TMPDIR/needs_help.txt + echo "You can use clang-format --style=Google -i to apply changes to a file." + [[ ! -s $BATS_TEST_TMPDIR/needs_help.txt ]] +} + +# Note: this is excluded on the full code base, since any submitted code must +# have passed Google's internal style guidelines. +@test "Check pylint for Python files" { + echo "Python formatting is recommended. Here are the pylint errors:" + echo "=============================" + grep -e "\.py$" $BATS_FILE_TMPDIR/changed_files > $BATS_TEST_TMPDIR/files || true + if [[ ! -s $BATS_TEST_TMPDIR/files ]]; then return 0; fi + xargs -a $BATS_TEST_TMPDIR/files -n1 -P $(nproc --all) \ + python -m pylint --rcfile=tensorflow/tools/ci_build/pylintrc --score false \ + | grep -v "**** Module" \ + | tee $BATS_TEST_TMPDIR/needs_help.txt + [[ ! -s $BATS_TEST_TMPDIR/needs_help.txt ]] +} + +teardown_file() { + bazel shutdown +} diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats new file mode 100644 index 00000000000000..c963fd850fc34f --- /dev/null +++ b/ci/official/utilities/code_check_full.bats @@ -0,0 +1,307 @@ +# vim: filetype=bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +setup_file() { + cd $TFCI_GIT_DIR + bazel version # Start the bazel server +} + +# Do a bazel query specifically for the licenses checker. It searches for +# targets matching the provided query, which start with // or @ but not +# //tensorflow (so it looks for //third_party, //external, etc.), and then +# gathers the list of all packages (i.e. directories) which contain those +# targets. +license_query() { + bazel cquery --experimental_cc_shared_library "$1" --keep_going \ + | grep -e "^//" -e "^@" \ + | grep -E -v "^//tensorflow" \ + | sed -e 's|:.*||' \ + | sort -u +} + +# Verify that, given a build target and a license-list generator target, all of +# the dependencies of that target which include a license notice file are then +# included when generating that license. Necessary because the license targets +# in TensorFlow are manually enumerated rather than generated automatically. +do_external_licenses_check(){ + BUILD_TARGET="$1" + LICENSES_TARGET="$2" + + # grep patterns for targets which are allowed to be missing from the licenses + cat > $BATS_TEST_TMPDIR/allowed_to_be_missing < $BATS_TEST_TMPDIR/allowed_to_be_extra < $BATS_TEST_TMPDIR/expected_licenses + license_query "deps($LICENSES_TARGET)" > $BATS_TEST_TMPDIR/actual_licenses + + # Column 1 is left only, Column 2 is right only, Column 3 is shared lines + # Select lines unique to actual_licenses, i.e. extra licenses. + comm -1 -3 $BATS_TEST_TMPDIR/expected_licenses $BATS_TEST_TMPDIR/actual_licenses | grep -v -f $BATS_TEST_TMPDIR/allowed_to_be_extra > $BATS_TEST_TMPDIR/actual_extra_licenses || true + # Select lines unique to expected_licenses, i.e. missing licenses + comm -2 -3 $BATS_TEST_TMPDIR/expected_licenses $BATS_TEST_TMPDIR/actual_licenses | grep -v -f $BATS_TEST_TMPDIR/allowed_to_be_missing > $BATS_TEST_TMPDIR/actual_missing_licenses || true + + if [[ -s $BATS_TEST_TMPDIR/actual_extra_licenses ]]; then + echo "Please remove the following extra licenses from $LICENSES_TARGET:" + cat $BATS_TEST_TMPDIR/actual_extra_licenses + fi + + if [[ -s $BATS_TEST_TMPDIR/actual_missing_licenses ]]; then + echo "Please include the missing licenses for the following packages in $LICENSES_TARGET:" + cat $BATS_TEST_TMPDIR/actual_missing_licenses + fi + + # Fail if either of the two "extras" or "missing" lists are present. If so, + # then the user will see the above error messages. + [[ ! -s $BATS_TEST_TMPDIR/actual_extra_licenses ]] && [[ ! -s $BATS_TEST_TMPDIR/actual_missing_licenses ]] +} + +@test "Pip package generated license includes all dependencies' licenses" { + do_external_licenses_check \ + "//tensorflow/tools/pip_package:build_pip_package" \ + "//tensorflow/tools/pip_package:licenses" +} + +@test "Libtensorflow generated license includes all dependencies' licenses" { + do_external_licenses_check \ + "//tensorflow:libtensorflow.so" \ + "//tensorflow/tools/lib_package:clicenses_generate" +} + +@test "Java library generated license includes all dependencies' licenses" { + do_external_licenses_check \ + "//tensorflow/java:libtensorflow_jni.so" \ + "//tensorflow/tools/lib_package:jnilicenses_generate" +} + +# This test ensures that all the targets built into the Python package include +# their dependencies. It's a rewritten version of the "smoke test", an older +# Python script that was very difficult to understand. See +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/pip_smoke_test.py +@test "Pip package includes all required //tensorflow dependencies" { + # grep patterns for packages whose dependencies can be ignored + cat > $BATS_TEST_TMPDIR/ignore_deps_for_these_packages < $BATS_TEST_TMPDIR/ignore_these_deps < $BATS_TEST_TMPDIR/pip_deps + # Find all Python py_test targets not tagged "no_pip" or "manual", excluding + # any targets in ignored packages. Combine this list of targets into a bazel + # query list (e.g. the list becomes "target+target2+target3") + bazel query --keep_going 'kind(py_test, //tensorflow/python/...) - attr("tags", "no_pip|manual", //tensorflow/python/...)' | grep -v -f $BATS_TEST_TMPDIR/ignore_deps_for_these_packages | paste -sd "+" - > $BATS_TEST_TMPDIR/deps + # Find all one-step dependencies of those tests which are from //tensorflow + # (since external deps will come from Python-level pip dependencies), + # excluding dependencies and files that are known to be unneccessary. + # This creates a list of targets under //tensorflow that are required for + # TensorFlow python tests. + bazel query --keep_going "deps($(cat $BATS_TEST_TMPDIR/deps), 1)" | grep "^//tensorflow" | grep -v -f $BATS_TEST_TMPDIR/ignore_these_deps | sort -u > $BATS_TEST_TMPDIR/required_deps + + + # Find if any required dependencies are missing from the list of dependencies + # included in the pip package. + # (comm: Column 1 is left, Column 2 is right, Column 3 is shared lines) + comm -2 -3 $BATS_TEST_TMPDIR/required_deps $BATS_TEST_TMPDIR/pip_deps > $BATS_TEST_TMPDIR/missing_deps || true + + if [[ -s $BATS_TEST_TMPDIR/missing_deps ]]; then + cat < $BATS_TEST_TMPDIR/out + + cat < $BATS_TEST_TMPDIR/out + + cat <> errors.txt + fi + if [[ -e errors.txt ]]; then + echo "Broken links found:" + cat errors.txt + rm errors.txt + false + fi + done +} + +@test "No duplicate files on Windows" { + cat < + +$(basename "$KOKORO_JOB_NAME") + + +

TensorFlow Job Logs and Links

+

Job Details

+
    +
  • Job name: $KOKORO_JOB_NAME
  • +
  • Job pool: $KOKORO_JOB_POOL
  • +
  • Job ID: $KOKORO_BUILD_ID
  • +
  • Current HEAD Piper Changelist, if any: cl/${KOKORO_PIPER_CHANGELIST:-not available}
  • +
  • Pull Request Number, if any: ${KOKORO_GITHUB_PULL_REQUEST_NUMBER_tensorflow:- none}
  • +
  • Pull Request Link, if any: ${KOKORO_GITHUB_PULL_REQUEST_URL_tensorflow:-none}
  • +
  • Commit: $KOKORO_GIT_COMMIT_tensorflow
  • +
+

Googlers-Only Links

+ +

Non-Googler Links

+ + +EOF diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh new file mode 100755 index 00000000000000..500d0d9478dd1a --- /dev/null +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Check and rename wheels with auditwheel. Inserts the platform tags like +# "manylinux_xyz" into the wheel filename. +set -euxo pipefail + +cd $TFCI_GIT_DIR +for wheel in build/*.whl; do + echo "Checking and renaming $wheel..." + time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir build 2>&1 | tee check.txt + + # We don't need the original wheel if it was renamed + new_wheel=$(grep --extended-regexp --only-matching '/tf/pkg/\S+.whl' check.txt) + if [[ "$new_wheel" != "$wheel" ]]; then + rm "$wheel" + wheel="$new_wheel" + fi + rm check.txt + + TF_WHEEL="$wheel" bats ./ci/official/utilities/wheel_verification.bats --timing +done diff --git a/ci/official/utilities/repack_libtensorflow.sh b/ci/official/utilities/repack_libtensorflow.sh new file mode 100755 index 00000000000000..0f549bf0975d73 --- /dev/null +++ b/ci/official/utilities/repack_libtensorflow.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== +# +# Repacks libtensorflow tarballs into $DIR with provided $TARBALL_SUFFIX, +# and also repacks libtensorflow-src.jar into a standardized format. + +# Helper function to copy a srcjar after moving any source files +# directly under the root to the "maven-style" src/main/java layout +# +# Source files generated by annotation processors appear directly +# under the root of srcjars jars created by bazel, rather than under +# the maven-style src/main/java subdirectory. +# +# Bazel manages annotation generated source as follows: First, it +# calls javac with options that create generated files under a +# bazel-out directory. Next, it archives the generated source files +# into a srcjar directly under the root. There doesn't appear to be a +# simple way to parameterize this from bazel, hence this helper to +# "normalize" the srcjar layout. +# +# Arguments: +# src_jar - path to the original srcjar +# dest_jar - path to the destination +# Returns: +# None +function cp_normalized_srcjar() { + src_jar="$1" + dest_jar="$2" + tmp_dir=$(mktemp -d) + cp "${src_jar}" "${tmp_dir}/orig.jar" + pushd "${tmp_dir}" + # Extract any src/ files + jar -xf "${tmp_dir}/orig.jar" src/ + # Extract any org/ files under src/main/java + (mkdir -p src/main/java && cd src/main/java && jar -xf "${tmp_dir}/orig.jar" org/) + # Repackage src/ + jar -cMf "${tmp_dir}/new.jar" src + popd + cp "${tmp_dir}/new.jar" "${dest_jar}" + rm -rf "${tmp_dir}" +} +DIR=$1 +TARBALL_SUFFIX=$2 +mkdir -p "$DIR" +cp bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz "${DIR}/libtensorflow${TARBALL_SUFFIX}.tar.gz" +cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_jni.tar.gz "${DIR}/libtensorflow_jni${TARBALL_SUFFIX}.tar.gz" +cp bazel-bin/tensorflow/java/libtensorflow.jar "${DIR}" +cp_normalized_srcjar bazel-bin/tensorflow/java/libtensorflow-src.jar "${DIR}/libtensorflow-src.jar" +cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_proto.zip "${DIR}" diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh new file mode 100755 index 00000000000000..faba21808b5d3c --- /dev/null +++ b/ci/official/utilities/setup.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Common setup for all TF scripts. +# +# Make as FEW changes to this file as possible. It should not contain utility +# functions (except for tfrun); use dedicated scripts instead and reference them +# specifically. Use your best judgment to keep the scripts in this directory +# lean and easy to follow. When in doubt, remember that for CI scripts, "keep it +# simple" is MUCH more important than "don't repeat yourself." + +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +# (affects 'source $TFCI') +set -euxo pipefail -o history -o allexport + +# "TFCI" may optionally be set to the name of an env-type file with TFCI +# variables in it, OR may be left empty if the user has already exported the +# relevant variables in their environment. Because of 'set -o allexport' above +# (which is equivalent to "set -a"), every variable in the file is exported +# for other files to use. +if [[ -n "${TFCI:-}" ]]; then + source "$TFCI" +else + echo '==TFCI==: The $TFCI variable is not set. This is fine as long as you' + echo 'already sourced a TFCI env file with "set -a; source ; set +a".' + echo 'If you have not, you will see a lot of undefined variable errors.' +fi + +# Make a "build" directory for outputting all build artifacts (TF's .gitignore +# ignores the "build" directory) +cd "$TFCI_GIT_DIR" +mkdir -p build + +# Setup tfrun, a helper function for executing steps that can either be run +# locally or run under Docker. docker.sh, below, redefines it as "docker exec". +# Important: "tfrun foo | bar" is "( tfrun foo ) | bar", not tfrun (foo | bar). +# Therefore, "tfrun" commands cannot include pipes -- which is probably for the +# better. If a pipe is necessary for something, it is probably complex. Write a +# well-documented script under utilities/ to encapsulate the functionality +# instead. +tfrun() { "$@"; } + +# For Google-internal jobs, run copybara, which will overwrite the source tree. +# Never useful for outside users. Requires that the Kokoro job define a gfile +# resource pointing to copybara.sh, which is then loaded into the GFILE_DIR. +# See: cs/official/copybara.sh +if [[ "$TFCI_COPYBARA_ENABLE" == 1 ]]; then + if [[ -e "$KOKORO_GFILE_DIR/copybara.sh" ]]; then + source "$KOKORO_GFILE_DIR/copybara.sh" + else + echo "TF_CI_COPYBARA_ENABLE is 1, but \$KOKORO_GFILE_DIR/copybara.sh" + echo "could not be found. If you are an internal user, make sure your" + echo "Kokoro job has a gfile_resources item pointing to the right file." + echo "If you are an external user, Copybara is useless for you, and you" + echo "should set TFCI_COPYBARA_ENABLE=0" + exit 1 + fi +fi + +# Run all "tfrun" commands under Docker. See docker.sh for details +if [[ "$TFCI_DOCKER_ENABLE" == 1 ]]; then + source ./ci/official/utilities/docker.sh +fi + +# Generate an overview page describing the build +if [[ "$TFCI_INDEX_HTML_ENABLE" == 1 ]]; then + ./ci/official/utilities/generate_index_html.sh build/index.html +fi diff --git a/ci/official/utilities/wheel_verification.bats b/ci/official/utilities/wheel_verification.bats new file mode 100644 index 00000000000000..6a35adc0f05748 --- /dev/null +++ b/ci/official/utilities/wheel_verification.bats @@ -0,0 +1,79 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Suite of verification tests for the SINGLE TensorFlow wheel in /tf/pkg +# or whatever path is set as $TF_WHEEL. + +setup_file() { + cd "$TFCI_GIT_DIR/build" + if [[ -z "$TF_WHEEL" ]]; then + export TF_WHEEL=$(find "$TFCI_GIT_DIR/build" -iname "*.whl") + fi +} + +teardown_file() { + rm -rf "$BATS_FILE_TMPDIR/venv" + python3 -m venv +} + +@test "Wheel is manylinux2014 (manylinux_2_17) compliant" { + python3 -m auditwheel show "$TF_WHEEL" > audit.txt + grep --quiet 'This constrains the platform tag to "manylinux_2_17_x86_64"' audit.txt +} + +@test "Wheel conforms to upstream size limitations" { + WHEEL_MEGABYTES=$(stat --format %s "$TF_WHEEL" | awk '{print int($1/(1024*1024))}') + # Googlers: search for "test_tf_whl_size" + case "$TF_WHEEL" in + # CPU: + *cpu*manylinux*) LARGEST_OK_SIZE=240 ;; + # GPU: + *manylinux*) LARGEST_OK_SIZE=580 ;; + # Unknown: + *) + echo "The wheel's name is in an unknown format." + exit 1 + ;; + esac + # >&3 forces output in bats even if the test passes. See + # https://bats-core.readthedocs.io/en/stable/writing-tests.html#printing-to-the-terminal + echo "# Size of $TF_WHEEL is $WHEEL_MEGABYTES / $LARGEST_OK_SIZE megabytes." >&3 + test "$WHEEL_MEGABYTES" -le "$LARGEST_OK_SIZE" +} + +# Note: this runs before the tests further down the file, so TF is installed in +# the venv and the venv is active when those tests run. The venv gets cleaned +# up in teardown_file() above. +@test "Wheel is installable" { + python3 -m venv "$BATS_FILE_TMPDIR/venv" + source "$BATS_FILE_TMPDIR/bin/activate" + python3 -m pip install "$TF_WHEEL" +} + +@test "TensorFlow is importable" { + source "$BATS_FILE_TMPDIR/bin/activate" + python3 -c 'import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)' +} + +# Is this still useful? +@test "TensorFlow has Keras" { + source "$BATS_FILE_TMPDIR/bin/activate" + python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "_v2.keras" in tf.keras.__name__ else 1)' +} + +# Is this still useful? +@test "TensorFlow has Estimator" { + source "$BATS_FILE_TMPDIR/bin/activate" + python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "_v2.estimator" in tf.estimator.__name__ else 1)' +} diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh new file mode 100755 index 00000000000000..60f626a4f45980 --- /dev/null +++ b/ci/official/wheel.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source "${BASH_SOURCE%/*}/utilities/setup.sh" + +# Record GPU count and CUDA version status +if [[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]]; then + tfrun nvidia-smi +fi + +# Update the version numbers for Nightly only +if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then + tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly +fi + +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" //tensorflow/tools/pip_package:build_pip_package +tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" +tfrun ./ci/official/utilities/rename_and_verify_wheels.sh build + +if [[ "$TFCI_UPLOAD_ENABLE" == 1 ]]; then + twine upload "${TFCI_UPLOAD_PYPI_ARGS[@]}" build/*.whl + gsutil cp build/*.whl "$TFCI_UPLOAD_GCS_DESTINATION" +fi + +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=nonpip diff --git a/configure.py b/configure.py index 9b67fae22f6334..262637734a55ec 100644 --- a/configure.py +++ b/configure.py @@ -84,6 +84,10 @@ def is_ppc64le(): return platform.machine() == 'ppc64le' +def is_s390x(): + return platform.machine() == 's390x' + + def is_cygwin(): return platform.system().startswith('CYGWIN_NT') @@ -1100,7 +1104,12 @@ def system_specific_test_config(environ_cp): def set_system_libs_flag(environ_cp): + """Set system libs flags.""" syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') + + if is_s390x() and 'boringssl' not in syslibs: + syslibs = 'boringssl' + (', ' + syslibs if syslibs else '') + if syslibs: if ',' in syslibs: syslibs = ','.join(sorted(syslibs.split(','))) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 53d444b410deec..ddf9b47151a9fe 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -571,6 +571,15 @@ config_setting( visibility = ["//visibility:public"], ) +# This condition takes precedence over :linux_x86_64 +# TODO(b/290533709): Remove this with PJRT build rule cleanup. +config_setting( + name = "linux_x86_64_with_weightwatcher", + define_values = {"tensorflow_weightwatcher": "true"}, + values = {"cpu": "k8"}, + visibility = ["//visibility:public"], +) + config_setting( name = "linux_ppc64le", values = {"cpu": "ppc"}, @@ -1108,7 +1117,6 @@ cc_library( name = "grpc", visibility = ["//visibility:public"], deps = select({ - ":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"], "//conditions:default": ["@com_github_grpc_grpc//:grpc"], }), ) @@ -1117,7 +1125,6 @@ cc_library( name = "grpc++", visibility = ["//visibility:public"], deps = select({ - ":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"], "//conditions:default": ["@com_github_grpc_grpc//:grpc++"], }), ) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index a1d2e6e86b9214..29dc30bedf92f1 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -1097,7 +1097,6 @@ tf_cuda_library( ":c_api_internal", "//tensorflow/core:protos_all_cc", # TODO(b/74620627): remove when _USE_C_SHAPES is removed - "//tensorflow/python/framework:cpp_shape_inference_proto_cc", ], alwayslink = 1, ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 48c16296d973de..bf598c4d57c148 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -50,11 +50,15 @@ limitations under the License. #include "tensorflow/core/config/flag_defs.h" #include "tensorflow/core/config/flags.h" #include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/cpp_shape_inference.pb.h" +#include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.h" @@ -81,10 +85,10 @@ limitations under the License. // The implementation below is at the top level instead of the // brain namespace because we are defining 'extern "C"' functions. -using tensorflow::AllocationDescription; using tensorflow::AttrValueMap; using tensorflow::DataType; using tensorflow::ExtendSessionGraphHelper; +using tensorflow::FullTypeDef; using tensorflow::Graph; using tensorflow::GraphDef; using tensorflow::mutex_lock; @@ -93,10 +97,7 @@ using tensorflow::NameRangesForNode; using tensorflow::NewSession; using tensorflow::Node; using tensorflow::NodeBuilder; -using tensorflow::NodeDef; using tensorflow::OpDef; -using tensorflow::OpRegistry; -using tensorflow::OutputTensor; using tensorflow::PartialTensorShape; using tensorflow::RunMetadata; using tensorflow::RunOptions; @@ -104,9 +105,7 @@ using tensorflow::Session; using tensorflow::Status; using tensorflow::string; using tensorflow::Tensor; -using tensorflow::TensorBuffer; using tensorflow::TensorId; -using tensorflow::TensorShape; using tensorflow::TensorShapeProto; using tensorflow::VersionDef; using tensorflow::errors::FailedPrecondition; @@ -2575,6 +2574,150 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } } +// Apis that are corresponding to python c api. -------------------------- + +void TF_AddOperationControlInput(TF_Graph* graph, TF_Operation* op, + TF_Operation* input) { + using tensorflow::RecordMutation; + mutex_lock l(graph->mu); + graph->graph.AddControlEdge(&input->node, &op->node); + RecordMutation(graph, *op, "adding control input"); +} + +void TF_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Buffer* attr_value_proto, TF_Status* status) { + using tensorflow::RecordMutation; + tensorflow::AttrValue attr_val; + if (!attr_val.ParseFromArray(attr_value_proto->data, + attr_value_proto->length)) { + status->status = absl::InvalidArgumentError("Invalid AttrValue proto"); + return; + } + + mutex_lock l(graph->mu); + op->node.AddAttr(attr_name, attr_val); + RecordMutation(graph, *op, "setting attribute"); +} + +void TF_ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Status* status) { + using tensorflow::RecordMutation; + mutex_lock l(graph->mu); + op->node.ClearAttr(attr_name); + RecordMutation(graph, *op, "clearing attribute"); +} + +void TF_SetFullType(TF_Graph* graph, TF_Operation* op, + const TF_Buffer* full_type_proto) { + using tensorflow::RecordMutation; + mutex_lock l(graph->mu); + FullTypeDef full_type; + full_type.ParseFromArray(full_type_proto->data, full_type_proto->length); + *op->node.mutable_def()->mutable_experimental_type() = full_type; + RecordMutation(graph, *op, "setting fulltype"); +} + +void TF_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, + const char* device) { + using tensorflow::RecordMutation; + mutex_lock l(graph->mu); + op->node.set_requested_device(device); + RecordMutation(graph, *op, "setting device"); +} + +void TF_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { + mutex_lock l(graph->mu); + std::vector control_edges; + for (const tensorflow::Edge* edge : op->node.in_edges()) { + if (!edge->IsControlEdge()) continue; + control_edges.push_back(edge); + } + for (const tensorflow::Edge* edge : control_edges) { + graph->graph.RemoveControlEdge(edge); + } +} + +void TF_SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { + mutex_lock l(graph->mu); + graph->refiner.set_require_shape_inference_fns(require); +} + +void TF_ExtendSession(TF_Session* session, TF_Status* status) { + ExtendSessionGraphHelper(session, status); + session->extend_before_run = false; +} + +TF_Buffer* TF_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { + Node* node = &output.oper->node; + tensorflow::core::CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); // Crash OK + CHECK_LT(output.index, ic->num_outputs()); // Crash OK + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) return nullptr; + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + *out_shape_and_type->mutable_type() = p.type; + } + } + string str_data; + handle_data.SerializeToString(&str_data); + + TF_Buffer* result = TF_NewBufferFromString(str_data.c_str(), str_data.size()); + return result; +} + +void TF_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::core::CppShapeInferenceResult::HandleData handle_data; + if (!handle_data.ParseFromArray(proto, proto_len)) { + status->status = + absl::InvalidArgumentError("Couldn't deserialize HandleData proto"); + return; + } + DCHECK(handle_data.is_set()); + + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&output.oper->node); + + std::vector shapes_and_types; + for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { + tensorflow::shape_inference::ShapeHandle shape; + status->status = + ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); + if (TF_GetCode(status) != TF_OK) return; + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), + shape_and_type_proto.type()); + } + ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); +} + +void TF_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status) { + using tensorflow::RecordMutation; + mutex_lock l(graph->mu); + status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, + new_src.index, &dst->node); + if (TF_GetCode(status) == TF_OK) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst, "adding input tensor"); + } +} + +// ------------------------------------------------------------------- + // TF_Server functions ---------------------------------------------- #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index e4c6499506ec76..2f4cf6062de04c 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1577,6 +1577,81 @@ TF_CAPI_EXPORT extern void TF_RegisterLogListener( TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin( const char* plugin_filename, TF_Status* status); +// Apis that are corresponding to python c api. -------------------- + +// Add control input to `op`. +TF_CAPI_EXPORT extern void TF_AddOperationControlInput(TF_Graph* graph, + TF_Operation* op, + TF_Operation* input); + +// Changes an attr value in the node_def Protocol Buffer and sets a status upon +// completion. +TF_CAPI_EXPORT extern void TF_SetAttr(TF_Graph* graph, TF_Operation* op, + const char* attr_name, + TF_Buffer* attr_value_proto, + TF_Status* status); + +// Clears the attr in the node_def Protocol Buffer and sets a status upon +// completion. +TF_CAPI_EXPORT extern void TF_ClearAttr(TF_Graph* graph, TF_Operation* op, + const char* attr_name, + TF_Status* status); + +// Sets the experimental_type` field in the node_def Protocol Buffer. +TF_CAPI_EXPORT extern void TF_SetFullType(TF_Graph* graph, TF_Operation* op, + const TF_Buffer* full_type_proto); + +// Set the requested device for `graph`. +TF_CAPI_EXPORT extern void TF_SetRequestedDevice(TF_Graph* graph, + TF_Operation* op, + const char* device); + +// Remove all the control inputs from `op` in `graph`. +TF_CAPI_EXPORT extern void TF_RemoveAllControlInputs(TF_Graph* graph, + TF_Operation* op); + +// Set if `graph` requires shape inference functions. +TF_CAPI_EXPORT extern void TF_SetRequireShapeInferenceFns(TF_Graph* graph, + bool require); + +// Extends `session` with any new operations added to its associated graph. +// Usually this happens automatically in TF_SessionRun. After this is called, +// TF_SessionRun will no longer extend the session on every call. +// +// We expose this here to allow fine-grained synchronization in multi-threaded +// workloads, which is required since the Python implementation depends on the +// above mutation methods. This allows us to prevent modifications to nodes in +// the graph after the session has been made aware of them. +TF_CAPI_EXPORT extern void TF_ExtendSession(TF_Session* session, + TF_Status* status); + +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource or variant tensor, or otherwise returns the empty +// string. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetHandleShapeAndType(TF_Graph* graph, + TF_Output output); + +// Sets `output` based on `proto`, which should be a serialized +// CppShapeInferenceResult::HandleData proto. `output` should be a resource +// or variant tensor. +// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string +// because I couldn't get SWIG to work otherwise. +TF_CAPI_EXPORT extern void TF_SetHandleShapeAndType(TF_Graph* graph, + TF_Output output, + const void* proto, + size_t proto_len, + TF_Status* status); + +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +TF_CAPI_EXPORT extern void TF_AddWhileInputHack(TF_Graph* graph, + TF_Output new_src, + TF_Operation* dst, + TF_Status* status); + +// ---------------------------------------------------------------- + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 748d49565f64a1..0b543f7dcbf9b7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -930,10 +930,12 @@ tf_cuda_cc_test( ":c_api_experimental", ":c_api_test_util", "//tensorflow/c:c_test_util", + "//tensorflow/c/eager:c_api_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 6fbcb7bb56a69e..db8b28437607f6 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -939,3 +939,18 @@ void TFE_WaitAtBarrier(TFE_Context* ctx, const char* barrier_id, status->status = coord_agent->WaitAtBarrier( barrier_id, absl::Milliseconds(barrier_timeout_in_ms), {}); } + +void TFE_InitializeLocalOnlyContext(TFE_Context* ctx, int keep_alive_secs, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = + tensorflow::unwrap(ctx) + ->GetDistributedManager() + ->InitializeLocalOnlyContext(server_def, keep_alive_secs); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index fcbced2080a082..dc88de351f74fa 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -747,6 +747,12 @@ TF_CAPI_EXPORT extern void TFE_WaitAtBarrier(TFE_Context* ctx, int64_t barrier_timeout_in_ms, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_InitializeLocalOnlyContext(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 68dbafc4d2a1e6..51e56827114cb6 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -18,13 +18,17 @@ limitations under the License. #include #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" using tensorflow::string; @@ -522,5 +526,63 @@ TEST(CAPI, TensorHandleDefaults) { TFE_DeleteContext(ctx); } +TEST(CAPI, CreateLocalContextAsReset) { + tensorflow::ServerDef server_def = GetServerDef("worker", 2); + server_def.mutable_default_session_config()->set_isolate_session_state(false); + + ServerFactory* factory; + ASSERT_TRUE(ServerFactory::GetFactory(server_def, &factory).ok()); + server_def.set_job_name("worker"); + server_def.set_task_index(0); + std::unique_ptr w0; + ASSERT_TRUE( + factory->NewServer(server_def, ServerFactory::Options(), &w0).ok()); + ASSERT_TRUE(w0->Start().ok()); + server_def.set_task_index(1); + std::unique_ptr w1; + ASSERT_TRUE( + factory->NewServer(server_def, ServerFactory::Options(), &w1).ok()); + ASSERT_TRUE(w1->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + opts->session_options.options.config.set_isolate_session_state(false); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + server_def.set_task_index(0); + auto cluster = server_def.mutable_cluster(); + auto client_job = cluster->add_job(); + client_job->set_name("localhost"); + int client_port = tensorflow::testing::PickUnusedPortOrDie(); + client_job->mutable_tasks()->insert( + {0, strings::StrCat("localhost:", client_port)}); + server_def.set_job_name("localhost"); + auto serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + server_def.set_job_name("worker"); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->mutable_job(0); + int worker_port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->at(0) = + tensorflow::strings::StrCat("localhost:", worker_port); + serialized = server_def.SerializeAsString(); + TFE_InitializeLocalOnlyContext(ctx, 0, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_DeleteContextOptions(opts); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + w0.release(); + w1.release(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 254648d9e09309..19c078cbc47e9d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1946,71 +1946,6 @@ tensorflow::ServerDef ReplaceTaskInServerDef( return server_def_copy; } -TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx, - const tensorflow::string& device_name, - const tensorflow::string& variable_name) { - TF_Status* status = TF_NewStatus(); - // Create the variable handle. - TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", "localhost", 0); - TFE_OpSetAttrString(op, "shared_name", variable_name.data(), - variable_name.size()); - if (!device_name.empty()) { - TFE_OpSetDevice(op, device_name.c_str(), status); - } - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_TensorHandle* var_handle = nullptr; - int num_retvals = 1; - TFE_Execute(op, &var_handle, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(1, num_retvals); - TF_DeleteStatus(status); - return var_handle; -} - -TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, - const tensorflow::string& device_name, - const tensorflow::string& variable_name) { - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* var_handle = - CreateVarHandle(ctx, device_name, variable_name); - - // Assign 'value' to it. - TFE_Op* op = TFE_NewOp(ctx, "AssignVariableOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpAddInput(op, var_handle, status); - if (!device_name.empty()) { - TFE_OpSetDevice(op, device_name.c_str(), status); - } - - // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. - std::unique_ptr t( - TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); - memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); - - std::unique_ptr - value_handle(TFE_NewTensorHandle(t.get(), status), - TFE_DeleteTensorHandle); - if (TF_GetCode(status) != TF_OK) return nullptr; - - TFE_OpAddInput(op, value_handle.get(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - - int num_retvals = 0; - TFE_Execute(op, nullptr, &num_retvals, status); - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(0, num_retvals); - TF_DeleteStatus(status); - return var_handle; -} - TFE_Context* CreateContext(const string& serialized_server_def, bool isolate_session_state) { TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 1fb76748059a20..75450e5c7aa88b 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -485,3 +485,68 @@ tensorflow::ServerDef GetMultiClientServerDef(const std::string& job_name, } return server_def; } + +TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx, + const tensorflow::string& device_name, + const tensorflow::string& variable_name) { + TF_Status* status = TF_NewStatus(); + // Create the variable handle. + TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op, "shape", {}, 0, status); + TFE_OpSetAttrString(op, "container", "localhost", 0); + TFE_OpSetAttrString(op, "shared_name", variable_name.data(), + variable_name.size()); + if (!device_name.empty()) { + TFE_OpSetDevice(op, device_name.c_str(), status); + } + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + TFE_Execute(op, &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(1, num_retvals); + TF_DeleteStatus(status); + return var_handle; +} + +TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name, + const tensorflow::string& variable_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* var_handle = + CreateVarHandle(ctx, device_name, variable_name); + + // Assign 'value' to it. + TFE_Op* op = TFE_NewOp(ctx, "AssignVariableOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + if (!device_name.empty()) { + TFE_OpSetDevice(op, device_name.c_str(), status); + } + + // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. + std::unique_ptr t( + TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); + memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); + + std::unique_ptr + value_handle(TFE_NewTensorHandle(t.get(), status), + TFE_DeleteTensorHandle); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_OpAddInput(op, value_handle.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + int num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(0, num_retvals); + TF_DeleteStatus(status); + return var_handle; +} diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index ce8546fb4f4186..3dad82723b6453 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -145,4 +145,16 @@ tensorflow::ServerDef GetMultiClientServerDef(const std::string& job_name, int num_tasks, int num_virtual_gpus = 0); +// Create a variable handle with name `variable_name` on a device with name +// `device_name`. +TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx, + const tensorflow::string& device_name, + const tensorflow::string& variable_name); + +// Create a variable with value `value` and name `variable_name` on a device +// with name `device_name`. +TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name, + const tensorflow::string& variable_name); + #endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h index 4f96992e7393af..b0fcc49c0b8c36 100644 --- a/tensorflow/c/eager/immediate_execution_distributed_manager.h +++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -44,6 +44,12 @@ class ImmediateExecutionDistributedManager { bool reset_context, int keep_alive_secs) = 0; + // Initializes context for the local worker and no contexts will be created + // for remote workers. Currently this only works for resetting context. + // TODO(b/289445025): Consider removing this when we find a proper fix. + virtual Status InitializeLocalOnlyContext(const ServerDef& server_def, + int keep_alive_secs) = 0; + // Set up a multi-client distributed execution environment. Must be called // on all tasks in the cluster. This call internally coordinates with other // tasks to initialize the eager context and TF server for multi-client diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 59f978000e66f9..ad3146a37fdbfc 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -357,6 +357,10 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, return; } const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); + if ((&cc_tensor) == nullptr) { // NOLINT: Error observed in OSS. + *tensor = nullptr; + return; + } TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status); if (TF_GetCode(status) == TF_OK) { diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index faf93475541da3..c2a4d73f8ad620 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -15,51 +15,43 @@ limitations under the License. #include "tensorflow/c/python_api.h" +#include + #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/core/framework/cpp_shape_inference.pb.h" #include "tensorflow/core/framework/full_type.pb.h" -#include "tensorflow/python/framework/cpp_shape_inference.pb.h" namespace tensorflow { -void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { +// Hack to export the tensorflow::RecordMutation symbol for windows. +// Do not delete. Do not use. +void ExportRecordMutation( // NOLINT: Intentionally unused function. + TF_Graph* graph, const TF_Operation& op, const char* mutation_type) { mutex_lock l(graph->mu); - graph->graph.AddControlEdge(&input->node, &op->node); - RecordMutation(graph, *op, "adding control input"); + RecordMutation(graph, op, mutation_type); +} + +void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { + TF_AddOperationControlInput(graph, op, input); } void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, TF_Buffer* attr_value_proto, TF_Status* status) { - AttrValue attr_val; - if (!attr_val.ParseFromArray(attr_value_proto->data, - attr_value_proto->length)) { - status->status = - tensorflow::errors::InvalidArgument("Invalid AttrValue proto"); - return; - } - - mutex_lock l(graph->mu); - op->node.AddAttr(attr_name, attr_val); - RecordMutation(graph, *op, "setting attribute"); + TF_SetAttr(graph, op, attr_name, attr_value_proto, status); } void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, TF_Status* status) { - mutex_lock l(graph->mu); - op->node.ClearAttr(attr_name); - RecordMutation(graph, *op, "clearing attribute"); + TF_ClearAttr(graph, op, attr_name, status); } void SetFullType(TF_Graph* graph, TF_Operation* op, - const FullTypeDef& full_type) { - mutex_lock l(graph->mu); - *op->node.mutable_def()->mutable_experimental_type() = full_type; - RecordMutation(graph, *op, "setting fulltype"); + const TF_Buffer* full_type_proto) { + TF_SetFullType(graph, op, full_type_proto); } void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { - mutex_lock l(graph->mu); - op->node.set_requested_device(device); - RecordMutation(graph, *op, "setting device"); + TF_SetRequestedDevice(graph, op, device); } void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, @@ -68,13 +60,12 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } void ExtendSession(TF_Session* session, TF_Status* status) { - ExtendSessionGraphHelper(session, status); - session->extend_before_run = false; + TF_ExtendSession(session, status); } std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { Node* node = &output.oper->node; - CppShapeInferenceResult::HandleData handle_data; + tensorflow::core::CppShapeInferenceResult::HandleData handle_data; handle_data.set_is_set(true); { mutex_lock l(graph->mu); @@ -100,41 +91,12 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, size_t proto_len, TF_Status* status) { - tensorflow::CppShapeInferenceResult::HandleData handle_data; - if (!handle_data.ParseFromArray(proto, proto_len)) { - status->status = tensorflow::errors::InvalidArgument( - "Couldn't deserialize HandleData proto"); - return; - } - DCHECK(handle_data.is_set()); - - tensorflow::mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(&output.oper->node); - - std::vector shapes_and_types; - for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { - tensorflow::shape_inference::ShapeHandle shape; - status->status = - ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); - if (TF_GetCode(status) != TF_OK) return; - shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), - shape_and_type_proto.type()); - } - ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); + TF_SetHandleShapeAndType(graph, output, proto, proto_len, status); } void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, TF_Status* status) { - mutex_lock l(graph->mu); - status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, - new_src.index, &dst->node); - if (TF_GetCode(status) == TF_OK) { - // This modification only updates the destination node for - // the purposes of running this graph in a session. Thus, we don't - // record the source node as being modified. - RecordMutation(graph, *dst, "adding input tensor"); - } + TF_AddWhileInputHack(graph, new_src, dst, status); } } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index ef677161091ad5..043b76686b4bca 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -40,7 +40,7 @@ void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, // Sets the experimental_type` field in the node_def Protocol Buffer. void SetFullType(TF_Graph* graph, TF_Operation* op, - const FullTypeDef& full_type); + const TF_Buffer* full_type_proto); void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); diff --git a/tensorflow/cc/experimental/libexport/load.cc b/tensorflow/cc/experimental/libexport/load.cc index c045dbd4e78058..be9319b066d74d 100644 --- a/tensorflow/cc/experimental/libexport/load.cc +++ b/tensorflow/cc/experimental/libexport/load.cc @@ -23,12 +23,6 @@ limitations under the License. #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" -#define RETURN_IF_ERROR(s) \ - { \ - auto c = (s); \ - if (!c.ok()) return c; \ - } - namespace tensorflow { namespace libexport { @@ -41,11 +35,11 @@ tensorflow::StatusOr TFPackage::Load(const std::string& path) { const string saved_model_pbtxt_path = io::JoinPath(path, kSavedModelFilenamePbTxt); if (Env::Default()->FileExists(saved_model_pb_path).ok()) { - RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), saved_model_pb_path, - &tf_package.saved_model_proto_)); + TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), saved_model_pb_path, + &tf_package.saved_model_proto_)); } else if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { - RETURN_IF_ERROR(ReadTextProto(Env::Default(), saved_model_pbtxt_path, - &tf_package.saved_model_proto_)); + TF_RETURN_IF_ERROR(ReadTextProto(Env::Default(), saved_model_pbtxt_path, + &tf_package.saved_model_proto_)); } else { return Status(absl::StatusCode::kNotFound, "Could not find SavedModel .pb or .pbtxt at supplied export " @@ -65,7 +59,7 @@ tensorflow::StatusOr TFPackage::Load(const std::string& path) { tf_package.variable_reader_ = std::make_unique( tensorflow::Env::Default(), tf_package.variables_filepath_); tensorflow::Tensor object_graph_tensor; - RETURN_IF_ERROR(tf_package.variable_reader_->Lookup( + TF_RETURN_IF_ERROR(tf_package.variable_reader_->Lookup( tensorflow::kObjectGraphProtoKey, &object_graph_tensor)); const auto* object_graph_string = reinterpret_cast( diff --git a/tensorflow/cc/experimental/libexport/load.h b/tensorflow/cc/experimental/libexport/load.h index cd85fb5f2b7efc..8ab5019eba45fe 100644 --- a/tensorflow/cc/experimental/libexport/load.h +++ b/tensorflow/cc/experimental/libexport/load.h @@ -92,7 +92,7 @@ class TFPackage { bool HasCheckpoint() { return has_checkpoint_; } // Returns the path to the variables file. - const std::string GetVariablesFilepath() { return variables_filepath_; } + const std::string GetVariablesFilepath() const { return variables_filepath_; } private: SavedModel saved_model_proto_; diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index d9ed820e15e24e..6667b6919d52e6 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -306,11 +306,11 @@ void Scope::UpdateStatus(const Status& s) const { } } -Status Scope::ToGraphDef(GraphDef* gdef) const { +Status Scope::ToGraphDef(GraphDef* gdef, bool include_debug_info) const { if (!ok()) { return *impl()->status_; } - graph()->ToGraphDef(gdef); + graph()->ToGraphDef(gdef, /*include_flib_def=*/true, include_debug_info); return OkStatus(); } diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 777d0ed6c01e39..771fdaa11688c9 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -200,8 +200,10 @@ class Scope { /// If status() is ok, convert the Graph object stored in this scope /// to a GraphDef proto and return an ok Status. Otherwise, return the error - /// status as is without performing GraphDef conversion. - Status ToGraphDef(GraphDef* gdef) const; + /// status as is without performing GraphDef conversion. If + /// `include_debug_info` is true, populate the `debug_info` field of the + /// GraphDef from stack traces in this Graph. + Status ToGraphDef(GraphDef* gdef, bool include_debug_info = false) const; // START_SKIP_DOXYGEN diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index de7b00e2eaeb6e..c33a59027c5f31 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -2,6 +2,7 @@ #Description: # TensorFlow SavedModel. +load("//tensorflow:strict.default.bzl", "py_strict_binary") load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( @@ -249,7 +250,7 @@ tf_cc_test( ) # A subset of the TF2 saved models can be generated with this tool. -py_binary( +py_strict_binary( name = "testdata/generate_saved_models", srcs = ["testdata/generate_saved_models.py"], data = [ @@ -259,12 +260,14 @@ py_binary( python_version = "PY3", srcs_version = "PY3", deps = [ + "//tensorflow/python/client:session", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/module", + "//tensorflow/python/ops:io_ops", "//tensorflow/python/ops:lookup_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", @@ -277,17 +280,18 @@ py_binary( # copybara:uncomment_begin(google-only) # -# py_binary( +# py_strict_binary( # name = "testdata/generate_chunked_models", # srcs = ["testdata/generate_chunked_models.py"], # python_version = "PY3", # srcs_version = "PY3", # deps = [ +# "//third_party/py/numpy", # "//tensorflow/python/compat:v2_compat", # "//tensorflow/python/eager:def_function", # "//tensorflow/python/framework:constant_op", +# "//tensorflow/python/lib/io:lib", # "//tensorflow/python/module", -# "//tensorflow/python/platform:client_testlib", # "//tensorflow/python/saved_model:loader", # "//tensorflow/python/saved_model:save", # "//tensorflow/python/saved_model:save_options", @@ -295,6 +299,7 @@ py_binary( # "//tensorflow/tools/proto_splitter:constants", # "//tensorflow/tools/proto_splitter/python:saved_model", # "@absl_py//absl:app", +# "@absl_py//absl/flags", # ], # ) # @@ -459,6 +464,7 @@ cc_library( ]) + if_android([ "//tensorflow/core:portable_tensorflow_lib_lite", ]) + if_google([ + ":fingerprinting_utils", "//tensorflow/tools/proto_splitter/cc:util", ]), alwayslink = True, @@ -481,6 +487,104 @@ cc_library( ]) + if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]), ) +# copybara:uncomment_begin(google-only) +# +# cc_library( +# name = "fingerprinting_utils_impl", +# srcs = [ +# "fingerprinting_utils.cc", +# "fingerprinting_utils.h", +# ], +# visibility = [ +# "//tensorflow:__pkg__", +# ], +# deps = [ +# ":constants", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/status:statusor", +# "@com_google_absl//absl/strings", +# "//third_party/riegeli/bytes:file_reader", +# "//third_party/riegeli/records:record_reader", +# "//tensorflow/core:lib", +# "//tensorflow/core:protos_all_cc", +# "//tensorflow/core/graph/regularization:simple_delete", +# "//tensorflow/core/graph/regularization:util", +# "//tensorflow/core/util/tensor_bundle:naming", +# "//tensorflow/tools/proto_splitter:chunk_proto_cc", +# "//tensorflow/tools/proto_splitter:merge", +# "//tensorflow/tools/proto_splitter/cc:util", +# "//tensorflow/tsl/platform:protobuf", +# ], +# alwayslink = True, +# ) +# +# cc_library( +# name = "fingerprinting_utils", +# hdrs = ["fingerprinting_utils.h"], +# visibility = [ +# "//tensorflow/cc/saved_model:__subpackages__", +# ], +# deps = if_static([ +# ":fingerprinting_utils_impl", +# "@com_google_protobuf//:protobuf_headers", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/status:statusor", +# "@com_google_absl//absl/strings", +# "//tensorflow/core:protos_all_cc", +# "//tensorflow/tools/proto_splitter:chunk_proto_cc", +# "//tensorflow/tsl/platform:protobuf", +# "//third_party/riegeli/bytes:file_reader", +# "//third_party/riegeli/records:record_reader", +# "//tensorflow/core:lib", +# ]), +# ) +# +# tf_cc_test( +# name = "fingerprinting_utils_test", +# srcs = ["fingerprinting_utils_test.cc"], +# data = [ +# "//tensorflow/tools/proto_splitter/testdata:many-field.cpb", +# "//tensorflow/tools/proto_splitter/testdata:split-standard.cpb", +# ], +# deps = [ +# ":fingerprinting_utils", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/strings", +# "//third_party/protobuf", +# "//third_party/riegeli/bytes:file_reader", +# "//third_party/riegeli/records:record_reader", +# "//tensorflow/core:protos_all_cc", +# "//tensorflow/core/platform:errors", +# "//tensorflow/core/platform:path", +# "//tensorflow/core/platform:protobuf", +# "//tensorflow/core/platform:test", +# "//tensorflow/tools/proto_splitter:chunk_proto_cc", +# "//tensorflow/tools/proto_splitter/cc:util", +# "//tensorflow/tools/proto_splitter/testdata:test_message_proto_cc", +# "@com_google_googletest//:gtest_main", +# ], +# ) +# +# tf_cc_test( +# name = "fingerprinting_chunked_test", +# size = "small", +# srcs = ["fingerprinting_chunked_test.cc"], +# data = [ +# ":saved_model_fingerprinting_test_files", +# ":saved_model_test_files", +# ], +# deps = [ +# ":fingerprinting", +# "@com_google_absl//absl/container:flat_hash_set", +# "//tensorflow/core:protos_all_cc", +# "//tensorflow/core:test", +# "//tensorflow/core/platform:path", +# "@com_google_googletest//:gtest_main", +# ], +# ) +# +# copybara:uncomment_end + tf_cc_test( name = "fingerprinting_test", size = "small", diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index 72f23ed745b5e1..d8f91267483f6d 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/cc/saved_model/constants.h" +// Placeholder for protosplitter riegeli includes. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/regularization/simple_delete.h" @@ -178,8 +179,7 @@ absl::StatusOr CreateFingerprintDef( return CreateFingerprintDefPb(export_dir, absl::StrCat(prefix, ".pb")); - return absl::UnimplementedError( - "Chunked proto fingerprinting unimplemented."); + return absl::PermissionDeniedError("Chunked proto format is not available in OSS."); } absl::StatusOr ReadSavedModelFingerprint( diff --git a/tensorflow/cc/saved_model/testdata/generate_saved_models.py b/tensorflow/cc/saved_model/testdata/generate_saved_models.py index 5644feaaeea5da..5b2e458bbb64c6 100644 --- a/tensorflow/cc/saved_model/testdata/generate_saved_models.py +++ b/tensorflow/cc/saved_model/testdata/generate_saved_models.py @@ -17,7 +17,7 @@ import os from absl import app -from keras.optimizers.optimizers_v2 import adam +from keras.optimizers.legacy import adam from tensorflow.python.client import session as session_lib from tensorflow.python.compat import v2_compat diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0d06315e3b4c57..ab84540ec8c683 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,7 +1,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_cuda_only_cc_test") load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") -load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps") +load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "tf_custom_op_py_strict_library", "tf_jit_compilation_passes_extra_deps") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") load( "//tensorflow/core/platform:build_config_root.bzl", @@ -1351,7 +1351,7 @@ tf_cc_test( ], ) -tf_custom_op_py_library( +tf_custom_op_py_strict_library( name = "xla_ops_py", kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], visibility = [ diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 1059a263d57f43..a2c4bbd466848c 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") @@ -17,10 +18,17 @@ cc_library( tf_gen_op_wrapper_py( name = "xla_ops_wrapper_py", out = "xla_ops.py", + extra_py_deps = [ + "//tensorflow/python:pywrap_tfe", + "//tensorflow/python/util:dispatch", + "//tensorflow/python/util:deprecation", + "//tensorflow/python/util:tf_export", + ], + py_lib_rule = py_strict_library, deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) -py_library( +py_strict_library( name = "xla_ops_grad", srcs = ["xla_ops_grad.py"], srcs_version = "PY3", diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index a0ba7086f91306..0f6fcaa8913fc7 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -689,6 +689,11 @@ Status PopulateCtxOutputsFromPjRtExecutableOutputs( const DataType& type = compilation_result.outputs[i].type; VLOG(2) << "Populating output for retval " << i << " type " << DataTypeString(type); + if (type == DT_VARIANT) { + return absl::UnimplementedError( + "Support for TensorList crossing the XLA/TF boundary " + "is not implemented"); + } if (compilation_result.outputs[i].is_constant) { bool requires_copy_to_device = GetDeviceType(ctx) != DEVICE_CPU; diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 429d1da9754f4a..af0cf97c31b64d 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") @@ -39,7 +39,7 @@ td_library( "ir/tfl_op_interfaces.td", "ir/tfl_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", @@ -63,7 +63,7 @@ td_library( "transforms/tensorlist_patterns.td", "utils/utils.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], deps = [ ":tensorflow_lite_ops_td_files", @@ -76,7 +76,7 @@ td_library( gentbl_cc_library( name = "tensorflow_lite_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -95,7 +95,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -122,7 +122,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_interfaces_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-interface-decls"], @@ -150,7 +150,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_enums_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-enum-decls"], @@ -178,7 +178,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_prepare_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -192,7 +192,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_lower_static_tensor_list_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -206,7 +206,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -220,7 +220,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_variables_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -234,7 +234,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_optimize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -248,7 +248,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_quantize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -262,7 +262,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_post_quantize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -276,7 +276,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tensorlist_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -852,7 +852,7 @@ filegroup( gentbl_cc_library( name = "op_quant_spec_getters_inc", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [([], "utils/generated_op_quant_spec_getters.inc")], tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", td_file = "ir/tfl_ops.td", @@ -863,7 +863,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tflite_op_coverage_spec_inc", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [([], "utils/tflite_op_coverage_spec.inc")], tblgen = "//tensorflow/compiler/mlir/lite/quantization:tflite_op_coverage_spec_getters_gen", td_file = "ir/tfl_ops.td", @@ -878,7 +878,7 @@ tf_native_cc_binary( srcs = [ "converter_gen.cc", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", @@ -888,7 +888,7 @@ tf_native_cc_binary( gentbl_cc_library( name = "converter_inc", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["--gen-operator-converters"], diff --git a/tensorflow/compiler/mlir/lite/experimental/common/BUILD b/tensorflow/compiler/mlir/lite/experimental/common/BUILD index 02fab009fda976..c7e4fa006d868a 100644 --- a/tensorflow/compiler/mlir/lite/experimental/common/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/common/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -6,7 +6,7 @@ cc_library( name = "outline_operations", srcs = ["outline_operations.cc"], hdrs = ["outline_operations.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index b012bd60b154ed..b0514e13937956 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -2,7 +2,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load( "@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", @@ -19,7 +19,7 @@ package( flatbuffer_cc_library( name = "runtime_metadata_fbs", srcs = ["runtime_metadata.fbs"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_library( @@ -88,7 +88,7 @@ cc_library( gentbl_cc_library( name = "transform_patterns_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -386,11 +386,11 @@ py_strict_library( proto_library( name = "tac_filter_proto", srcs = ["tac_filter.proto"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_proto_library( name = "tac_filter_cc_proto", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [":tac_filter_proto"], ) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 230e3565ad1865..9dc61d0d907981 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -1305,7 +1305,7 @@ static LogicalResult ComputeConvWindowedOutputSize( int64_t pad_low; int64_t pad_high; - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( input_size, filter_size, dilation_rate, stride, padding, output_size, &pad_low, &pad_high); // Return failure if expected_output_size could not be calculated. diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 51618d4826e6ad..9290272db90d28 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -27,11 +27,10 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index efa633e736ae69..b683e3859afc44 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -117,8 +117,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.guarantee_all_funcs_one_use = toco_flags.guarantee_all_funcs_one_use(); pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); - pass_config.enable_hlo_to_tf_conversion = - toco_flags.enable_hlo_to_tf_conversion(); return internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 2e5819a0e2fc63..e955159990457e 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -201,8 +201,6 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.guarantee_all_funcs_one_use = toco_flags.guarantee_all_funcs_one_use(); pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); - pass_config.enable_hlo_to_tf_conversion = - toco_flags.enable_hlo_to_tf_conversion(); pass_config.legalize_custom_tensor_list_ops = toco_flags.legalize_custom_tensor_list_ops(); diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index fb5efba769a066..c695d0a7f499df 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" @@ -346,6 +347,18 @@ Status ConvertMLIRToTFLiteFlatBuffer( mlir::TFL::PassConfig pass_config_copy = pass_config; pass_config_copy.outline_tf_while = true; + + // Checks whether the model contains an `XlaCallModuleOp` operation which + // is a wrapper around StableHLO. + // This option is mutually exclusive to `enable_stablehlo_conversion`, the + // latter of which takes precedence. + // TODO(b/290109282): explore removing the enable_hlo_to_tf_conversion flag + // entirely, such that the added passes are no-ops in the non-shlo case. + module->walk([&](mlir::TF::XlaCallModuleOp xla_call_module_op) { + pass_config_copy.enable_hlo_to_tf_conversion = true; + mlir::WalkResult::interrupt(); + }); + auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy, saved_model_tags, model_flags.saved_model_dir(), session, result); diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 33d54e4e449d60..ec839967182c61 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -1,6 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -33,7 +33,7 @@ td_library( srcs = [ "quantization.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization/ir:QuantizationOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -42,7 +42,7 @@ td_library( gentbl_cc_library( name = "quantization_interfaces_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-interface-decls"], diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD index dc1c7d841b5d05..727fb03d833964 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -14,7 +14,7 @@ td_library( "QuantOps.td", "QuantOpsBase.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -25,7 +25,7 @@ td_library( gentbl_cc_library( name = "QuantOpsIncGen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -57,7 +57,7 @@ gentbl_cc_library( gentbl_cc_library( name = "QuantPassIncGen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -89,7 +89,7 @@ cc_library( "QuantizeUtils.h", "UniformSupport.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":QuantOpsIncGen", ":QuantPassIncGen", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 20346adde817c7..93c50fed86f77a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -30,18 +30,17 @@ cc_library( "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/api", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", ], diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 7581b5c78cfbcd..3d7503cd64128b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -20,9 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -38,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { @@ -50,12 +49,11 @@ std::string TfLiteToMlir(const absl::string_view tflite_op_name) { // TODO(fengliuai): check the result for `fully_quantize` flag. TfLiteStatus QuantizeModel( - const tflite::ModelT& input_model, const tflite::TensorType& input_type, + const absl::string_view model_buffer, const tflite::TensorType& input_type, const tflite::TensorType& output_type, const tflite::TensorType& inference_type, const std::unordered_set& operator_names, - bool disable_per_channel, bool fully_quantize, - flatbuffers::FlatBufferBuilder* builder, + bool disable_per_channel, bool fully_quantize, std::string& output_buffer, tflite::ErrorReporter* error_reporter, bool verify_numeric, bool whole_model_verify, bool legacy_float_scale, const absl::flat_hash_set& denylisted_ops, @@ -73,18 +71,8 @@ TfLiteStatus QuantizeModel( StatusScopedDiagnosticHandler statusHandler(&context, /*propagate=*/true); - // Import input_model to a MLIR module - flatbuffers::FlatBufferBuilder input_builder; - flatbuffers::Offset input_model_location = - tflite::Model::Pack(input_builder, &input_model); - tflite::FinishModelBuffer(input_builder, input_model_location); - - std::string serialized_model( - reinterpret_cast(input_builder.GetBufferPointer()), - input_builder.GetSize()); - OwningOpRef module = tflite::FlatBufferToMlir( - serialized_model, &context, UnknownLoc::get(&context)); + model_buffer, &context, UnknownLoc::get(&context)); if (!module) { error_reporter->Report("Couldn't import flatbuffer to MLIR."); return kTfLiteError; @@ -130,20 +118,16 @@ TfLiteStatus QuantizeModel( return kTfLiteError; } - // Export the results to the builder - std::string result; + // Export the results. tflite::FlatbufferExportOptions options; options.toco_flags.set_force_select_tf_ops(false); options.toco_flags.set_enable_select_tf_ops(true); options.toco_flags.set_allow_custom_ops(true); if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options, - &result)) { + &output_buffer)) { error_reporter->Report("Failed to export MLIR to flatbuffer."); return kTfLiteError; } - builder->PushFlatBuffer(reinterpret_cast(result.data()), - result.size()); - return kTfLiteOk; } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 243af219da689b..d85aba47811675 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -15,39 +15,47 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_MODEL_H_ -#include #include #include #include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { -// Quantize the `input_model` and write the result to a flatbuffer `builder`. -// The `input_type`, `output_type` and `inference_type` can be -// float32/qint8/int8/int16. -// Return partially quantized model if `fully_quantize` is false. +// Quantizes the input model represented as `model_buffer` and writes the result +// to the `output_buffer`. Both `model_buffer` and `output_buffer` should be a +// valid FlatBuffer format for Model supported by TFLite. +// +// The `input_type`, `output_type` and `inference_type` can be float32 / qint8 / +// int8 / int16. +// +// Returns a partially quantized model if `fully_quantize` is false. Returns a +// non-OK status if the quantization fails. +// // When `verify_numeric` is true, the model will have it's original float ops // and NumericVerify ops to compare output values from the quantized and float -// ops. When `legacy_float_scale` is true, the quantizer will use float scale -// instead of double, and call TOCO's quantization routines to maintain -// bit-exactness of the values with the TOCO quantizer. +// ops. +// +// When `legacy_float_scale` is true, the quantizer will use float scale instead +// of double, and call TOCO's quantization routines to maintain bit-exactness of +// the values with the TOCO quantizer. TfLiteStatus QuantizeModel( - const tflite::ModelT& input_model, const tflite::TensorType& input_type, + absl::string_view model_buffer, const tflite::TensorType& input_type, const tflite::TensorType& output_type, const tflite::TensorType& inference_type, const std::unordered_set& operator_names, - bool disable_per_channel, bool fully_quantize, - flatbuffers::FlatBufferBuilder* builder, + bool disable_per_channel, bool fully_quantize, std::string& output_buffer, tflite::ErrorReporter* error_reporter, bool verify_numeric = false, bool whole_model_verify = false, bool legacy_float_scale = true, const absl::flat_hash_set& denylisted_ops = {}, const absl::flat_hash_set& denylisted_nodes = {}, bool enable_variable_quantization = false); + } // namespace lite } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 798e011dec247d..ee9c5e7852aea9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -67,74 +67,77 @@ ModelT UnPackFlatBufferModel(const Model& flatbuffer_model) { } TfLiteStatus QuantizeModel( - flatbuffers::FlatBufferBuilder* builder, ModelT* model, - const TensorType& input_type, const TensorType& output_type, - bool allow_float, const std::unordered_set& operator_names, + ModelT* model, const TensorType& input_type, const TensorType& output_type, + const bool allow_float, const std::unordered_set& operator_names, const TensorType& activations_type, ErrorReporter* error_reporter, - bool disable_per_channel = false, + std::string& output_buffer, const bool disable_per_channel = false, const absl::flat_hash_set& blocked_ops = {}, const absl::flat_hash_set& blocked_nodes = {}) { TensorType inference_tensor_type = activations_type; - bool fully_quantize = !allow_float; + const bool fully_quantize = !allow_float; + + flatbuffers::FlatBufferBuilder input_builder; + tflite::FinishModelBuffer(input_builder, + tflite::Model::Pack(input_builder, model)); + + const std::string input_buffer( + reinterpret_cast(input_builder.GetBufferPointer()), + input_builder.GetSize()); auto status = mlir::lite::QuantizeModel( - *model, input_type, output_type, inference_tensor_type, - /*operator_names=*/{}, disable_per_channel, fully_quantize, builder, + input_buffer, input_type, output_type, inference_tensor_type, + /*operator_names=*/{}, disable_per_channel, fully_quantize, output_buffer, error_reporter, /*verify_numeric=*/false, /*whole_model_verify=*/false, /*legacy_float_scale=*/true, blocked_ops, blocked_nodes); if (status != kTfLiteOk) { return status; } - std::string buffer( - reinterpret_cast(builder->GetCurrentBufferPointer()), - builder->GetSize()); - auto flatbuffer_model = - FlatBufferModel::BuildFromBuffer(buffer.c_str(), buffer.size()); + auto flatbuffer_model = FlatBufferModel::BuildFromBuffer( + output_buffer.data(), output_buffer.size()); *model = UnPackFlatBufferModel(*flatbuffer_model->GetModel()); return kTfLiteOk; } -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, const TensorType& input_type, +TfLiteStatus QuantizeModel(ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, - ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, allow_float, - /*operator_names=*/{}, TensorType_INT8, error_reporter); + ErrorReporter* error_reporter, + std::string& output_buffer) { + return QuantizeModel(model, input_type, output_type, allow_float, + /*operator_names=*/{}, TensorType_INT8, error_reporter, + output_buffer); } -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, const TensorType& input_type, +TfLiteStatus QuantizeModel(ModelT* model, const TensorType& input_type, const TensorType& output_type, - ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, - /*allow_float=*/false, error_reporter); + ErrorReporter* error_reporter, + std::string& output_buffer) { + return QuantizeModel(model, input_type, output_type, + /*allow_float=*/false, error_reporter, output_buffer); } -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, error_reporter); +TfLiteStatus QuantizeModel(ModelT* model, ErrorReporter* error_reporter, + std::string& output_buffer) { + return QuantizeModel(model, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, error_reporter, output_buffer); } TfLiteStatus QuantizeModelAllOperators( - flatbuffers::FlatBufferBuilder* builder, ModelT* model, - const TensorType& input_type, const TensorType& output_type, + ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, const TensorType& activations_type, - bool disable_per_channel, ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, allow_float, + bool disable_per_channel, ErrorReporter* error_reporter, + std::string& output_buffer) { + return QuantizeModel(model, input_type, output_type, allow_float, /*operator_names=*/{}, activations_type, error_reporter, - disable_per_channel); + output_buffer, disable_per_channel); } -TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, - const TensorType& input_type, - const TensorType& output_type, - bool allow_float, - const TensorType& activations_type, - ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, allow_float, - /*operator_names=*/{}, activations_type, error_reporter); +TfLiteStatus QuantizeModelAllOperators( + ModelT* model, const TensorType& input_type, const TensorType& output_type, + bool allow_float, const TensorType& activations_type, + ErrorReporter* error_reporter, std::string& output_buffer) { + return QuantizeModel(model, input_type, output_type, allow_float, + /*operator_names=*/{}, activations_type, error_reporter, + output_buffer); } std::unique_ptr ReadModel(const string& model_name) { @@ -180,8 +183,8 @@ class QuantizeModelTest : public testing::Test { std::unique_ptr input_model_; const Model* readonly_model_; tflite::ModelT model_; - flatbuffers::FlatBufferBuilder builder_; internal::FailOnErrorReporter error_reporter_; + std::string output_buffer_; // Raw buffer for quantized output model. }; void ExpectEqualTensor(TensorT* tensor, TensorT* expected_tensor) { @@ -279,20 +282,21 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConvModelTest, QuantizationSucceeds) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); - const uint8_t* buffer = builder_.GetBufferPointer(); - const Model* output_model = GetModel(buffer); + + const Model* output_model = GetModel(output_buffer_.data()); ASSERT_TRUE(output_model); } TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) { - auto status = QuantizeModel( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, /*operator_names=*/{}, TensorType_FLOAT32, - &error_reporter_, /*disable_per_channel=*/false, {"CONV_2D"}); + auto status = + QuantizeModel(&model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, /*operator_names=*/{}, + TensorType_FLOAT32, &error_reporter_, output_buffer_, + /*disable_per_channel=*/false, {"CONV_2D"}); EXPECT_THAT(status, Eq(kTfLiteOk)); ModelT expected_model; @@ -302,11 +306,11 @@ TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) { } TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayerByName) { - auto status = QuantizeModel( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, /*operator_names=*/{}, TensorType_FLOAT32, - &error_reporter_, /*disable_per_channel=*/false, /*blocked_ops=*/{}, - {"output"}); + auto status = QuantizeModel(&model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, /*operator_names=*/{}, + TensorType_FLOAT32, &error_reporter_, + output_buffer_, /*disable_per_channel=*/false, + /*blocked_ops=*/{}, {"output"}); EXPECT_THAT(status, Eq(kTfLiteOk)); ModelT expected_model; @@ -316,9 +320,9 @@ TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayerByName) { } TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, - /*allow_float=*/false, tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); for (const auto& subgraph : model_.subgraphs) { @@ -340,11 +344,10 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); - const uint8_t* buffer = builder_.GetBufferPointer(); - const Model* output_model = GetModel(buffer); + const Model* output_model = GetModel(output_buffer_.data()); ASSERT_TRUE(output_model); } @@ -361,8 +364,8 @@ class QuantizeSplitModelTest : public QuantizeModelTest { // should have the scales be hardcodes to the input scale value. TEST_F(QuantizeSplitModelTest, QuantizeSplit) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); // There is only one subgraph. @@ -458,9 +461,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; auto conv_op = subgraph->operators[0].get(); @@ -566,8 +569,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, /*disable_per_channel=*/true, &error_reporter_); + &model_, tensor_type_, tensor_type_, /*allow_float=*/false, tensor_type_, + /*disable_per_channel=*/true, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; auto conv_op = subgraph->operators[0].get(); @@ -684,8 +687,8 @@ class QuantizeSoftmaxTest : public QuantizeModelTest { TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; @@ -748,8 +751,8 @@ class QuantizeAvgPoolTest : public QuantizeModelTest { TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; @@ -809,8 +812,8 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); @@ -861,8 +864,8 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Verify ADD is quantized. @@ -935,10 +938,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConstInputTestInst, QuantizeConstInputTest, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { - auto status = - QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Verify ConstOp is quantized. @@ -981,8 +983,8 @@ class QuantizeArgMaxTest : public QuantizeModelTest { TEST_F(QuantizeArgMaxTest, VerifyArgMax) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; @@ -1026,10 +1028,9 @@ class QuantizeLSTMTest : public QuantizeModelTest { }; TEST_F(QuantizeLSTMTest, VerifyLSTM) { - // Quantize model. auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, true, - TensorType_INT8, &error_reporter_); + &model_, TensorType_FLOAT32, TensorType_FLOAT32, /*allow_float=*/true, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Read expected model. @@ -1053,8 +1054,8 @@ class QuantizeLSTM2Test : public QuantizeModelTest { TEST_F(QuantizeLSTM2Test, VerifyLSTM) { // Quantize model. auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/false, TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Read expected model. @@ -1077,10 +1078,9 @@ class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { TEST_F(QuantizeUnidirectionalSequenceLSTMTest, VerifyUnidirectionalSequenceLSTM) { - // Quantize model. auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_FLOAT32, TensorType_FLOAT32, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Read expected model. @@ -1105,8 +1105,8 @@ class QuantizeSVDFTest : public QuantizeModelTest { TEST_F(QuantizeSVDFTest, VerifySVDF) { // Quantize model. auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Read expected model. @@ -1129,8 +1129,8 @@ class QuantizeFCTest : public QuantizeModelTest { TEST_F(QuantizeFCTest, VerifyFC8x8) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; @@ -1182,8 +1182,8 @@ TEST_F(QuantizeFCTest, VerifyFC8x8) { TEST_F(QuantizeFCTest, VerifyFCFor16x8) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT16, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT16, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const std::unique_ptr& subgraph = model_.subgraphs[0]; @@ -1247,9 +1247,9 @@ class QuantizeCustomOpTest }; TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, GetParam(), GetParam(), - /*allow_float=*/true, GetParam(), &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, GetParam(), GetParam(), + /*allow_float=*/true, GetParam(), + &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; auto float_graph = readonly_model_->subgraphs()->Get(0); @@ -1286,7 +1286,7 @@ class QuantizePackTest : public QuantizeModelTest { }; TEST_F(QuantizePackTest, VerifyPack) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = QuantizeModel(&model_, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); @@ -1350,7 +1350,7 @@ class QuantizeMinimumMaximumTest }; TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = QuantizeModel(&model_, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; // Check that the first op is Quantize and the last is Dequant. @@ -1413,7 +1413,7 @@ class QuantizeUnpackTest : public QuantizeModelTest { }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = QuantizeModel(&model_, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); @@ -1470,9 +1470,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); // There is only one subgraph. @@ -1537,9 +1537,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); // There is only one subgraph. @@ -1596,8 +1596,8 @@ TEST_F(QuantizeWhereModelTest, QuantizeWhere) { // Where operator takes a BOOL tensor as input // and outputs INT64 indices, both of which // should not be quantized - auto status = QuantizeModel(&builder_, &model_, TensorType_BOOL, - TensorType_INT64, &error_reporter_); + auto status = QuantizeModel(&model_, TensorType_BOOL, TensorType_INT64, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); // There is only one subgraph. diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index fe5ca2ca8f1d47..4bf154e892bcdb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -14,10 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include -#include -#include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" @@ -36,21 +33,14 @@ static opt inputFileName(llvm::cl::Positional, namespace mlir { namespace { -TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, - flatbuffers::FlatBufferBuilder* builder) { - auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( - buffer.data(), buffer.size()); - if (nullptr == model_ptr) { - return TfLiteStatus::kTfLiteError; - } - std::unique_ptr model(model_ptr->GetModel()->UnPack()); +TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, + std::string& output_buffer) { tflite::StderrReporter error_reporter; return mlir::lite::QuantizeModel( - *model, tflite::TensorType_INT8, tflite::TensorType_INT8, - tflite::TensorType_INT8, {}, - /*disable_per_channel=*/false, - /*fully_quantize=*/true, builder, &error_reporter); + buffer, tflite::TensorType_INT8, tflite::TensorType_INT8, + tflite::TensorType_INT8, {}, /*disable_per_channel=*/false, + /*fully_quantize=*/true, output_buffer, &error_reporter); } } // namespace @@ -66,16 +56,13 @@ int main(int argc, char** argv) { return 1; } auto buffer = file_or_err->get(); - flatbuffers::FlatBufferBuilder builder; - auto status = - mlir::QuantizeAnnotatedModel(buffer->getBuffer().str(), &builder); - if (status != kTfLiteOk) { + std::string output_buffer; + if (auto status = mlir::QuantizeAnnotatedModel(buffer->getBuffer().str(), + output_buffer); + status != kTfLiteOk) { return 1; } - std::cout << std::string( - reinterpret_cast(builder.GetBufferPointer()), - builder.GetSize()) - << "\n"; + std::cout << output_buffer << "\n"; return 0; } diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD index 9652196367f398..cdefbdb1e28a4e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -24,7 +24,7 @@ td_library( srcs = [ "fallback_to_flex_patterns.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], deps = [ "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", @@ -35,7 +35,7 @@ td_library( gentbl_cc_library( name = "ptq_fallback_to_flex_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index aaeb0d4dff4836..9fd1317d132c9d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -291,7 +291,6 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD index c7e8f44563dd42..8910f55f81eb8b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -17,7 +17,7 @@ cc_library( "flatbuffer_operator.h", "flatbuffer_translator.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", @@ -48,7 +48,7 @@ cc_library( "flatbuffer_export.cc", ], hdrs = ["flatbuffer_export.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":flatbuffer_translator", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir index b2e2aac35983df..f97bab00c89503 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir @@ -79,6 +79,7 @@ func.func @listFromTensor(%tensor: tensor<3xi32>, %shape : tensor) -> ten func.return %0 : tensor>> // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListFromTensor", custom_option = #tfl} : (tensor<3xi32>, tensor) -> tensor>> } + // ----- // CHECK-LABEL: typeNotSupportedNotLegalized diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 4993a8babfbb79..f780e10f89b97c 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -602,6 +602,90 @@ func.func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: // CHECK: return %[[add_result]] } +// CHECK-LABEL: @FuseReshapeAroundBMMLHS +func.func @FuseReshapeAroundBMMLHS(%arg0: tensor<6x5x1024xf32>) -> tensor<6x5x8192xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<1024x8192xf32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<3xi32> + %cst_1 = arith.constant dense_resource<__elided__> : tensor<2xi32> + %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<6x5x1024xf32>, tensor<2xi32>) -> tensor<30x1024xf32> + %1 = "tfl.batch_matmul"(%0, %cst) {adj_x = false, adj_y = false} : (tensor<30x1024xf32>, tensor<1024x8192xf32>) -> tensor<30x8192xf32> + %2 = "tfl.reshape"(%1, %cst_0) : (tensor<30x8192xf32>, tensor<3xi32>) -> tensor<6x5x8192xf32> + return %2 : tensor<6x5x8192xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<1024x8192xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) {adj_x = false, adj_y = false} : (tensor<6x5x1024xf32>, tensor<1024x8192xf32>) -> tensor<6x5x8192xf32> + // CHECK: return %0 : tensor<6x5x8192xf32> +} + +// CHECK-LABEL: @FuseReshapeAroundBMMNagativeTest +func.func @FuseReshapeAroundBMMNagativeTest(%arg0: tensor<5x4x1x1024xf32>, %arg1: tensor<5x1024x8192xf32>) -> tensor<5x4x1x8192xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<4xi32> + %0 = "tfl.reshape"(%arg0, %cst) : (tensor<5x4x1x1024xf32>, tensor<3xi32>) -> tensor<5x4x1024xf32> + %1 = "tfl.batch_matmul"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<5x4x1024xf32>, tensor<5x1024x8192xf32>) -> tensor<5x4x8192xf32> + %2 = "tfl.reshape"(%1, %cst_0) : (tensor<5x4x8192xf32>, tensor<4xi32>) -> tensor<5x4x1x8192xf32> + return %2 : tensor<5x4x1x8192xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> + // CHECK: %cst_0 = arith.constant dense_resource<__elided__> : tensor<4xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<5x4x1x1024xf32>, tensor<3xi32>) -> tensor<5x4x1024xf32> + // CHECK: %1 = "tfl.batch_matmul"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<5x4x1024xf32>, tensor<5x1024x8192xf32>) -> tensor<5x4x8192xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<5x4x8192xf32>, tensor<4xi32>) -> tensor<5x4x1x8192xf32> + // CHECK: return %2 : tensor<5x4x1x8192xf32> +} + +// CHECK-LABEL: @FuseReshapeAroundBMMNagativeTest2 +// Checks that the pattern matcher FuseReshapesAroundBatchMatMulLHS does not get +// applied for this case that does not pass the constraint around input rank. +func.func @FuseReshapeAroundBMMNagativeTest2(%arg0: tensor<2x1536xf32>) -> tensor<2x768xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<2xi32> + %402 = "tfl.reshape"(%arg0, %cst) : (tensor<2x1536xf32>, tensor<3xi32>) -> tensor<2x12x128xf32> + %403 = "tfl.pseudo_qconst"() {qtype = tensor<128x64x!quant.uniform>, value = dense<9> : tensor<128x64xi8>} : () -> tensor<128x64x!quant.uniform> + %404 = "tfl.batch_matmul"(%402, %403) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = true} : (tensor<2x12x128xf32>, tensor<128x64x!quant.uniform>) -> tensor<2x12x64xf32> + %405 = "tfl.reshape"(%404, %cst_0) : (tensor<2x12x64xf32>, tensor<2xi32>) -> tensor<2x768xf32> + return %405 : tensor<2x768xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> + // CHECK: %cst_0 = arith.constant dense_resource<__elided__> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<2x1536xf32>, tensor<3xi32>) -> tensor<2x12x128xf32> + // CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<128x64x!quant.uniform>, value = dense<9> : tensor<128x64xi8>} : () -> tensor<128x64x!quant.uniform> + // CHECK: %2 = "tfl.batch_matmul"(%0, %1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = true} : (tensor<2x12x128xf32>, tensor<128x64x!quant.uniform>) -> tensor<2x12x64xf32> + // CHECK: %3 = "tfl.reshape"(%2, %cst_0) : (tensor<2x12x64xf32>, tensor<2xi32>) -> tensor<2x768xf32> + // CHECK: return %3 : tensor<2x768xf32> +} + +// CHECK-LABEL: @FuseReshapeAroundBMMRHS +func.func @FuseReshapeAroundBMMRHS(%arg0: tensor<1x3x6x5x1024xf32>) -> tensor<1x3x6x5x8192xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "inputs", outputs = "Identity_1"}} { + %cst = arith.constant dense_resource<__elided__> : tensor<1x1024x8192xf32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<5xi32> + %cst_1 = arith.constant dense_resource<__elided__> : tensor<3xi32> + %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x3x6x5x1024xf32>, tensor<3xi32>) -> tensor<1x90x1024xf32> + %1 = "tfl.batch_matmul"(%0, %cst) {adj_x = false, adj_y = false} : (tensor<1x90x1024xf32>, tensor<1x1024x8192xf32>) -> tensor<1x90x8192xf32> + %2 = "tfl.reshape"(%1, %cst_0) : (tensor<1x90x8192xf32>, tensor<5xi32>) -> tensor<1x3x6x5x8192xf32> + return %2 : tensor<1x3x6x5x8192xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<1x1024x8192xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) {adj_x = false, adj_y = false} : (tensor<1x3x6x5x1024xf32>, tensor<1x1024x8192xf32>) -> tensor<1x3x6x5x8192xf32> + // CHECK: return %0 : tensor<1x3x6x5x8192xf32> +} + +// CHECK-LABEL: @FuseTransposeIntoBMM_RHS +func.func @FuseTransposeIntoBMM_RHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> { + %cst_1 = arith.constant dense_resource<__elided__> : tensor<3xi32> + %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32> + %33 = "tfl.batch_matmul"(%arg0, %32) {adj_x = false, adj_y = false} : (tensor<1x4x1440x256xf32>, tensor<1x256x1440xf32>) -> tensor<1x4x1440x1440xf32> + return %33 : tensor<1x4x1440x1440xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x1440x256xf32>, tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> + // CHECK: return %0 : tensor<1x4x1440x1440xf32> +} + +// CHECK-LABEL: @FuseTransposeIntoBMM_LHS +func.func @FuseTransposeIntoBMM_LHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x256x256xf32> { + %cst_1 = arith.constant dense_resource<__elided__> : tensor<3xi32> + %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32> + %33 = "tfl.batch_matmul"(%32, %arg0) {adj_x = false, adj_y = false} : (tensor<1x256x1440xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32> + return %33 : tensor<1x4x256x256xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg1, %arg0) {adj_x = true, adj_y = false} : (tensor<1x1440x256xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32> + // CHECK: return %0 : tensor<1x4x256x256xf32> +} + // CHECK-LABEL: @FuseFullyConnectedReshapeAddConst // FOLD-LABEL: @FuseFullyConnectedReshapeAddConst func.func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index eb84f8143113b4..23886f95d089db 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -282,7 +282,7 @@ Status ConvertTFExecutorToStablehloFlatbuffer( } // for now always output mlir - if (/*export_to_mlir*/ true) { + if (/*export_to_mlir*/ /* DISABLES CODE */ (true)) { llvm::raw_string_ostream os(*result); module.print(os); return statusHandler.ConsumeStatus(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 1c9d7bbb002d2c..e0caac4e90490d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -123,6 +123,33 @@ class OptimizePass : public impl::OptimizePassBase { void runOnOperation() override; }; +// Return true if the product of dimension values of a subsection of the tensor +// is equal to the non-contracting dimension after a reshape +bool BroadcastDimsProductEqual(Value input, Value output, + size_t agg_start_idx) { + ArrayRef input_shape = input.getType().cast().getShape(); + ArrayRef output_shape = + output.getType().cast().getShape(); + + int64_t agg_value = 1; + for (size_t i = agg_start_idx; i < input_shape.size() - 1; ++i) { + agg_value *= input_shape[i]; + } + + return (agg_value == output_shape[agg_start_idx]); +} + +// Return true if the product of dimension values of a subsection of the tensor +// is equal to the non-contracting dimension after a reshape +bool AreLastTwoDimsTransposed(Value input, Value output) { + ArrayRef input_shape = input.getType().cast().getShape(); + ArrayRef output_shape = + output.getType().cast().getShape(); + + return (input_shape.back() == output_shape[output_shape.size() - 2]) && + (input_shape[input_shape.size() - 2] == output_shape.back()); +} + // Returns whether the given type `a` is broadcast-compatible with `b`. bool IsBroadcastableElementsAttrAndType(Type a, Type b) { return OpTrait::util::getBroadcastedType(a, b) != Type(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 66d625e970b31c..b056dc1c977347 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -57,6 +57,8 @@ class HasRank : Constraint< class FloatValueEquals : Constraint>; +class IsBoolAttrEqual : Constraint>; // Flattens a constant tensor to 1D. def FlattenTo1D : NativeCodeCall<"FlattenTo1D($0)">; @@ -1430,3 +1432,82 @@ def FuseLeakyReluConst : Pat< (HasOneUse $geq_out), (HasOneUse $mul_out), ]>; + +// Return true if the product of dimension values of a subsection of the tensor +// is equal to the non-contracting dimension after a reshape +class BroadcastDimsProductEqual : Constraint>; + +// Returns true if the dimensions of a subsection of two tensors is equal +// and the subsections are not empty +class AreTensorSubSectionShapesEqual : Constraint().getShape()" + ".drop_back("#skip_last#").drop_front("#skip_first#") ==" + "$1.getType().dyn_cast().getShape()" + ".drop_back("#skip_last#").drop_front("#skip_first#"))" + "&& !$0.getType().dyn_cast().getShape()" + ".drop_back("#skip_last#").drop_front("#skip_first#").empty()">>; + +// Returns true if the broadcast dimension of a tensor is [1] +// here- broadcast dimension is first prefix dimension +// excluding the last two dimensions +def IsBroadcastDimEqualToOne : Constraint().getShape()[0] == 1">>; + +// Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp +// This pattern is applied when the rank of rhs is 2 +// which means it has empty broadcast dimensions +def FuseReshapesAroundBatchMatMulLHS: Pat< + (TFL_ReshapeOp:$final_shape_change + (TFL_BatchMatMulOp:$bmm_tmp_output + (TFL_ReshapeOp:$initial_shape_change $input, (Arith_ConstantOp $s0)), + $rhs, $adj_x, $adj_y, $bool_attr), + (Arith_ConstantOp $s1)), + (TFL_BatchMatMulOp $input, $rhs, $adj_x, $adj_y, $bool_attr), + [(HasRank<2> $rhs), + (HasRank<2> $initial_shape_change), + (BroadcastDimsProductEqual<0> $input, $initial_shape_change), + (BroadcastDimsProductEqual<0> $final_shape_change, $bmm_tmp_output), + (AreTensorSubSectionShapesEqual<0, 1> $input, $final_shape_change)]>; + +// Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp +// This pattern is applied when the rank of rhs is 3 +// and the broadcast dimension is [1] +def FuseReshapesAroundBatchMatMulLHS1: Pat< + (TFL_ReshapeOp:$final_shape_change + (TFL_BatchMatMulOp:$bmm_tmp_output + (TFL_ReshapeOp:$initial_shape_change $input, (Arith_ConstantOp $s0)), + $rhs, $adj_x, $adj_y, $bool_attr), + (Arith_ConstantOp $s1)), + (TFL_BatchMatMulOp $input, $rhs, $adj_x, $adj_y, $bool_attr), + [(HasRank<3> $rhs), + (HasRank<3> $initial_shape_change), + (IsBroadcastDimEqualToOne $rhs), + (IsBroadcastDimEqualToOne $input), + (BroadcastDimsProductEqual<1> $input, $initial_shape_change), + (BroadcastDimsProductEqual<1> $final_shape_change, $bmm_tmp_output), + (AreTensorSubSectionShapesEqual<1, 1> $input, $final_shape_change)]>; + +def AreLastTwoDimsTransposed : Constraint>; + +// Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp +def FuseTransposeIntoBatchMatMulRHS: Pat< + (TFL_BatchMatMulOp $lhs, + (TFL_TransposeOp:$transposed_value $input, (Arith_ConstantOp $p0)), + $adj_x, $adj_y, $asymmetric_quantize_inputs), + (TFL_BatchMatMulOp $lhs, $input, $adj_x, ConstBoolAttrTrue, $asymmetric_quantize_inputs), + [(AreLastTwoDimsTransposed $input, $transposed_value), + (IsBoolAttrEqual<"false"> $adj_y), + (AreTensorSubSectionShapesEqual<0, 2> $input, $transposed_value)]>; + +// Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp +def FuseTransposeIntoBatchMatMulLHS: Pat< + (TFL_BatchMatMulOp + (TFL_TransposeOp:$transposed_value $input, (Arith_ConstantOp $p0)), + $rhs, $adj_x, $adj_y, $asymmetric_quantize_inputs), + (TFL_BatchMatMulOp $input, $rhs, ConstBoolAttrTrue, $adj_y, $asymmetric_quantize_inputs), + [(AreLastTwoDimsTransposed $input, $transposed_value), + (IsBoolAttrEqual<"false"> $adj_x), + (AreTensorSubSectionShapesEqual<0, 2> $input, $transposed_value)]>; + diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 50404aab815c4a..c3b045f19a86b0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") @@ -36,7 +36,7 @@ cc_library( hdrs = [ "passes/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":fill_quantization_options", ":quantization_options_proto_cc", @@ -63,7 +63,7 @@ cc_library( gentbl_cc_library( name = "bridge_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -89,7 +89,7 @@ cc_library( hdrs = [ "passes/bridge/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/compiler/mlir/tf2xla:__subpackages__", @@ -162,7 +162,7 @@ cc_library( hdrs = [ "quantize_passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [":internal_visibility_allowlist_package"], deps = [ ":fill_quantization_options", @@ -180,7 +180,7 @@ cc_library( gentbl_cc_library( name = "stablehlo_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -200,7 +200,7 @@ cc_library( name = "fill_quantization_options", srcs = ["utils/fill_quantization_options.cc"], hdrs = ["utils/fill_quantization_options.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", @@ -213,7 +213,7 @@ cc_library( name = "math_utils", srcs = ["utils/math_utils.cc"], hdrs = ["utils/math_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = ["@llvm-project//mlir:Support"], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index 0424ad97cc61ca..b54fc3bfc2ff22 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include +#include +#include #include #include "llvm/ADT/STLExtras.h" @@ -435,6 +438,55 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { } }; +// A shared matchAndRewrite implementation for dot-like hybrid quantized +// operators. Hybrid ops are currently only interpreted as weight-only +// quantization ops, this might change in the future. +// +// All attrs of the original op are preserved after the conversion. +template +LogicalResult matchAndRewriteDotLikeHybridOp( + OpType &op, OpAdaptorType &adaptor, ConversionPatternRewriter &rewriter, + const quant::UniformQuantizedType &rhs_element_type) { + // For dot like hybrid ops, lhs is float type, rhs is uniform + // quantized type and result is float type. + // For weight-only quantization: + // result = hybridOp(lhs, dequant(rhs)) + Value lhs_float32_tensor = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + auto res_float32_tensor_type = + op.getResult().getType().template cast(); + + // Get scales and zero points for rhs. + Value rhs_zero_point = rewriter.create( + op->getLoc(), + rewriter.getF32FloatAttr((rhs_element_type.getZeroPoint()))); + Value rhs_scale_constant = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr( + static_cast(rhs_element_type.getScale()))); + + // Dequantize rhs_float32_tensor. + Value rhs_float32_tensor = rewriter.create( + op->getLoc(), res_float32_tensor_type, rhs); + rhs_float32_tensor = rewriter.create( + op->getLoc(), res_float32_tensor_type, rhs_float32_tensor, rhs_zero_point, + nullptr); + rhs_float32_tensor = rewriter.create( + op->getLoc(), res_float32_tensor_type, rhs_float32_tensor, + rhs_scale_constant, nullptr); + + // Execute conversion target op. + SmallVector operands{lhs_float32_tensor, rhs_float32_tensor}; + Value res_float32 = rewriter.create( + op->getLoc(), res_float32_tensor_type, operands, op->getAttrs()); + + Value half = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(0.5f)); + res_float32 = rewriter.create( + op->getLoc(), res_float32_tensor_type, res_float32, half, nullptr); + rewriter.replaceOpWithNewOp(op, res_float32); + return success(); +} + // A shared matchAndRewrite implementation for dot-like quantized operators. // // Dot-like operators refer to operators that generate a tensor where each @@ -446,24 +498,42 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { template LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, ConversionPatternRewriter &rewriter) { - auto lhs_element_type = op.getLhs() - .getType() - .getElementType() - .template dyn_cast(); - auto rhs_element_type = op.getRhs() - .getType() - .getElementType() - .template dyn_cast(); - auto res_element_type = op.getResult() - .getType() - .getElementType() - .template dyn_cast(); - - // Check if the operands and result are UniformQuantizedTypes. - if (!lhs_element_type || !rhs_element_type || !res_element_type) { + auto lhs_element_type = getElementTypeOrSelf(op.getLhs().getType()); + auto rhs_element_quant_type = + op.getRhs() + .getType() + .getElementType() + .template dyn_cast(); + auto res_element_type = getElementTypeOrSelf(op.getResult()); + + // Check if the right operand is UniformQuantizedTypes. + if (!rhs_element_quant_type) { return rewriter.notifyMatchFailure( op, "Legalization failed: supports only per-tensor quantization."); } + + if (lhs_element_type.template isa()) { + // If lhs is uniform quantized type, result should also be uniform + // quantized type, representing none-hybrid op. + if (!res_element_type.template isa()) { + op->emitError("Unsupported result element type for " + + op->getName().getStringRef().str()); + return failure(); + } + } else if (lhs_element_type.isF32()) { + // If lhs is float32 type, result should also be float32 type, + // representing hybrid op. + if (!res_element_type.isF32()) { + op->emitError("Unsupported result element type for " + + op->getName().getStringRef().str()); + return failure(); + } + return matchAndRewriteDotLikeHybridOp(op, adaptor, rewriter, + rhs_element_quant_type); + } else { + return rewriter.notifyMatchFailure(op, "Unsupported input element type."); + } + auto res_float32_tensor_type_or = GetSameShapeTensorType( op, op.getResult().getType().template cast(), rewriter.getF32Type(), rewriter); @@ -471,6 +541,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, return failure(); } + auto lhs_element_quant_type = + lhs_element_type.template dyn_cast(); + auto res_element_quant_type = + res_element_type.template dyn_cast(); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); @@ -481,10 +555,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // Get scales and zero points for both operands. Value lhs_zero_point = rewriter.create( op->getLoc(), - rewriter.getF32FloatAttr((lhs_element_type.getZeroPoint()))); + rewriter.getF32FloatAttr((lhs_element_quant_type.getZeroPoint()))); Value rhs_zero_point = rewriter.create( op->getLoc(), - rewriter.getF32FloatAttr((rhs_element_type.getZeroPoint()))); + rewriter.getF32FloatAttr((rhs_element_quant_type.getZeroPoint()))); // Offset xxx_int32_tensor according to zero points. Value lhs_float32_tensor = rewriter.create( @@ -507,10 +581,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // scales. Value result_zero_point = rewriter.create( op->getLoc(), - rewriter.getF32FloatAttr((res_element_type.getZeroPoint()))); - const double effective_scale = lhs_element_type.getScale() * - rhs_element_type.getScale() / - res_element_type.getScale(); + rewriter.getF32FloatAttr((res_element_quant_type.getZeroPoint()))); + const double effective_scale = lhs_element_quant_type.getScale() * + rhs_element_quant_type.getScale() / + res_element_quant_type.getScale(); Value effective_scale_constant = rewriter.create( op->getLoc(), rewriter.getF32FloatAttr(static_cast(effective_scale))); @@ -543,10 +617,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // Clamp results by [quantization_min, quantization_max]. Value result_quantization_min = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - res_element_type.getStorageTypeMin()))); + res_element_quant_type.getStorageTypeMin()))); Value result_quantization_max = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - res_element_type.getStorageTypeMax()))); + res_element_quant_type.getStorageTypeMax()))); res_int32 = rewriter.create( op->getLoc(), *res_int32_tensor_type_or, result_quantization_min, res_int32, result_quantization_max); @@ -554,7 +628,7 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // Convert results back to int8. auto res_final_tensor_type_or = GetSameShapeTensorType( op, res_int32_tensor_type_or->template cast(), - res_element_type.getStorageType(), rewriter); + res_element_quant_type.getStorageType(), rewriter); rewriter.replaceOpWithNewOp(op, *res_final_tensor_type_or, res_int32); @@ -568,8 +642,7 @@ class ConvertUniformQuantizedDotOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::DotOp op, mhlo::DotOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - return matchAndRewriteDotLikeOp( - op, adaptor, rewriter); + return matchAndRewriteDotLikeOp(op, adaptor, rewriter); } }; @@ -581,9 +654,7 @@ class ConvertUniformQuantizedConvolutionOp LogicalResult matchAndRewrite( mhlo::ConvolutionOp op, mhlo::ConvolutionOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - return matchAndRewriteDotLikeOp(op, adaptor, - rewriter); + return matchAndRewriteDotLikeOp(op, adaptor, rewriter); } }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index cece788381b0b6..896aa1ba833395 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -24,6 +24,6 @@ def QuantizeWeightPass : Pass<"stablehlo-quantize-weight", "mlir::func::FuncOp"> def PrepareSrqQuantizePass : Pass<"stablehlo-prepare-srq-quantize", "mlir::func::FuncOp"> { let summary = "Prepare StableHLO dialect for static range quantization."; let constructor = "CreatePrepareSrqQuantizePass()"; - let dependentDialects = ["stablehlo::StablehloDialect", "quant::QuantizationDialect"]; + let dependentDialects = ["stablehlo::StablehloDialect", "quant::QuantizationDialect", "quantfork::QuantizationForkDialect"]; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc index 78e056efef8d2b..12ccddcce58ba7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc @@ -71,8 +71,8 @@ class PrepareSrqQuantizePass }; using ReplaceStatsWithQDQs = - quant::ConvertStatsToQDQs; + quant::ConvertStatsToQDQs; void PrepareSrqQuantizePass::runOnOperation() { func::FuncOp func = getOperation(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int-no-chlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int-no-chlo.mlir index df3886645a3112..95a247ffdc19b8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int-no-chlo.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int-no-chlo.mlir @@ -125,3 +125,13 @@ func.func @uniform_quantized_convolution(%arg0: tensor, %arg1: tens -> tensor> return } + +// ----- + +// CHECK-LABEL: func @uniform_quantize_dot_hybrid +func.func @uniform_quantize_dot_hybrid(%arg0: tensor, %arg1: tensor) { + // CHECK-NOT: chlo + %0 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> + %1 = "mhlo.dot" (%arg0, %0): (tensor, tensor>) -> tensor + return +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 186a576c3731c3..d01c18bce6c607 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -84,13 +84,13 @@ func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor< // CHECK-LABEL: func @uniform_quantize_add func.func @uniform_quantize_add(%arg0: tensor, %arg1: tensor) -> () { - // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor - // CHECK-DAG: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor - // CHECK-DAG: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1:.*]], %[[VAL3:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4:.*]], %[[VAL5:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6:.*]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor - // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9:.*]] : (tensor) -> tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1]], %[[VAL3]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> %1 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> %2 = mhlo.add %0, %1: (tensor>, tensor>) -> tensor> @@ -101,13 +101,13 @@ func.func @uniform_quantize_add(%arg0: tensor, %arg1: tensor) // CHECK-LABEL: func @uniform_quantize_add_int4 func.func @uniform_quantize_add_int4(%arg0: tensor, %arg1: tensor) -> () { - // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor - // CHECK-DAG: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor - // CHECK-DAG: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1:.*]], %[[VAL3:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4:.*]], %[[VAL5:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6:.*]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor - // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9:.*]] : (tensor) -> tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1]], %[[VAL3]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> %1 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> %2 = mhlo.add %0, %1: (tensor>, tensor>) -> tensor> @@ -243,17 +243,17 @@ func.func @uniform_quantize_requantize_and_dequantize(%arg0: tensor) -> // CHECK-LABEL: func @uniform_quantize_dot_dequantize func.func @uniform_quantize_dot_dequantize(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor - // CHECK: %[[VAL3:.*]] = chlo.broadcast_subtract %[[VAL1:.*]], %[[VAL2:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[VAL2:.*]] : (tensor, tensor) -> tensor // CHECK: %[[VAL5:.*]] = mhlo.convert %[[VAL4:.*]] : (tensor) -> tensor - // CHECK: %[[VAL7:.*]] = chlo.broadcast_subtract %[[VAL5:.*]], %[[VAL6:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL8:.*]] = "mhlo.dot"(%[[VAL3:.*]], %[[VAL7:.*]]) : (tensor, tensor) -> tensor - // CHECK: %[[VAL10:.*]] = chlo.broadcast_multiply %[[VAL8:.*]], %[[VAL9:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL12:.*]] = chlo.broadcast_add %[[VAL10:.*]], %[[VAL11:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL13:.*]] = mhlo.floor %[[VAL12:.*]] : tensor - // CHECK: %[[VAL15:.*]] = chlo.broadcast_add %[[VAL13:.*]], %[[VAL14:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL16:.*]] = mhlo.convert %[[VAL15:.*]] : (tensor) -> tensor - // CHECK: %[[VAL19:.*]] = mhlo.clamp %[[VAL17:.*]], %[[VAL16:.*]], %[[VAL18:.*]] : (tensor, tensor, tensor) -> tensor - // CHECK: %[[VAL20:.*]] = mhlo.convert %[[VAL19:.*]] : (tensor) -> tensor + // CHECK: %[[VAL7:.*]] = chlo.broadcast_subtract %[[VAL5]], %[[VAL6:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL8:.*]] = "mhlo.dot"(%[[VAL3]], %[[VAL7]]) : (tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = chlo.broadcast_multiply %[[VAL8]], %[[VAL9:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL12:.*]] = chlo.broadcast_add %[[VAL10]], %[[VAL11:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL13:.*]] = mhlo.floor %[[VAL12]] : tensor + // CHECK: %[[VAL15:.*]] = chlo.broadcast_add %[[VAL13]], %[[VAL14:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL16:.*]] = mhlo.convert %[[VAL15]] : (tensor) -> tensor + // CHECK: %[[VAL19:.*]] = mhlo.clamp %[[VAL17:.*]], %[[VAL16]], %[[VAL18:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL20:.*]] = mhlo.convert %[[VAL19]] : (tensor) -> tensor %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> %1 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> %2 = "mhlo.dot" (%0, %1) : (tensor>, tensor>) -> tensor> @@ -308,3 +308,28 @@ func.func @uniform_quantized_convolution(%arg0: tensor, %arg1: tens -> tensor> return } + +// ----- + +// CHECK-LABEL: func @uniform_quantize_dot_hybrid +func.func @uniform_quantize_dot_hybrid(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[VAL2:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL5:.*]] = chlo.broadcast_multiply %[[VAL3]], %[[VAL4:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = "mhlo.dot"(%[[VAL6:.*]], %[[VAL5]]) : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = chlo.broadcast_add %[[VAL7]], %[[VAL8:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.floor %[[VAL9]] : tensor + %0 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> + %1 = "mhlo.dot" (%arg0, %0): (tensor, tensor>) -> tensor + return %1: tensor +} + +// ----- + +func.func @uniform_quantize_dot_hybrid_result_type_not_float(%arg0: tensor, %arg1: tensor) { + %0 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> + // expected-error@+2 {{Unsupported result element type for mhlo.dot}} + // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} + %1 = "mhlo.dot" (%arg0, %0): (tensor, tensor>) -> tensor> + return +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir index 67e57d23c200b4..4c7d909094fe27 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir @@ -9,13 +9,13 @@ func.func @main(%arg0: tensor) -> tensor { } // CHECK: %[[cst:.*]] = stablehlo.constant -// CHECK: %[[q1:.*]] = stablehlo.uniform_quantize %arg0 +// CHECK: %[[q1:.*]] = "quantfork.qcast"(%arg0) // CHECK-SAME: quant.uniform -// CHECK: %[[dq1:.*]] = stablehlo.uniform_dequantize %[[q1]] +// CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]]) // CHECK-SAME: quant.uniform // CHECK: %[[dot:.*]] = stablehlo.dot %[[dq1]], %[[cst]] -// CHECK: %[[q2:.*]] = stablehlo.uniform_quantize %[[dot]] +// CHECK: %[[q2:.*]] = "quantfork.qcast"(%[[dot]]) // CHECK-SAME: quant.uniform> -// CHECK: %[[dq2:.*]] = stablehlo.uniform_dequantize %[[q2]] +// CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]]) // CHECK-SAME: quant.uniform> // CHECK: return %[[dq2]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 1c9e70d1a260a0..d32fbcc7ed853c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -1,6 +1,6 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") @@ -45,7 +45,7 @@ genrule( "passes/quantized_function_library.h", ], cmd = "$(location gen_quantized_function_library) --output_file $(RULEDIR)/passes/quantized_function_library.h --src '$(SRCS)'", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tools = ["gen_quantized_function_library"], ) @@ -57,7 +57,7 @@ cc_library( hdrs = [ "passes/utils.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":quantization_options_proto_cc", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", @@ -76,7 +76,7 @@ cc_library( hdrs = [ "passes/manipulate_model_attr.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", @@ -93,7 +93,7 @@ cc_library( hdrs = [ "passes/remove_identity_op_pattern.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:IR", @@ -118,7 +118,7 @@ td_library( "passes/tf_quant_ops.td", "passes/utils.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils_td_files", @@ -130,7 +130,7 @@ td_library( gentbl_cc_library( name = "cast_bf16_ops_to_f32_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -144,7 +144,7 @@ gentbl_cc_library( gentbl_cc_library( name = "prepare_lifting_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -158,7 +158,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -172,7 +172,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_drq_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -186,7 +186,7 @@ gentbl_cc_library( gentbl_cc_library( name = "prepare_quantize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -200,7 +200,7 @@ gentbl_cc_library( gentbl_cc_library( name = "quantize_composite_functions_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -214,7 +214,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_quant_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -234,7 +234,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -248,7 +248,7 @@ gentbl_cc_library( gentbl_cc_library( name = "convert_tpu_model_to_cpu_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -262,7 +262,7 @@ gentbl_cc_library( gentbl_cc_library( name = "post_quantize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -276,7 +276,7 @@ gentbl_cc_library( gentbl_cc_library( name = "preprocess_op_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -296,7 +296,7 @@ cc_library( "passes/tf_quant_ops.h.inc", ], hdrs = ["passes/tf_quant_ops.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", @@ -324,7 +324,7 @@ cc_library( "ops/tf_op_quant_spec.cc", ], hdrs = ["ops/tf_op_quant_spec.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", @@ -340,7 +340,7 @@ cc_library( "ops/uniform_op_quant_spec.cc", ], hdrs = ["ops/uniform_op_quant_spec.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":tf_quant_ops", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", @@ -353,7 +353,7 @@ cc_library( gentbl_cc_library( name = "replace_cast_hacks_with_tf_xla_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -414,7 +414,7 @@ cc_library( "passes/constants.h", "passes/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":manipulate_model_attr", ":pass_utils", @@ -489,7 +489,7 @@ cc_library( hdrs = [ "quantize_preprocess.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":passes", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", @@ -517,7 +517,7 @@ cc_library( hdrs = [ "quantize_passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":pass_utils", ":passes", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index d8a0c975aac48a..97553651573e31 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -5,7 +5,7 @@ load( ) load( "//tensorflow:tensorflow.default.bzl", - "get_compatible_with_cloud", + "get_compatible_with_portable", "tf_kernel_library", "tf_py_strict_test", ) @@ -29,7 +29,7 @@ cc_library( name = "calibrator_singleton_impl", srcs = ["calibrator_singleton.cc"], hdrs = ["calibrator_singleton.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/container:flat_hash_map", @@ -42,7 +42,7 @@ cc_library( cc_library( name = "calibrator_singleton", hdrs = ["calibrator_singleton.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = if_static([":calibrator_singleton_impl"]) + [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -65,7 +65,7 @@ tf_cc_test( tf_kernel_library( name = "custom_aggregator_op", srcs = ["custom_aggregator_op.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow:__pkg__", "//tensorflow/compiler/mlir/quantization/tensorflow/python:__pkg__", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index f121b219d2ed4a..86e6efc5ec43a2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -1,6 +1,6 @@ load( "//tensorflow:tensorflow.default.bzl", - "get_compatible_with_cloud", + "get_compatible_with_portable", ) load( "//tensorflow:tensorflow.bzl", @@ -21,7 +21,7 @@ cc_library( name = "save_variables", srcs = ["save_variables.cc"], hdrs = ["save_variables.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:framework", @@ -68,7 +68,7 @@ cc_library( name = "const_op_size", srcs = ["const_op_size.cc"], hdrs = ["const_op_size.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_remaining_ops", @@ -97,7 +97,7 @@ cc_library( name = "convert_asset_args", srcs = ["convert_asset_args.cc"], hdrs = ["convert_asset_args.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:import_model", @@ -128,7 +128,7 @@ tf_cc_test( cc_library( name = "status_macro", hdrs = ["status_macro.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:macros", "@com_google_absl//absl/status", @@ -150,7 +150,7 @@ cc_library( name = "run_passes", srcs = ["run_passes.cc"], hdrs = ["run_passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:mlir_dump", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -170,7 +170,7 @@ cc_library( hdrs = [ "constant_fold.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils", "//tensorflow/compiler/mlir/tensorflow", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD index 879ccc88de0583..cd755f83e2d963 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -13,7 +13,7 @@ cc_library( name = "mlir_dump", srcs = ["mlir_dump.cc"], hdrs = ["mlir_dump.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:path", @@ -31,7 +31,7 @@ cc_library( tf_cc_test( name = "mlir_dump_test", srcs = ["mlir_dump_test.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":mlir_dump", "//tensorflow/tsl/platform:path", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 38363200a62b4a..8faafa1c17770c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -5,7 +5,7 @@ load( ) load( "//tensorflow:tensorflow.default.bzl", - "get_compatible_with_cloud", + "get_compatible_with_portable", "tf_py_test", "tf_python_pybind_extension", ) @@ -25,7 +25,7 @@ cc_library( name = "quantize_model_cc_impl", srcs = ["quantize_model.cc"], hdrs = ["quantize_model.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [ # Directly linked to `libtensorflow_cc.so` or # `_pywrap_tensorflow_internal.so` if static build. @@ -82,7 +82,7 @@ cc_library( cc_library( name = "quantize_model_cc", hdrs = ["quantize_model.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = if_static([":quantize_model_cc_impl"]) + [ "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py index d7dc023e1dc87c..5983c2b581953a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py @@ -907,7 +907,7 @@ def _save_tf1_model( inputs: Mapping[str, core.Tensor], outputs: Mapping[str, core.Tensor], init_op: Optional[ops.Operation] = None, - assets_collection: Optional[Sequence[ops.Tensor]] = None, + assets_collection: Optional[Sequence[core.Symbol]] = None, ) -> None: """Saves a TF1 model. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index a01f881d88e04c..80955803f1cc5b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -1,5 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( @@ -14,7 +14,7 @@ cc_library( name = "fake_quant_utils", srcs = ["fake_quant_utils.cc"], hdrs = ["fake_quant_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", @@ -29,7 +29,7 @@ td_library( srcs = [ "lift_as_function_call_utils.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:FuncTdFiles", ], @@ -39,7 +39,7 @@ cc_library( name = "lift_as_function_call_utils", srcs = ["lift_as_function_call_utils.cc"], hdrs = ["lift_as_function_call_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", @@ -56,7 +56,7 @@ cc_library( name = "tf_to_uniform_attribute_utils", srcs = ["tf_to_uniform_attribute_utils.cc"], hdrs = ["tf_to_uniform_attribute_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", @@ -73,7 +73,7 @@ cc_library( name = "tf_to_xla_attribute_utils", srcs = ["tf_to_xla_attribute_utils.cc"], hdrs = ["tf_to_xla_attribute_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index 162572e5d7c298..ddea76059bfdcd 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow:pytype.default.bzl", "pytype_library") +load("//tensorflow:strict.default.bzl", "py_strict_test") +load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow/tsl:tsl.default.bzl", "tsl_pybind_extension") package( @@ -36,7 +37,7 @@ tsl_pybind_extension( ], ) -pytype_library( +pytype_strict_library( name = "stablehlo", srcs = ["stablehlo.py"], srcs_version = "PY3", @@ -46,7 +47,7 @@ pytype_library( ], ) -py_test( +py_strict_test( name = "stablehlo_test", srcs = ["stablehlo_test.py"], python_version = "PY3", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 60dcd27cff61d4..474ab34155e8eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow:strict.default.bzl", "py_strict_library") +load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") @@ -30,7 +31,7 @@ td_library( "ir/tf_ops.td", "ir/tfrt_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:CallInterfacesTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -43,7 +44,7 @@ td_library( gentbl_cc_library( name = "tensorflow_op_interfaces_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-interface-decls"], @@ -64,7 +65,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_struct_doc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-dialect-doc"], @@ -103,7 +104,7 @@ cc_library( gentbl_cc_library( name = "tensorflow_all_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -123,7 +124,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_tfrt_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -156,7 +157,7 @@ tf_ops_category_list = [ [[ gentbl_cc_library( name = "tensorflow_" + target["name"] + "_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -183,7 +184,7 @@ tf_ops_category_list = [ gentbl_cc_library( name = "tensorflow_remaining_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -209,7 +210,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_saved_model_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -235,7 +236,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_executor_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -266,7 +267,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_device_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -294,7 +295,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_canonicalize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -326,7 +327,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_legalize_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -753,7 +754,7 @@ cc_library( gentbl_cc_library( name = "decompose_resource_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -807,7 +808,7 @@ td_library( srcs = [ "transforms/rewrite_util.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:OpBaseTdFiles", ], @@ -829,7 +830,7 @@ cc_library( gentbl_cc_library( name = "tf_data_optimization_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -1109,7 +1110,7 @@ cc_library( gentbl_cc_library( name = "tf_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -1132,7 +1133,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_device_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -1155,7 +1156,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_savedmodel_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -1178,7 +1179,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_test_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -2233,7 +2234,7 @@ filegroup( gentbl_cc_library( name = "tensorflow_optimize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -2329,6 +2330,7 @@ tf_gen_op_wrapper_py( name = "gen_mlir_passthrough_op_py", out = "gen_mlir_passthrough_op.py", compatible_with = [], + py_lib_rule = py_strict_library, deps = [":mlir_passthrough_op"], ) @@ -2338,7 +2340,7 @@ tf_gen_op_wrapper_py( # without linking any of the other tensorflow passes. gentbl_cc_library( name = "lower_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 7f066b3f327eb5..f66c996f32a888 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -1922,7 +1922,7 @@ static LogicalResult inferConvReturnTypeComponents( // Skip if input or filter size is dynamic. if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue; // Calculate the expected_output_size. - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( input_ty.getDimSize(dim), filter_ty.getDimSize(i), get_int(dilations[dim]), stride, padding, &expected_output_size, &pad_low, &pad_high); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc index 01c6c6c68b38ad..8cce823ae5233c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc @@ -71,6 +71,19 @@ LogicalResult _TfrtGetResourceOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PwStreamResults +//===----------------------------------------------------------------------===// + +mlir::LogicalResult PwStreamResultsOp::verify() { + if (getArgs().size() != getNames().size()) { + return emitOpError() + << "has a mismatch between the number of arguments and their names (" + << getArgs().size() << " vs. " << getNames().size() << ")"; + } + return mlir::success(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td index bd6f35db525ffa..a0e2935255e95a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td @@ -63,4 +63,37 @@ def TF__TfrtGetResourceOp : TF_Op<"_TfrtGetResource", let hasVerifier = 1; } +// TODO(chky): Consider adding this op to tensorflow core ops. +def TF_PwStreamResultsOp : TF_Op<"PwStreamResults"> { + let summary = "Streams results back to the controller"; + + let description = [{ + This op is a TensorFlow op that represents "streamed outputs", where + intermediate results can be returned immediately without waiting for the + entire signature computation to complete. + + This op takes `args` with their `names` (their cardinality must match) and + sends the given argument tensors back to the serving controller. This + triggers a controller-side stream callback (see `ScopedStreamCallback`). + + In addition to the listed attributes, this op has two "hidden" attributes + that do not exist in SavedModel but are dynamically populated by the serving + runtime: + + * `_controller_address`: Address of the remote instance to which tensors + will be sent via e.g. RPC. + * `_callback_id`: Identifier for the callback to be called from the + controller. See `ScopedStreamCallback`. + }]; + + let arguments = (ins + Variadic : $args, + StrArrayAttr : $names + ); + + TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; + + let hasVerifier = 1; +} + #endif // TFRT_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 1da0ee417402b9..f369fbde5f6fa3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1555,6 +1555,42 @@ func.func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x func.return %0 : tensor<3x8x8x16xf32> } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_tf_style( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<5xi32>) -> tensor { +// CHECK %[[VAL_0:.*]] = "tf.BroadcastTo"(%[[ARG_0]], %[[ARG_1]]) : (tensor, tensor<5xi32>) -> tensor +// CHECK return %[[VAL_0]] : tensor +func.func @dynamic_broadcast_in_dim_tf_style(%arg0: tensor, %arg1: tensor<5xi32>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>} : (tensor, tensor<5xi32>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_general_case_expand_back_dims( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<4xi32>) -> tensor { +// CHECK %[[CST_0:.*]] = "tf.Const"() {value = dense<2> : tensor} : () -> tensor +// CHECK %[[VAL_0:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_0]]) : (tensor, tensor) -> tensor +// CHECK %[[CST_1:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor +// CHECK %[[VAL_1:.*]] = "tf.ExpandDims"(%[[VAL_0]], %[[CST_1]]) : (tensor, tensor) -> tensor +// CHECK %[[VAL_2:.*]] = "tf.BroadcastTo"(%[[VAL_1]], %[[ARG_1]]) : (tensor, tensor<4xi32>) -> tensor +// CHECK return %[[VAL_2]] : tensor +func.func @dynamic_broadcast_in_dim_general_case_expand_back_dims(%arg0: tensor, %arg1: tensor<4xi32>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<4xi32>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_general_case_expand_middle_dim( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<4xi32>) -> tensor { +// CHECK %[[CST_0:.*]] = "tf.Const"() {value = dense<2> : tensor} : () -> tensor +// CHECK %[[VAL_0:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_0]]) : (tensor, tensor) -> tensor +// CHECK %[[VAL_1:.*]] = "tf.BroadcastTo"(%[[VAL_0]], %[[ARG_1]]) : (tensor, tensor<4xi32>) -> tensor +// CHECK return %[[VAL_1]] : tensor +func.func @dynamic_broadcast_in_dim_general_case_expand_middle_dim(%arg0: tensor, %arg1: tensor<4xi32>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1, 3]> : tensor<3xi64>} : (tensor, tensor<4xi32>) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: func @convert_dot_general( // CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x6x5x1xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 71fbf7cca9ee58..c65bd89421132c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1297,6 +1297,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return %0 : tensor<*xf32> } + func.func @xla_call_module_parsing_error(%arg0: tensor) -> tensor<*xf32> { + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], module = "invalid-stablehlo-module", platforms = [], version = 4 : i64} : (tensor, tensor<*xf32>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> + } + // CHECK-LABEL: func @xla_host_compute_mlir_empty_module func.func @xla_host_compute_mlir_empty_module(%arg0: tensor<2xf32>) -> tensor<*xf32> { // CHECK: "tf._XlaHostComputeMlir" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir new file mode 100644 index 00000000000000..3fb11e56172276 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir @@ -0,0 +1,15 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// Tests for TensorFlow TFRT ops with custom verifiers. + +//===--------------------------------------------------------------------===// +// Test TF operations (tf.*) +//===--------------------------------------------------------------------===// + +// CHECK-LABEL: func @testPwStreamResults +func.func @testPwStreamResults(%arg0: tensor, %arg1: tensor) { + "tf.PwStreamResults"(%arg0, %arg1) {names = ["foo", "bar"]} : (tensor, tensor) -> () + return +} + +// ----- diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 3ffa9a120ca573..a69aef8edc5695 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -32,6 +32,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -502,7 +503,7 @@ class Convert2DConvOp : public OpConversionPattern, int64_t output_size; int64_t pad_low_int64; int64_t pad_high_int64; - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( conv_op.getLhs().getType().cast().getDimSize( input_spatial_dim[i]), conv_op.getRhs().getType().cast().getDimSize( @@ -3741,6 +3742,33 @@ arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input, return rewriter.create(output.getLoc(), attr_type, attr); } +Value ExpandedDynamicShape(PatternRewriter& rewriter, Value input, + DenseIntElementsAttr broadcast_dimensions, + Value output) { + assert(output.getType().cast() && + "output type must be of ShapedType"); + int64_t output_rank = output.getType().cast().getRank(); + llvm::SmallVector expanded_dimensions; + llvm::SmallSet broadcast_dimensions_values; + for (auto x : llvm::enumerate(broadcast_dimensions)) { + broadcast_dimensions_values.insert(x.value().getSExtValue()); + } + for (int64_t i = 0; i < output_rank; i++) { + if (!broadcast_dimensions_values.contains(i)) { + expanded_dimensions.push_back(i); + } + } + Value expanded_input = input; + for (int64_t i : expanded_dimensions) { + auto index_attr = DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI64Type()), {i}); + Value index = rewriter.create(output.getLoc(), index_attr); + expanded_input = rewriter.create(output.getLoc(), + expanded_input, index); + } + return expanded_input; +} + #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc" /// Performs the lowering to TF dialect. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 0261783da7c33f..7ac8934e5bb915 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -39,6 +39,7 @@ def IsNotTFStyleBroadcast : Constraint>, // Return intermediate shape before broadcasting, wrapped in a constant op. def ExpandedShape : NativeCodeCall<"ExpandedShape($_builder, $0, $1, $2)">; +def ExpandedDynamicShape : NativeCodeCall<"ExpandedDynamicShape($_builder, $0, $1, $2)">; def : Pat<(MHLO_ConstantOp:$output $value), (TF_ConstOp $value), [(TF_Tensor $output)]>; @@ -183,6 +184,17 @@ def : Pat<(MHLO_BroadcastInDimOp:$output $input, $broadcast_dimensions), (ExpandedShape $input, $broadcast_dimensions, $output)), (ShapeToConst $output)), [(IsNotTFStyleBroadcast $broadcast_dimensions, $output)]>; +// Dynamism op +def : Pat<(MHLO_DynamicBroadcastInDimOp:$output $input, $output_dimensions, + $broadcast_dimensions, $expanding_dimensions_unused, $nonexpanding_dimensions_unused), + (TF_BroadcastToOp $input, $output_dimensions), + [(IsTFStyleBroadcast $broadcast_dimensions, $output)]>; +def : Pat<(MHLO_DynamicBroadcastInDimOp:$output $input, $output_dimensions, + $broadcast_dimensions, $expanding_dimensions_unused, $nonexpanding_dimensions_unused), + (TF_BroadcastToOp (ExpandedDynamicShape $input, $broadcast_dimensions, $output), $output_dimensions), + [(IsNotTFStyleBroadcast $broadcast_dimensions, $output)]>; + + def : Pat<(MHLO_TransposeOp $arg, $permutation), (TF_TransposeOp $arg, (TF_ConstOp $permutation))>; def : Pat<(MHLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 19c6050aaf21c8..92bbc1f5a9910b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -1191,49 +1191,49 @@ bool ShapeInference::InferShapeForCaseRegion(CaseRegionOp op) { bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { tensorflow::XlaCallModuleLoader* loader; - { - const auto [it, inserted] = xla_call_module_loaders_.insert({op, nullptr}); - + if (auto it = xla_call_module_loaders_.find(op); + it != xla_call_module_loaders_.end()) { + loader = it->second.get(); + } else { // Lazily parse XlaCallModule's embedded HLO module and cache the loader to // avoid repeatedly parsing the module. - if (inserted) { - std::vector dim_args_spec; - for (auto attr : op.getDimArgsSpec().getAsRange()) { - dim_args_spec.push_back(attr.getValue().str()); - } - std::vector disabled_checks; - for (auto attr : op.getDisabledChecks().getAsRange()) { - disabled_checks.push_back(attr.getValue().str()); - } - std::vector platforms; - for (auto attr : op.getPlatforms().getAsRange()) { - platforms.push_back(attr.getValue().str()); - } - // Always use the first platform. The assumption is that shape inference - // results should be the same regardless of which platform is chosen. - // Very old versions of the op have an empty platforms attribute. - std::string loading_platform = - (platforms.empty() ? "CPU" : platforms.front()); - - // It is a terrible idea to have local MLIR contexts so we need to - // register extensions here, again. - mlir::DialectRegistry registry; - registry.insert(); - mlir::func::registerAllExtensions(registry); - xla_call_module_context_.appendDialectRegistry(registry); - - auto l = tensorflow::XlaCallModuleLoader::Create( - &xla_call_module_context_, op.getVersion(), op.getModule().str(), - std::move(dim_args_spec), std::move(disabled_checks), - std::move(platforms), std::move(loading_platform)); - if (!l.ok()) { - LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: " - << l.status().ToString() << "\n"); - return false; - } - it->second = *std::move(l); + + std::vector dim_args_spec; + for (auto attr : op.getDimArgsSpec().getAsRange()) { + dim_args_spec.push_back(attr.getValue().str()); + } + std::vector disabled_checks; + for (auto attr : op.getDisabledChecks().getAsRange()) { + disabled_checks.push_back(attr.getValue().str()); + } + std::vector platforms; + for (auto attr : op.getPlatforms().getAsRange()) { + platforms.push_back(attr.getValue().str()); + } + // Always use the first platform. The assumption is that shape inference + // results should be the same regardless of which platform is chosen. + // Very old versions of the op have an empty platforms attribute. + std::string loading_platform = + (platforms.empty() ? "CPU" : platforms.front()); + + // It is a terrible idea to have local MLIR contexts so we need to + // register extensions here, again. + mlir::DialectRegistry registry; + registry.insert(); + mlir::func::registerAllExtensions(registry); + xla_call_module_context_.appendDialectRegistry(registry); + + auto l = tensorflow::XlaCallModuleLoader::Create( + &xla_call_module_context_, op.getVersion(), op.getModule().str(), + std::move(dim_args_spec), std::move(disabled_checks), + std::move(platforms), std::move(loading_platform)); + if (!l.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: " + << l.status().ToString() << "\n"); + return false; } + it = xla_call_module_loaders_.insert({op, *std::move(l)}).first; loader = it->second.get(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc index d4a05aae890ecd..484523dfd7c912 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "llvm/ADT/SmallVector.h" @@ -55,6 +54,36 @@ using OpToParallelIdsMap = using OpToOpsMap = absl::flat_hash_map>; +// Many operations have the same dependency and parallel id set. We cache the +// processed result of these operations to speed execution. +struct OpCacheEntry { + Operation* template_op; + llvm::SmallVector preds_in_reverse_program_order; +}; + +struct OpCacheKey { + const llvm::SmallVector deps; + const GroupIdToBranchIdMap& group_id_to_branch_id_map; + + template + friend H AbslHashValue(H h, const OpCacheKey& c) { + for (Operation* dep : c.deps) { + h = H::combine(std::move(h), dep); + } + for (auto [group_id, branch_id] : c.group_id_to_branch_id_map) { + h = H::combine(std::move(h), group_id, branch_id); + } + return h; + } + + bool operator==(const OpCacheKey& other) const { + return deps == other.deps && + group_id_to_branch_id_map == other.group_id_to_branch_id_map; + } +}; + +using OpCache = absl::flat_hash_map; + #define GEN_PASS_DEF_EXECUTORUPDATECONTROLDEPENDENCIESPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" @@ -169,27 +198,63 @@ LogicalResult FillOpToParallelIdsMap( // Computes and sets direct control inputs for `op`. Also fills // `active_transitive_preds` and `inactive_transitive_preds` for `op`. -void -UpdateControlDependenciesForOp( +// +// `active_transitive_preds` are those dominated by `op`: taking a dependency +// on `op` will also ensure all `active_transitive_preds[op]` are waited +// for. +// +// `inactive_transitive_preds` are transitive dependencies of op in the original +// graph but are not dominated by `op`. (They run in a different parallel +// execution group). They must be separately considered when processing +// successor operations. +void UpdateControlDependenciesForOp( Operation* op, const TF::SideEffectAnalysis::Info& analysis_for_func, const OpToParallelIdsMap& op_to_parallel_ids_map, + OpCache& op_cache, OpToOpsMap& active_transitive_preds, OpToOpsMap& inactive_transitive_preds, int& num_control_inputs_removed, int& num_control_inputs_added, int& num_invalid_dependencies) { + auto& op_inactive = inactive_transitive_preds[op]; + auto& op_active = active_transitive_preds[op]; + + llvm::SmallVector control_deps = + analysis_for_func.DirectControlPredecessors(op); + OpCacheKey key = { + control_deps, + GetGroupIdToBranchIdMap(op, op_to_parallel_ids_map) + }; + + // We matched with another op in the cache. We will have the same active and + // inactive dependency sets and control inputs, except we swap out our current + // op for the template op in the active set. + if (op_cache.contains(key)) { + auto& entry = op_cache[key]; + op_active = active_transitive_preds[entry.template_op]; + op_active.insert(op); + op_active.erase(entry.template_op); + + op_inactive = inactive_transitive_preds[entry.template_op]; + ClearControlInputs(op, num_control_inputs_removed); + SetControlInputs(op, entry.preds_in_reverse_program_order, + num_control_inputs_added); + return; + } + + op_active.insert(op); + + // First iterate over all direct control dependencies and collect the set of + // potential active dependencies. absl::flat_hash_set pred_set; - active_transitive_preds[op].insert(op); - for (Operation* pred : analysis_for_func.DirectControlPredecessors(op)) { - // Propagate inactive transitive dependencies from `pred` to `op`. - inactive_transitive_preds[op].insert( - inactive_transitive_preds[pred].begin(), - inactive_transitive_preds[pred].end()); + for (Operation* pred : control_deps) { // Inactive transitive predecessors of `pred` are potential direct // predecessors of `op` (they are not tracked by `pred`). for (Operation* transitive_pred : inactive_transitive_preds[pred]) { pred_set.insert(transitive_pred); + op_inactive.insert(transitive_pred); } + if (IsValidDependency(pred, op, op_to_parallel_ids_map)) { // We know that any active transitive predecessors will still be covered // by (pred, op), so we don't have to add them to `potential_preds`. @@ -197,40 +262,55 @@ UpdateControlDependenciesForOp( } else { // Active transitive predecessors will not be covered by (pred, op) // anymore, so add them all as candidates. - for (Operation* transitive_pred : active_transitive_preds[pred]) { - pred_set.insert(transitive_pred); - } + pred_set.insert( + active_transitive_preds[pred].begin(), + active_transitive_preds[pred].end()); ++num_invalid_dependencies; } } - std::vector potential_preds(pred_set.begin(), pred_set.end()); - std::sort(potential_preds.begin(), potential_preds.end(), IsAfterInBlock()); + // Now collect a list of valid dependencies and sort them in program order. + std::vector potential_preds; + potential_preds.reserve(pred_set.size()); - llvm::SmallVector preds_in_reverse_program_order; - for (Operation* potential_pred : potential_preds) { - bool is_valid = - IsValidDependency(potential_pred, op, op_to_parallel_ids_map); - if (!is_valid) { + for (Operation* potential_pred : pred_set) { + if (IsValidDependency(potential_pred, op, op_to_parallel_ids_map)) { + potential_preds.push_back(potential_pred); + } else { // We don't keep the (pred, op) dependency, so all active transitive // dependencies become inactive. - inactive_transitive_preds[op].insert( + op_inactive.insert( active_transitive_preds[potential_pred].begin(), active_transitive_preds[potential_pred].end()); - } else if (!active_transitive_preds[op].contains(potential_pred)) { + } + } + std::sort(potential_preds.begin(), potential_preds.end(), IsAfterInBlock()); + + // Finally, accumulate dependencies until we have coverage over all active + // dependencies. + llvm::SmallVector preds_in_reverse_program_order; + for (Operation* potential_pred : potential_preds) { + if (!op_active.contains(potential_pred)) { // `potential_pred` is not an active transitive predecessor of `op` yet, // so we must add it as a direct predecessor. preds_in_reverse_program_order.push_back(potential_pred); // We keep the (pred, op) dependency, so all active transitive // dependencies stay active. - active_transitive_preds[op].insert( + op_active.insert( active_transitive_preds[potential_pred].begin(), active_transitive_preds[potential_pred].end()); } } + + for (Operation* pred : op_active) { + op_inactive.erase(pred); + } + ClearControlInputs(op, num_control_inputs_removed); SetControlInputs(op, preds_in_reverse_program_order, num_control_inputs_added); + + op_cache[key] = {op, preds_in_reverse_program_order}; } // This function updates all control dependencies in `func`, represented as @@ -259,6 +339,7 @@ LogicalResult UpdateAllControlDependencies( // Maps island ops to parallel IDs of the wrapped ops. OpToParallelIdsMap op_to_parallel_ids_map; + OpCache op_cache; OpToOpsMap active_transitive_preds, inactive_transitive_preds; // We call `VerifyExportSuitable` in the beginning of the pass, so every @@ -275,6 +356,7 @@ LogicalResult UpdateAllControlDependencies( op, analysis_for_func, op_to_parallel_ids_map, + op_cache, active_transitive_preds, inactive_transitive_preds, num_control_inputs_removed, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 74cf842327062a..63ed9aac1db4e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -69,8 +69,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -using llvm::dyn_cast; -using llvm::isa; using mlir::BlockArgument; using mlir::Dialect; using mlir::Operation; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index f07af4f8b85a2c..c90a0b419b1be3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -218,10 +218,10 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, Status result = CreateFileForDumping(name, &os, &filepath, dirname); if (!result.ok()) return std::string(result.message()); + LOG(INFO) << "Dumping MLIR operation '" << op->getName().getStringRef().str() + << "' to '" << filepath << "'"; if (pass_manager) PrintPassPipeline(*pass_manager, op, *os); op->print(*os, mlir::OpPrintingFlags().useLocalScope()); - LOG(INFO) << "Dumped MLIR operation '" << op->getName().getStringRef().str() - << "' to '" << filepath << "'"; return filepath; } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 3406b6d0e4739d..19f9bae6ca920f 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/core/tpu/kernels:tpu_compile_op_support", "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", "//tensorflow/core/tpu/kernels:tpu_util_hdrs", + "//tensorflow/tsl/platform:error_logging", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/log", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc index 59ebb5e913b3a0..cf21467ad7ff6d 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/tpu_compile.h" +#include "tensorflow/tsl/platform/error_logging.h" #include "tensorflow/tsl/platform/status.h" #include "tensorflow/tsl/platform/statusor.h" @@ -95,6 +96,9 @@ constexpr char kOldBridgeWithFallbackModeSuccess[] = // Old Bridge failed in fallback (was run because MLIR bridge failed first). constexpr char kOldBridgeWithFallbackModeFailure[] = "kOldBridgeWithFallbackModeFailure"; +// Name of component for error logging. This name is fixed and required to +// enable logging. +constexpr char kBridgeComponent[] = "TFXLABridge"; // Time the execution of kernels (in CPU cycles). Meant to be used as RAII. struct CompilationTimer { @@ -227,9 +231,19 @@ tsl::StatusOr LegalizeMlirToHlo( } else if (!enable_op_fallback) { // Don't fallback to the old bridge if op-by-op fallback isn't enabled. mlir_second_phase_count->GetCell(kMlirModeFailure)->IncrementBy(1); + if (!mlir_bridge_status.ok()) { + tsl::error_logging::Log(kBridgeComponent, + "TFXLA_API_V1_BRIDGE_NO_FALLBACK", + mlir_bridge_status.ToString()) + .IgnoreError(); + } return mlir_bridge_status; + } else { + tsl::error_logging::Log(kBridgeComponent, + "TFXLA_API_V1_BRIDGE_WITH_FALLBACK_FAIL", + mlir_bridge_status.ToString()) + .IgnoreError(); } - bool filtered_graph = false; if (mlir_bridge_status == CompileToHloGraphAnalysisFailedError()) { VLOG(1) << "Filtered out MLIR computation to XLA HLO using MLIR tf2xla " @@ -263,6 +277,11 @@ tsl::StatusOr LegalizeMlirToHlo( mlir_second_phase_count->GetCell(kOldBridgeWithFallbackModeFailure) ->IncrementBy(1); } + if (!old_bridge_status.ok()) { + tsl::error_logging::Log(kBridgeComponent, "TFXLA_API_V1_OLD_BRIDGE", + mlir_bridge_status.ToString()) + .IgnoreError(); + } return old_bridge_status; } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index f8aada0c497aa3..5733d107fa4574 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/tsl/platform:build_config_root.bzl", "if_static") package( @@ -15,7 +15,7 @@ package( gentbl_cc_library( name = "legalize_tf_patterns_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -35,7 +35,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_legalize_tf_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -54,7 +54,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_xla_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 8130518f2aef73..79220db6087a14 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -131,7 +131,7 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 71); EXPECT_EQ(tf2xla_fallback_count, 295); - EXPECT_EQ(non_categorized_count, 418); + EXPECT_EQ(non_categorized_count, 419); } // Just a counter test to see which ops have duplicate lowerings. This isn't a diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index cce28c45d704ec..be7fceb6632773 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -14,14 +14,18 @@ limitations under the License. ==============================================================================*/ // This file implements logic for lowering TensorFlow dialect to XLA dialect. - +#include #include +#include #include #include #include #include #include #include +#include +#include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -1222,7 +1226,7 @@ class ConvertConvOp : public OpRewritePattern { int64_t pad_high_int64; int64_t input_size = input_ty.getDimSize(dim); if (input_size == ShapedType::kDynamic) return failure(); - tsl::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tsl::Status status = tensorflow::GetWindowedOutputSizeVerbose( input_size, filter_ty.getDimSize(i), dilation, stride, padding, &output_size, &pad_low_int64, &pad_high_int64); if (!status.ok()) return failure(); @@ -4886,7 +4890,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { padding_after = explicit_paddings[2 * spatial_dim + 1]; } int64_t expected_output_size = 0; - auto status = GetWindowedOutputSizeVerboseV2( + auto status = GetWindowedOutputSizeVerbose( input_size, filter_size, dilation, stride, padding, &expected_output_size, &padding_before, &padding_after); if (!status.ok()) return failure(); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc index d8fc61604cebcb..bab46615f90003 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc @@ -20,6 +20,10 @@ limitations under the License. // TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass // rather than its own pass. +#include +#include +#include + #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -31,6 +35,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/lib/monitoring/counter.h" #define DEBUG_TYPE "xla-legalize-tf-types" @@ -38,6 +43,12 @@ namespace mlir { namespace mhlo { namespace { +// TODO: b/290366702 - Temporarily added metrics for debugging. +auto *mlir_tf_quant_op_count = tensorflow::monitoring::Counter<1>::New( + "/tensorflow/core/tf2xla/tf_quant_op_count" /*metric_name*/, + "Counts the number of ops that has qint types" /*metric description*/, + "op_name" /*metric label*/); + bool IsIllegalElementType(Type type) { return type .isagetResults()); + // TODO: b/290366702 - Temporarily added metrics for debugging. + if (llvm::any_of(op->getResultTypes(), IsIllegalType)) { + mlir_tf_quant_op_count->GetCell(std::string(op->getName().getStringRef())) + ->IncrementBy(1); + } return success(); } }; diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 7fd0a0446d2e77..b18a7872d5e0db 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -4,7 +4,7 @@ load( "tf_cc_binary", "tf_cc_test", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "tf_py_strict_test", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", "tf_py_strict_test", "tf_python_pybind_extension") load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") load( "@llvm-project//mlir:tblgen.bzl", @@ -37,7 +37,7 @@ td_library( srcs = [ "ir/tfr_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization/ir:QuantizationOpsTdFiles", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", @@ -52,7 +52,7 @@ td_library( gentbl_cc_library( name = "tfr_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -72,7 +72,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tfr_decompose_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 9032b75d11ee44..db856a8e92f011 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -8,7 +8,7 @@ load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_binary") load("@tf_runtime//:build_defs.bzl", "tfrt_cc_library", "tfrt_cc_test") # Note: keep the following lines separate due to the way copybara works -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "get_compatible_with_portable") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # TF to TFRT kernels conversion. package( @@ -108,7 +108,7 @@ cc_library( name = "tf_jitrt_pipeline", srcs = ["jit/tf_jitrt_pipeline.cc"], hdrs = ["jit/tf_jitrt_pipeline.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", @@ -365,13 +365,10 @@ cc_library( srcs = ["transforms/gpu_passes.cc"], hdrs = ["transforms/gpu_passes.h"], deps = [ - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) @@ -414,7 +411,6 @@ cc_library( ":cost_analysis", ":fallback_converter", ":tensor_array_side_effect_analysis", - ":tfrt_jitrt_stub", ":tfrt_pipeline_options", ":tpu_passes", ":transform_utils", @@ -471,7 +467,6 @@ cc_library( deps = [ ":tf_to_tfrt", ":tfrt_compile_options", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:import_model", @@ -644,8 +639,6 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_test_passes", ], ) @@ -672,7 +665,6 @@ cc_library( ":test_tensor_array_side_effect_analysis", ":tf_jitrt_opdefs", ":tf_to_tfrt", - ":tfrt_jitrt_passes", ":transforms/gpu_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir:passes", @@ -683,7 +675,6 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_sync_opdefs", - "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", @@ -755,7 +746,6 @@ tf_cc_binary( testonly = True, visibility = [":friends"], deps = [ - ":tf_jitrt_kernels_alwayslink", "@tf_runtime//:dtype", "@tf_runtime//:simple_tracing_sink", "@tf_runtime//tools:bef_executor_expensive_kernels", @@ -868,40 +858,6 @@ cc_library( ], ) -cc_library( - name = "tfrt_jitrt_passes", - srcs = ["transforms/tfrt_jitrt_passes.cc"], - deps = [ - ":fallback_converter", - ":tf_jitrt_opdefs", - ":tf_jitrt_pipeline", - ":tfrt_jitrt_stub", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_clustering", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:TransformUtils", - "@tf_runtime//:basic_kernels_opdefs", - "@tf_runtime//backends/jitrt:jitrt_opdefs", - ], - alwayslink = 1, -) - -cc_library( - name = "tfrt_jitrt_stub", - srcs = ["transforms/tfrt_jitrt_stub.cc"], - hdrs = ["transforms/tfrt_jitrt_stub.h"], - deps = [ - ":corert_converter", - ":tfrt_pipeline_options", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:TransformUtils", - ], -) - cc_library( name = "constants", hdrs = ["constants.h"], diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD index 8632e4a71855af..a5a53c3e5150a7 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD @@ -51,7 +51,6 @@ cc_library( deps = [ ":benchmark", "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tfrt:host_context_util", "//tensorflow/compiler/mlir/tfrt:runtime_fallback_executor", "//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline", @@ -61,7 +60,6 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:test_main", "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:mlir_c_runner_utils", "@tf_runtime//:basic_kernels_alwayslink", diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD index 6314fb77c9052a..868f9814127bd2 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD @@ -1,22 +1,19 @@ load("//tensorflow:tensorflow.default.bzl", "pybind_extension", "pybind_library") -load("//tensorflow:strict.default.bzl", "py_strict_test") - -licenses(["notice"]) +load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":__subpackages__"], ) -py_library( +licenses(["notice"]) + +py_strict_library( name = "tf_jitrt", testonly = 1, srcs = ["tf_jitrt.py"], visibility = ["//tensorflow/compiler/mlir/tfrt:__subpackages__"], - deps = [ - ":_tf_jitrt_executor", - "//third_party/py/numpy", - ], + deps = [":_tf_jitrt_executor"], ) py_strict_test( @@ -59,7 +56,7 @@ pybind_extension( ], ) -py_library( +py_strict_library( name = "tfrt_fallback", testonly = True, srcs = ["tfrt_fallback.py"], diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD index 78b4f9d64d4a13..edbdcb36a6b765 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD @@ -1,6 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("@tf_runtime//:build_defs.bzl", "tfrt_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # TF to TFRT kernels conversion. package( @@ -13,7 +13,7 @@ tfrt_cc_library( name = "tf_jitrt_clustering", srcs = ["tf_jitrt_clustering.cc"], hdrs = ["tf_jitrt_clustering.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", @@ -27,7 +27,7 @@ tfrt_cc_library( gentbl_cc_library( name = "tf_jitrt_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -46,7 +46,6 @@ cc_library( name = "tf_jitrt_passes", srcs = [ "tf_jitrt_buffer_forwarding.cc", - "tf_jitrt_clustering_pass.cc", "tf_jitrt_copy_removal.cc", "tf_jitrt_fission.cc", "tf_jitrt_fusion.cc", @@ -55,7 +54,7 @@ cc_library( "tf_jitrt_passes.cc", ], hdrs = ["tf_jitrt_passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":tf_jitrt_clustering", ":tf_jitrt_passes_inc_gen", diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc deleted file mode 100644 index c1416bb5c45c8c..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_CLUSTERING -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -using llvm::ArrayRef; - -using mlir::TF::ConstOp; -using mlir::TF::HashTableV2Op; -using mlir::TF::ReadVariableOp; - -using mlir::TFDevice::Cluster; -using mlir::TFDevice::ClusteringPolicySet; -using mlir::TFDevice::CreateClusterOp; -using mlir::TFDevice::FindClustersInTheBlock; - -// -------------------------------------------------------------------------- // -// Cluster operations based on the TF JitRt clustering policy. -// -------------------------------------------------------------------------- // -struct ClusteringPass : public impl::ClusteringBase { - ClusteringPass() = default; - ClusteringPass(ArrayRef cluster_oplist, int cluster_min_size) { - oplist = cluster_oplist; - min_cluster_size = cluster_min_size; - } - - void runOnOperation() override { - ClusteringPolicySet policies; - - // Parse clustering tier and operations filter from the oplist. - llvm::DenseSet opset; - std::optional tier; - - for (const auto& op : oplist) { - if (op == "tier0") { - tier = JitRtClusteringTier::kTier0; - } else if (op == "tier1") { - tier = JitRtClusteringTier::kTier1; - } else if (op == "tier1metadata") { - tier = JitRtClusteringTier::kTier1Metadata; - } else if (op == "tier1reductions") { - tier = JitRtClusteringTier::kTier1Reductions; - } else if (op == "all") { - tier = JitRtClusteringTier::kAll; - } else { - opset.insert(op); - } - } - - // Run clustering only if the clustering tier or supported operations are - // explicitly defined by the oplist. - if (!tier.has_value() && opset.empty()) return; - - // If the clustering tier is not defined, it means that the opset will later - // filter supported operations, so it's ok to use `all` tier. - populateTfJitRtClusteringPolicies(policies, - tier.value_or(JitRtClusteringTier::kAll)); - - // If opset is not empty restrict operations that are enabled for - // clustering. - auto opset_filter = [&](mlir::Operation* op) -> bool { - return opset.empty() || opset.contains(op->getName().getStringRef()); - }; - - // Find operations that could be hoisted from the function body into the - // TFRT resource initialization function. Currently it is an approximation - // of hoisting rules in the TFRT, we just find all the operations that - // depend only on ConstOp, ReadVariableOp or HashTableV2Op operations. We - // don't do any side effects analysis and conservatively can mark as - // hoistable operations that will not be hoisted by TFRT because of side - // effect dependencies. - // - // TODO(ezhulenev): This should be shared with TFRT hoisting implementation. - - // Initialize a set of operations that we assume we will hoist. - llvm::DenseSet hoisted_ops; - getOperation().walk([&](mlir::Operation* op) { - if (mlir::isa(op)) - hoisted_ops.insert(op); - }); - - // Initialize work list with users of ReadVariableOp results. - llvm::SmallVector work_list; - for (mlir::Operation* hoisted : hoisted_ops) - work_list.append(hoisted->user_begin(), hoisted->user_end()); - - // Traverse all users until we find all operations that could be hoisted. - while (!work_list.empty()) { - mlir::Operation* op = work_list.pop_back_val(); - - // Skip operations that are already in the hoisted set. - if (hoisted_ops.contains(op)) continue; - - // Add operation to hoisted ops set if all operands can be hoisted. - bool all_operands_hoisted = - llvm::all_of(op->getOperands(), [&](mlir::Value value) { - return hoisted_ops.contains(value.getDefiningOp()); - }); - if (!all_operands_hoisted) continue; - - hoisted_ops.insert(op); - work_list.append(op->user_begin(), op->user_end()); - } - - auto hoist_filter = [&](mlir::Operation* op) { - return !hoisted_ops.contains(op); - }; - - // Combine together opset and hoist filters. - auto filter = [&](mlir::Operation* op) -> bool { - return opset_filter(op) && hoist_filter(op); - }; - - // Annotate all formed clusters with an attribute. - auto policy = mlir::StringAttr::get(&getContext(), "tfrt.auto-fusion"); - - getOperation().walk([&](mlir::Block* block) { - for (Cluster& cluster : FindClustersInTheBlock(block, policies, filter)) { - // Do not create too small clusters. - if (cluster.operations.size() < min_cluster_size) continue; - // Verify that JIT runtime can compile the cluster. - if (failed(VerifyCluster(cluster))) continue; - - CreateClusterOp(cluster, policy); - } - }); - } -}; - -} // namespace - -std::unique_ptr> -CreateTfJitRtClusteringPass() { - return std::make_unique(); -} - -std::unique_ptr> -CreateTfJitRtClusteringPass(llvm::ArrayRef oplist, - int min_cluster_size) { - return std::make_unique(oplist, min_cluster_size); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h index ae7a7b8da17f06..b50dad8498b90f 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h @@ -57,14 +57,6 @@ std::unique_ptr> CreateFissionPass(); // Pass to fuse Linalg generic operations on Tensors. std::unique_ptr> CreateFusionPass(); -// Creates `tf_device.cluster` operations according to the TF JitRt clustering -// policy. -std::unique_ptr> -CreateTfJitRtClusteringPass(); -std::unique_ptr> -CreateTfJitRtClusteringPass(llvm::ArrayRef oplist, - int min_cluster_size); - // Pass to replace math ops with approximations. std::unique_ptr> CreateMathApproximationPass(llvm::ArrayRef oplist = {}); diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td index 5e2578bfb50a4d..dbf765dec9bc31 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td @@ -85,27 +85,6 @@ def JitRtLegalizeI1Types ]; } -def Clustering : Pass<"tf-jitrt-clustering", "mlir::func::FuncOp"> { - let summary = "Creates `tf_device.cluster` operations according to the TF " - "JitRt clustering policy"; - - let constructor = "tensorflow::CreateTfJitRtClusteringPass()"; - - let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"]; - - let options = [ - Option<"min_cluster_size", "min-cluster-size", "int" , /*default=*/"1", - "Do not form clusters smaller of the given size.">, - // TODO(ezhulenev): This is a temporary workaround to control TF->JitRt - // clustering policy at runtime. - ListOption<"oplist", "oplist", "std::string", - "Explicitly allow operations for clustering. Only operations in " - "this list will be passed to the TF->JitRt clustering policy. " - "Alternatively use 'tier1', ..., 'all' to allow clustering for " - "all operations included in the given clustering tier."> - ]; -} - def MathApproximation : Pass<"tf-jitrt-math-approximation", "mlir::func::FuncOp"> { let summary = "Approximate math operations with an implementation meant to " "match Eigen's results. This is a useful property to have when " diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist.mlir deleted file mode 100644 index cde0cef4e38ed7..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=tf.Add,tf.Sub,tf.Neg min-cluster-size=2"\ -// RUN: | FileCheck %s - -// CHECK-LABEL: func @single_cluster_one_result -func.func @single_cluster_one_result(%arg0 : tensor, %arg1 : tensor) - -> tensor { - // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() - // CHECK: "tf.Add" - // CHECK: "tf.Neg" - // CHECK: "tf.Sub" - // CHECK: "tf.Neg" - // CHECK: %[[RET:.*]] = "tf.Add" - // CHECK: tf_device.return %[[RET]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Neg"(%0) : (tensor) -> tensor - %2 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor - %3 = "tf.Neg"(%2) : (tensor) -> tensor - %4 = "tf.Add"(%1, %3) : (tensor, tensor) -> tensor - // CHECK: }) {policy = "tfrt.auto-fusion"} - // CHECK: return %[[CLUSTER]] - func.return %4 : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_all.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_all.mlir deleted file mode 100644 index d6f6ed6da92768..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_all.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=all min-cluster-size=2" \ -// RUN: | FileCheck %s - -// CHECK-LABEL: func @single_cluster_one_result -func.func @single_cluster_one_result(%arg0 : tensor, %arg1 : tensor) - -> tensor { - // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() - // CHECK: "tf.Add" - // CHECK: "tf.Neg" - // CHECK: "tf.Sub" - // CHECK: "tf.Neg" - // CHECK: %[[RET:.*]] = "tf.Add" - // CHECK: tf_device.return %[[RET]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Neg"(%0) : (tensor) -> tensor - %2 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor - %3 = "tf.Neg"(%2) : (tensor) -> tensor - %4 = "tf.Add"(%1, %3) : (tensor, tensor) -> tensor - // CHECK: }) {policy = "tfrt.auto-fusion"} - // CHECK: return %[[CLUSTER]] - func.return %4 : tensor -} - -// CHECK-LABEL: func @do_not_cluster_hoistable_ops -func.func @do_not_cluster_hoistable_ops( - %arg0 : tensor, - %arg1 : tensor<*x!tf_type.resource>, - %arg2 : tensor<*x!tf_type.resource> - ) -> tensor { - // CHECK: "tf.Const" - // CHECK: "tf.ReadVariableOp" - // CHECK: "tf.ReadVariableOp" - // CHECK: "tf.Add" - // CHECK: "tf.Neg" - // CHECK: "tf.Sub" - %c = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %x = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>) -> tensor - %y = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf_type.resource>) -> tensor - %0 = "tf.Add"(%x, %y) : (tensor, tensor) -> tensor - %1 = "tf.Neg"(%0) : (tensor) -> tensor - %2 = "tf.Sub"(%0, %c) : (tensor, tensor) -> tensor - // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() - // CHECK: "tf.Sub" - // CHECK: "tf.Neg" - // CHECK: %[[RET:.*]] = "tf.Add" - // CHECK: tf_device.return %[[RET]] - %3 = "tf.Sub"(%arg0, %2) : (tensor, tensor) -> tensor - %4 = "tf.Neg"(%3) : (tensor) -> tensor - %5 = "tf.Add"(%2, %4) : (tensor, tensor) -> tensor - // CHECK: }) {policy = "tfrt.auto-fusion"} - // CHECK: return %[[CLUSTER]] - func.return %5 : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_tier1.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_tier1.mlir deleted file mode 100644 index 026828ae368718..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_tier1.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=tier1 min-cluster-size=2" \ -// RUN: | FileCheck %s --check-prefix CHECK --check-prefix=TIER1 -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=tier1metadata min-cluster-size=2" \ -// RUN: | FileCheck %s --check-prefix CHECK --check-prefix=METADATA -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=tier1reductions min-cluster-size=2" \ -// RUN: | FileCheck %s --check-prefix CHECK --check-prefix=REDUCTIONS - -// CHECK-LABEL: func @single_cluster_one_result -func.func @single_cluster_one_result(%arg0 : tensor, %arg1 : tensor) - -> tensor { - // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() - // TIER1-NOT: "tf.Sum" - // TIER1: "tf.Add" - // TIER1: "tf.Neg" - // TIER1: "tf.Sub" - // TIER1: "tf.Neg" - // TIER1: %[[RET:.*]] = "tf.Add" - // TIER1: tf_device.return %[[RET]] - - // METADATA-NOT: "tf.Sum" - // METADATA: "tf.Add" - // METADATA: "tf.Neg" - // METADATA: "tf.Sub" - // METADATA: "tf.Neg" - // METADATA: "tf.Add" - // METADATA: %[[RET:.*]] = "tf.Shape" - // METADATA: tf_device.return %[[RET]] - - // REDUCTIONS: "tf.Sum" - // REDUCTIONS: "tf.Add" - // REDUCTIONS: "tf.Neg" - // REDUCTIONS: "tf.Sub" - // REDUCTIONS: "tf.Neg" - // REDUCTIONS: %[[RET:.*]] = "tf.Add" - // REDUCTIONS: tf_device.return %[[RET]] - %dimension = "tf.Const"() { value = dense<0> : tensor<1xi64> } : () -> tensor<1xi64> - %s = "tf.Sum"(%arg0, %dimension) { keep_dims = false }: (tensor, tensor<1xi64>) -> tensor - %0 = "tf.Add"(%s, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Neg"(%0) : (tensor) -> tensor - %2 = "tf.Sub"(%s, %arg1) : (tensor, tensor) -> tensor - %3 = "tf.Neg"(%2) : (tensor) -> tensor - %4 = "tf.Add"(%1, %3) : (tensor, tensor) -> tensor - %5 = "tf.Shape"(%4) : (tensor) -> tensor - // CHECK: }) {policy = "tfrt.auto-fusion"} - // TIER1: %[[SHAPE:.*]] = "tf.Shape"(%[[CLUSTER]]) - // TIER1: return %[[SHAPE]] - - // REDUCTIONS: %[[SHAPE:.*]] = "tf.Shape"(%[[CLUSTER]]) - // REDUCTIONS: return %[[SHAPE]] - func.return %5 : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir deleted file mode 100644 index 855b5f6e6d5489..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir +++ /dev/null @@ -1,130 +0,0 @@ -// RUN: tf-tfrt-opt -tf-jitrt-pipeline="vectorize" -split-input-file %s | FileCheck %s - -func.func @transpose_2d(%arg0: tensor) -> tensor { - %0 = "tf.Const"() - {value = dense<[1, 0]> : tensor<2xi64>, - device = "/job:localhost/replica:0/task:0/device:CPU:0"} - : () -> tensor<2xi64> - %1 = "tf.Transpose"(%arg0, %0) - {device = "/job:localhost/replica:0/task:0/device:CPU:0"} - : (tensor, tensor<2xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_2d -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// 8x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. - -// ----- - -func.func @transpose_3d_021(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 1x8x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C1]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. - -// ----- - -func.func @transpose_3d_201(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[2, 0, 1]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d_201 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 8x1x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C1]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. - -// ----- - -func.func @transpose_3d_210(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d_210 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 8x1x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C1]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. - -// ----- - -func.func @transpose_3d_120(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[1, 2, 0]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d_120 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 1x8x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C1]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir deleted file mode 100644 index a01636a4d9dfe2..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir +++ /dev/null @@ -1,440 +0,0 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-jitrt-pipeline %s | FileCheck %s - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @tanh_lower_and_fuse -// CHECK-SAME: %[[ARG:.*]]: memref -func.func @tanh_lower_and_fuse(%arg0: tensor) -> tensor { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]] - // CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DIM]]) {{.*}} : memref - - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - // CHECK-SAME: ins(%[[ARG]] : memref) - // CHECK-SAME: outs(%[[MEMREF]] : memref) - // CHECK: tanh - // CHECK-NEXT: tanh - - // CHECK: return %[[MEMREF]] - %0 = "tf.Tanh"(%arg0): (tensor) -> tensor - %1 = "tf.Tanh"(%0): (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @sigmoid_dynamic_dim -func.func @sigmoid_dynamic_dim(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - %0 = "tf.Sigmoid"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0) -> ()> -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_scalar_with_vec -func.func @add_scalar_with_vec(%arg0: tensor, - %arg1: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_vec_vec -func.func @add_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_vec_vec_vec -func.func @add_vec_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg2): (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// Verify that symbolic shape optimization can move all the broadcasts up, and -// progressively remove all shape constraints and replace mhlo broadcasts with -// linalg.generic operations that in the end all are fused together. - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, 0)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -// CHECK: compute_with_bcast -func.func @compute_with_bcast( - %arg0: tensor<1x?x1xf32> - {rt.symbolic_shape = dense<[1, -2, 1]> : tensor<3xi64>}, - %arg1: tensor<512xf32>, - %arg2: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}, - %arg3: tensor<1x?x1xf32> - {rt.symbolic_shape = dense<[1, -2, 1]> : tensor<3xi64>}, - %arg4: tensor<512xf32> -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK-NEXT: math.rsqrt - // CHECK-NEXT: mulf - // CHECK-NEXT: mulf - // CHECK-NEXT: subf - // CHECK-NEXT: mulf - // CHECK-NEXT: addf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %c = "tf.Const"() {value = dense<9.99999996E-13> - : tensor} : () -> tensor - %0 = "tf.AddV2"(%arg0, %c) - : (tensor<1x?x1xf32>, tensor) -> tensor - %1 = "tf.Rsqrt"(%0) - : (tensor) -> tensor - %2 = "tf.Mul"(%1, %arg1) - : (tensor, tensor<512xf32>) -> tensor - %3 = "tf.Mul"(%2, %arg2) - : (tensor, tensor<1x?x512xf32>) -> tensor - %4 = "tf.Mul"(%2, %arg3) - : (tensor, tensor<1x?x1xf32>) -> tensor - %5 = "tf.Sub"(%arg4, %4) - : (tensor<512xf32>, tensor) -> tensor - %6 = "tf.AddV2"(%3, %5) - : (tensor, tensor) -> tensor - func.return %6 : tensor -} - -// ----- - -// CHECK: add_vec_vec_vec_vec -func.func @add_vec_vec_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg3: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg2): (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg3): (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK: add_vec_tensor_tensor -func.func @add_vec_tensor_tensor( - %arg0: tensor<512xf32>, - %arg1: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}, - %arg2: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>} -) -> tensor<1x?x512xf32> { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1) - : (tensor<512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32> - %1 = "tf.AddV2"(%arg2, %0) - : (tensor<1x?x512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32> - func.return %1 : tensor<1x?x512xf32> -} - -// ----- - -// CHECK-LABEL: @tf_binary_with_bcast -func.func @tf_binary_with_bcast(%arg0: tensor, - %arg1: tensor) -> tensor { - // CHECK-NOT: shape. - // CHECK: %[[LHS:.*]] = memref.reinterpret_cast - // CHECK: %[[RHS:.*]] = memref.reinterpret_cast - // CHECK: linalg.generic {{.*}} ins(%[[LHS]], %[[RHS]] : - // CHECK: mulf - %0 = "tf.Mul"(%arg0, %arg1) - : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tf_binary_with_bcast_and_fusion -// CHECK-SAME: %[[ARG0:.*]]: memref, -// CHECK-SAME: %[[ARG1:.*]]: memref<4xf32>, -// CHECK-SAME: %[[ARG2:.*]]: memref<4xf32> -func.func @tf_binary_with_bcast_and_fusion(%arg0: tensor, - %arg1: tensor<4xf32>, - %arg2: tensor<4xf32>) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) - // CHECK: math.log1p - // CHECK-NEXT: subf - // CHECK-NEXT: mulf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %0 = "tf.Log1p"(%arg0) - : (tensor) -> tensor - %1 = "tf.Sub"(%0, %arg1) - : (tensor, tensor<4xf32>) -> tensor - %2 = "tf.Mul"(%1, %arg2) - : (tensor, tensor<4xf32>) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK: tf_binary_with_bcast_symbolic_shapes -func.func @tf_binary_with_bcast_symbolic_shapes( - %arg0: tensor {rt.symbolic_shape = dense<[ -3]>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>}, - %arg3: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: log1p - // CHECK: addf - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.Log1p"(%arg0) - : (tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg1) - : (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg2) - : (tensor, tensor) -> tensor - %3 = "tf.AddV2"(%2, %arg3) - : (tensor, tensor) -> tensor - func.return %3 : tensor -} - -// ----- - -// CHECK-LABEL: @cast_sub -func.func @cast_sub(%arg0: tensor, %arg1: tensor) - -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: outs(%[[RESULT_BUF:.*]] : memref) - // CHECK-SAME: { - // CHECK: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: i16, %{{.*}}: f16): - // CHECK: %[[RHS_CASTED:.*]] = arith.sitofp %[[RHS]] : i16 to f16 - // CHECK: %[[RESULT:.*]] = arith.subf %[[LHS]], %[[RHS_CASTED]] : f16 - // CHECK: linalg.yield %[[RESULT]] : f16 - // CHECK: } - // CHECK: return %[[RESULT_BUF]] : memref - %0 = "tf.Cast"(%arg0) : (tensor) -> tensor - %1 = "tf.Sub"(%arg1, %0) : (tensor, tensor) - -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @tf_transpose_const_perm -func.func @tf_transpose_const_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { - // CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} : memref<3x2xf32> - // CHECK: linalg.generic {indexing_maps = [#map{{[0-9]*}}, #map{{[0-9]*}}] - // CHECK-SAME: ins(%arg0 : memref<2x3xf32>) - // CHECK-SAME: outs(%[[OUT]] : memref<3x2xf32>) - %0 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } - : () -> tensor<2xi32> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - func.return %1 : tensor<3x2xf32> -} - -// ----- - -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d2, d0, d1)> -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -// CHECK-LABEL: @tf_transpose_after_transpose -func.func @tf_transpose_after_transpose(%arg0: tensor) - -> tensor { - // CHECK: %[[OUT:.*]] = memref.alloc - // CHECK: linalg.generic {indexing_maps = [#map{{[0-9]*}}, #map{{[0-9]*}}] - // CHECK-SAME: ins(%arg0 : memref) - // CHECK-SAME: outs(%[[OUT]] : memref) - // CHECK-NOT: linalg.generic - %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi32> } - : () -> tensor<3xi32> - %1 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi32> } - : () -> tensor<3xi32> - %2 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi32>) -> tensor - %3 = "tf.Transpose"(%2, %1) - : (tensor, tensor<3xi32>) -> tensor - func.return %3 : tensor -} - -// ----- - -// CHECK-LABEL: @bias_add_and_relu -// CHECK-SAME: %[[ARG0:.*]]: memref -// CHECK-SAME: %[[ARG1:.*]]: memref<32xf32> -func.func @bias_add_and_relu(%arg0: tensor, - %arg1: tensor<32xf32>) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) - // CHECK: addf - // CHECK: maxf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %0 = "tf.BiasAdd"(%arg0, %arg1) - : (tensor, tensor<32xf32>) -> tensor - %1 = "tf.Relu"(%0): (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @sub_sub -func.func @sub_sub(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: outs(%[[RESULT_BUF:.*]] : memref) - // CHECK: ^bb0(%[[A:.*]]: f16, %[[B:.*]]: f16, %[[C:.*]]: f16, %{{.*}}: f16): - // CHECK: %[[TMP:.*]] = arith.subf %[[B]], %[[C]] - // CHECK: %[[RESULT:.*]] = arith.subf %[[A]], %[[TMP]] - // CHECK: linalg.yield %[[RESULT]] - // CHECK: return %[[RESULT_BUF]] : memref - %0 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Sub"(%arg2, %0) : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @strided_slice_1d_to_0d -func.func @strided_slice_1d_to_0d(%arg0: tensor<3xi32>) -> tensor { - %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[0] [1] [1] - // CHECK-SAME: : memref<3xi32> to memref<1xi32, strided<[1]>> - // CHECK: %[[RET:.*]] = memref.collapse_shape %[[SUBVIEW]] - // CHECK: return %[[RET]] - %0 = "tf.StridedSlice"(%arg0, %cst_1, %cst_0, %cst_0) - { - begin_mask = 0 : i64, - ellipsis_mask = 0 : i64, - end_mask = 0 : i64, - new_axis_mask = 0 : i64, - shrink_axis_mask = 1 : i64 - } : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) - -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<[0, 1]> -// CHECK-SAME: {alignment = 64 : i64} -// CHECK-LABEL: @constant_folding -func.func @constant_folding() -> tensor<2xi32> { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[CONST:.*]] = memref.get_global @__constant_2xi32 : memref<2xi32> - // CHECK: return %[[CONST]] - %2 = "tf.Pack"(%0, %1) {axis = 0 : i64} - : (tensor, tensor) -> tensor<2xi32> - func.return %2 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: @add_floormod_add -func.func @add_floormod_add(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg0) - : (tensor, tensor) -> tensor - %1 = "tf.FloorMod"(%0, %arg0) - : (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg0) - : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: @min_clip_by_value -func.func @min_clip_by_value(%V__0: tensor) -> tensor { - %dims0 = "tf.Const"() { value = dense<[1, 2]> : tensor<2xi32> }: () -> tensor<2xi32> - %0 = "tf.Min"(%V__0, %dims0) {keep_dims = true} : (tensor, tensor<2xi32>) -> tensor - %1 = "tf.ClipByValue"(%V__0, %0, %V__0) : (tensor, tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @rint_sq_sub -func.func @rint_sq_sub(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.Rint"(%arg0) : (tensor) -> tensor - %1 = "tf.Square"(%arg0) : (tensor) -> tensor - %2 = "tf.Sub"(%0, %1) : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: @do_not_fuse_if_multiple_uses -func.func @do_not_fuse_if_multiple_uses(%arg0: tensor) - -> (tensor, tensor) { - // CHECK: linalg.generic - // CHECK: math.rsqrt - // CHECK-NEXT: math.rsqrt - // CHECK-NEXT: linalg.yield - %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor - %1 = "tf.Rsqrt"(%0) : (tensor) -> tensor - // CHECK: linalg.generic - // CHECK: math.rsqrt - // CHECK-NEXT: linalg.yield - %2 = "tf.Rsqrt"(%1) : (tensor) -> tensor - // CHECK-NOT: linalg.generic - func.return %1, %2 : tensor, tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir deleted file mode 100644 index 07d2d6a3f08434..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir +++ /dev/null @@ -1,75 +0,0 @@ -// RUN: tf-tfrt-opt -tf-jitrt-pipeline="vectorize" \ -// RUN: %s -split-input-file | FileCheck %s - -// CHECK-LABEL: @reduce_row_sum_2d_dynamic -func.func @reduce_row_sum_2d_dynamic(%input: tensor) -> tensor { - %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} - : () -> tensor<1xi32> - %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: scf.for -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> - -// ----- - -// CHECK-LABEL: @reduce_column_sum_2d_dynamic -func.func @reduce_column_sum_2d_dynamic(%input: tensor) -> tensor { - %dim_to_reduce = "tf.Const"() {value = dense<[0]> : tensor<1xi32>} - : () -> tensor<1xi32> - %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: scf.for -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> - -// ----- - -// CHECK-LABEL: @reduce_row_mean_2d_dynamic -func.func @reduce_row_mean_2d_dynamic(%input: tensor) -> tensor { - %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} - : () -> tensor<1xi32> - %0 = "tf.Mean"(%input, %dim_to_reduce) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: scf.for -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: scf.yield -// CHECK: arith.divf %{{.*}}, %{{.*}} : vector<4xf32> - -// ----- - -// CHECK-LABEL: @reduce_1d_dynamic -func.func @reduce_1d_dynamic(%input: tensor) -> tensor { - %dim_to_reduce = "tf.Const"() {value = dense<[0]> : tensor<1xi32>} - : () -> tensor<1xi32> - %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: scf.for -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32> -// CHECK: vector.reduction - -// ----- - -// CHECK-LABEL: @reduction_of_cast -func.func @reduction_of_cast(%arg0: tensor) -> tensor { - %cst = "tf.Const"() - {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %0 = "tf.Cast"(%arg0) {Truncate = false} - : (tensor) -> tensor - %1 = "tf.Prod"(%0, %cst) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %1 : tensor -} -// CHECK: scf.for -// CHECK: arith.trunci -// CHECK: scf.for -// CHECK: arith.muli diff --git a/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc b/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc index 3089afe93686b5..4cd8ce4b833ce3 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc @@ -131,7 +131,6 @@ TEST(SavedModelTest, ConvertTfMlirToBefWithXlaFuncExport) { tfrt_stub::GraphExecutionOptions options(runtime.get()); options.compile_options.device_target = TfrtDeviceInfraTarget::kGpu; - options.compile_options.use_bridge_for_gpu = true; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr fallback_state, @@ -170,7 +169,6 @@ TEST(SavedModelTest, ConvertTfMlirToBefExportingXlaReduceWindow) { tensorflow::tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/1); tfrt_stub::GraphExecutionOptions options(runtime.get()); options.compile_options.device_target = TfrtDeviceInfraTarget::kGpu; - options.compile_options.use_bridge_for_gpu = true; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr fallback_state, diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir deleted file mode 100644 index b6c63f3f560bcf..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir +++ /dev/null @@ -1,67 +0,0 @@ -// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="auto-fusion-oplist=tf.Rsqrt,tf.Tanh auto-fusion-min-cluster-size=1" -split-input-file %s \ -// RUN: | FileCheck %s --dump-input=always - -// CHECK-LABEL: func @single_op_cluster -// CHECK: %[[ARG0:.*]]: !tfrt.chain -// CHECK: %[[ARG1:.*]]: !corert.tensorhandle -func.func @single_op_cluster(%arg0: tensor) -> tensor { - // CHECK: %[[ARG:.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor - // CHECK-SAME: %[[ARG1]] - // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:CPU:0" - // CHECK: %[[RES:.*]] = tf_jitrt.fallback.execute @kernel::@compute(%[[ARG]]) - // CHECK: %[[OUT:.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle - // CHECK-SAME: %[[RES]] - // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:CPU:0" - // CHECK: tfrt.return %[[ARG0]], %[[OUT]] : !tfrt.chain, !corert.tensorhandle - %0 = "tf.Rsqrt"(%arg0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK-LABEL: func @compute -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK: %[[RES:.*]] = "tf.Rsqrt"(%[[ARG0]]) -// CHECK: return %[[RES]] - -// ----- - -// CHECK-LABEL: func @one_compiled_cluster -func.func @one_compiled_cluster(%arg0: tensor) -> tensor { - // CHECK: %[[RES:.*]] = tf_jitrt.fallback.execute @kernel::@compute - // CHECK-NOT: Rsqrt - // CHECK-NOT: Tanh - %0 = "tf.Rsqrt"(%arg0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - %1 = "tf.Tanh"(%0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - func.return %1 : tensor -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK-LABEL: func @compute -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK: %[[RES0:.*]] = "tf.Rsqrt"(%[[ARG0]]) -// CHECK: %[[RES1:.*]] = "tf.Tanh"(%[[RES0]]) -// CHECK: return %[[RES1]] - -// ----- - -// CHECK-LABEL: func @two_compiled_clusters -func.func @two_compiled_clusters(%arg0: tensor) -> tensor { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = "tf.Rsqrt"(%arg0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - // CHECK: tfrt_fallback_async.executeop {{.*}} "tf.Sqrt" - %1 = "tf.Sqrt"(%0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - // CHECK: tf_jitrt.fallback.execute @kernel_0::@compute - %2 = "tf.Tanh"(%1) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - func.return %2 : tensor -} - -// CHECK: module @kernel -// CHECK: tf.Rsqrt -// CHECK: module @kernel_0 -// CHECK: tf.Tanh diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/outline-cpurt-cluster.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/outline-cpurt-cluster.mlir deleted file mode 100644 index 9ade6a0d6f0243..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/outline-cpurt-cluster.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-outline-jitrt-cluster %s \ -// RUN: | FileCheck %s - -// ----- -// Outline a simple cluster with a single operation. - -// CHECK-LABEL: func @simple_cluster -func.func @simple_cluster(%arg0: tensor) -> tensor { - // CHECK: %[[RES:.*]] = jitrt.call(%arg0) - // CHECK-SAME: {callee = @kernel::@compute} - // CHECK-SAME: (tensor) -> tensor - %0 = "tf_device.cluster"() ({ - %1 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor - tf_device.return %1 : tensor - }) { policy = "tfrt.auto-fusion" } : () -> tensor - func.return %0 : tensor -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %arg0: tensor -// CHECK-SAME: ) -> tensor { -// CHECK: %[[RET:.*]] = "tf.Rsqrt"(%arg0) -// CHECK: return %[[RET]] -// CHECK: } - -// ----- -// Check that tf.Transpose constraint propagated to the function argument. - -// CHECK-LABEL: func @cluster_with_transpose -func.func @cluster_with_transpose(%arg0: tensor, - %arg1: tensor<2xi32>) -> tensor { - // CHECK: %[[RES:.*]] = jitrt.call(%arg0, %arg1) - // CHECK-SAME: {callee = @kernel::@compute} - // CHECK-SAME: (tensor, tensor<2xi32>) -> tensor - %0 = "tf_device.cluster"() ({ - %1 = "tf.Transpose"(%arg0, %arg1) - : (tensor, tensor<2xi32>) -> tensor - tf_device.return %1 : tensor - }) { policy = "tfrt.auto-fusion" } : () -> tensor - func.return %0 : tensor -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 2 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %arg0: tensor -// CHECK-SAME: %arg1: tensor<2xi32> {rt.constraint = "value"} -// CHECK-SAME: ) -> tensor { -// CHECK: %[[RET:.*]] = "tf.Transpose"(%arg0, %arg1) -// CHECK: return %[[RET]] -// CHECK: } diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir deleted file mode 100644 index 43267d6ec2264e..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir +++ /dev/null @@ -1,190 +0,0 @@ -// RUN: tf-tfrt-opt %s \ -// RUN: -split-input-file \ -// RUN: -tf-executor-to-tfrt-pipeline=" \ -// RUN: enable-optimizer=true \ -// RUN: tfrt-cost-threshold=1024 \ -// RUN: auto-fusion-oplist=tf.Relu,tf.Transpose,tf.Const \ -// RUN: auto-fusion-min-cluster-size=1" \ -// RUN: | FileCheck %s --dump-input=always - -// Check TF->JitRT JIT compiled operations clustering and outlining starting -// from the Tensorflow executor dialect. - -// ----- -// Simple cluster consisting of a single operation. - -module attributes {tf.versions = {producer = 462 : i32}} { - // CHECK: func @_tfrt_fallback_init(%[[ARG:.*]]: !tfrt.chain) - // CHECK: %[[COMPILED:.*]] = tf_jitrt.fallback.compile @kernel::@compute - // CHECK: %[[CHAIN:.*]] = tfrt.merge.chains %[[ARG]], %[[COMPILED]] - // CHECK: tfrt.return %[[CHAIN]] - - // CHECK: func @call - func.func @call(%arg0: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0", - outputs = "output_0"}} { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = tf_executor.graph { - %outs, %control = tf_executor.island wraps "tf.Relu"(%arg0) - {device = ""} : (tensor) -> tensor - tf_executor.fetch %outs: tensor - } - func.return %0 : tensor - } -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: ) -> tensor { -// CHECK: %[[RELU:.*]] = "tf.Relu"(%[[ARG0]]) -// CHECK: return %[[RELU]] -// CHECK: } - -// ----- -// Two identical clusters (except the _class attribute) consisting of a single -// `Relu` operation. Check that outlined clusters are deduplicated and we -// compile only once. - -module attributes {tf.versions = {producer = 462 : i32}} { - // CHECK: func @_tfrt_fallback_init - // CHECK: tf_jitrt.fallback.compile @kernel::@compute - // CHECK-NOT: tf_jitrt.fallback.compile - - // CHECK: func @call - func.func @call(%arg0: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0", - outputs = "output_0"}} { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - // CHECK: tfrt_fallback_async.executeop {{.*}} "tf.Sqrt" - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = tf_executor.graph { - %outs0, %control0 = tf_executor.island wraps "tf.Relu"(%arg0) - {device = "", _class = ["loc:@Relu_0"]} - : (tensor) -> tensor - %outs1, %control1 = tf_executor.island wraps "tf.Sqrt"(%outs0) - {device = ""} : (tensor) -> tensor - %outs2, %control2 = tf_executor.island wraps "tf.Relu"(%outs1) - {device = "", _class = ["loc:@Relu_1"]} - : (tensor) -> tensor - tf_executor.fetch %outs2: tensor - } - func.return %0 : tensor - } -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: ) -> tensor { -// CHECK: %[[RELU:.*]] = "tf.Relu"(%[[ARG0]]) -// CHECK: return %[[RELU]] -// CHECK: } - -// ----- -// Constants sunk into the outlined compiled functions. - -module attributes {tf.versions = {producer = 462 : i32}} { - // CHECK: func @_tfrt_fallback_init - // CHECK: tf_jitrt.fallback.compile @kernel::@compute - - // CHECK: func @call - func.func @call(%arg0: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0", - outputs = "output_0"}} { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = tf_executor.graph { - %perm, %perm_ctl = tf_executor.island wraps "tf.Const"() - {device = "", value = dense<[1, 0]> : tensor<2xi32>} - : () -> tensor<2xi32> - %out, %out_ctl = tf_executor.island wraps "tf.Transpose"(%arg0, %perm) - {device = ""} - : (tensor, tensor<2xi32>) -> tensor - tf_executor.fetch %out: tensor - } - func.return %0 : tensor - } -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: ) -> tensor { -// CHECK: %[[PERM:.*]] = "tf.Const"() {{.*}} dense<[1, 0]> -// CHECK: %[[RET:.*]] = "tf.Transpose"(%[[ARG0]], %[[PERM]]) -// CHECK: return %[[RET]] -// CHECK: } - -// ----- -// tf.Transpose: a non-const permutation parameter cannot be sunk into the -// compiled function. Such a transpose should, however, support clustering, -// and its permutation parameter should compile to be value-constrained. - -module attributes {tf.versions = {producer = 462 : i32}} { - // CHECK: func @_tfrt_fallback_init - // CHECK: tf_jitrt.fallback.compile @kernel::@compute - - // CHECK: func @call - func.func @call(%arg0: tensor, %arg1: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0,input_1", - outputs = "output_0"}} { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = tf_executor.graph { - %out, %out_ctl = tf_executor.island wraps "tf.Transpose"(%arg0, %arg1) - {device = ""} - : (tensor, tensor) -> tensor - tf_executor.fetch %out: tensor - } - func.return %0 : tensor - } -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: %[[ARG1:.*]]: tensor {rt.constraint = "value"} -// CHECK-SAME: ) -> tensor { -// CHECK-NEXT: %[[RET:.*]] = "tf.Transpose"(%[[ARG0]], %[[ARG1]]) -// CHECK: return %[[RET]] -// CHECK: } - -// ----- -// Operations with unsupported data type operands/results are not clustered. - -module attributes {tf.versions = {producer = 462 : i32}} { - func.func @call(%arg0: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0", - outputs = "output_0"}} { - // CHECK-NOT: tf_jitrt.fallback.compile - // CHECK-NOT: tf_jitrt.fallback.execute - // CHECK-NOT: module @kernel - %0 = tf_executor.graph { - %perm, %perm_ctl = - tf_executor.island wraps "tf.Const"() - {device = "", value = dense<[1, 0]> : tensor<2xi32>} - : () -> tensor<2xi32> - %out, %out_ctl = - tf_executor.island wraps "tf.Transpose"(%arg0, %perm) {device = ""} - : (tensor, tensor<2xi32>) -> tensor - tf_executor.fetch %out: tensor - } - func.return %0 : tensor - } -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir index 3905074bfd25b6..6fee8545cfe6c4 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true use-bridge-for-gpu=true func-use-fallback-tensor=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all +// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true func-use-fallback-tensor=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all func.func private @xla_func_0(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { %1 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index bf27bac6ffb11c..c812cf1c9f51ef 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/device_name_utils.h" @@ -128,7 +127,7 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( // flow, which is converted back after the optimization passes are performed. pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); pm.addPass(mlir::createInlinerPass()); - pm.addNestedPass( + pm.addNestedPass( mlir::TF::CreateRemoveUnusedWhileResultsPass()); pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); @@ -173,8 +172,6 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( pm.addNestedPass( mlir::TF::CreateTensorDeviceCopyConversionPass()); - AddTfrtJitRtPasses(options, pm); - // Rewriter operation sequences to device specific fusions. DeviceNameUtils::ParsedName parsed_name; @@ -218,8 +215,7 @@ void CreateTFExecutorToTFInvariantOptimizationPipelineHelper( } Status ValidateTfrtPipelineOptions(const TfrtPipelineOptions &options) { - if (options.target_tpurt && - (options.target_gpu || options.use_bridge_for_gpu)) { + if (options.target_tpurt && options.target_gpu) { return tensorflow::errors::Internal( "Invalid pipeline options. Targeting both TPU and GPU is not " "supported."); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index 5cde65d2c65508..9b57bf04156c06 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -55,7 +55,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" #include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" #include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -85,7 +84,6 @@ void getDependentConversionDialects(mlir::DialectRegistry ®istry) { tfrt::fallback_async::FallbackAsyncDialect, tfrt::compiler::TFRTDialect>(); mlir::func::registerAllExtensions(registry); - RegisterJitRtDialects(registry); } mlir::Value GetFunctionInputChain(mlir::Operation *op) { @@ -156,7 +154,7 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { tfrt_compiler::FallbackConverter *fallback_converter, const mlir::SymbolTable *symbol_table, const tfrt_compiler::CostAnalysis *cost_analysis, - bool tpu_lower_to_fallback, bool target_tpurt, bool use_bridge_for_gpu) + bool tpu_lower_to_fallback, bool target_tpurt) : mlir::ConversionPattern(mlir::Pattern::MatchAnyOpTypeTag(), kFallbackBenefit, context), corert_converter_(*corert_converter), @@ -164,8 +162,7 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { symbol_table_(*symbol_table), cost_analysis_(*cost_analysis), tpu_lower_to_fallback_(tpu_lower_to_fallback), - target_tpurt_(target_tpurt), - use_bridge_for_gpu_(use_bridge_for_gpu) {} + target_tpurt_(target_tpurt) {} LogicalResult matchAndRewrite( mlir::Operation *op, ArrayRef operands, @@ -195,7 +192,7 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { // e.g., variable lifting. The new MLIR function will need to be exported to // the function library for runtime to use. bool use_mlir_func_name = - parsed_device_name->device_type == DEVICE_GPU && use_bridge_for_gpu_ && + parsed_device_name->device_type == DEVICE_GPU && op->getName().getStringRef().str() == "tf.XlaLaunch"; mlir::ArrayAttr op_func_attrs = corert_converter_.CreateOpFuncAttrs( @@ -295,8 +292,6 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { const tfrt_compiler::CostAnalysis &cost_analysis_; bool tpu_lower_to_fallback_; bool target_tpurt_; - // TODO(b/260915352): Remove the flag and default to using bridge. - bool use_bridge_for_gpu_; }; mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp( @@ -334,8 +329,7 @@ mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp( // For now, we only consider GPU XLA clusters in the form of XlaLaunch for // simplicity. We could extend to support other GPU ops that cann't be XLAed. bool is_xla_launch_on_gpu = - is_gpu_op && use_bridge_for_gpu_ && - op->getName().getStringRef().str() == "tf.XlaLaunch"; + is_gpu_op && op->getName().getStringRef().str() == "tf.XlaLaunch"; if (is_xla_launch_on_gpu) { new_operands = AddGpuVariableAndInputTensorTransferOps(op, new_operands, rewriter); @@ -1446,11 +1440,11 @@ void PopulateTFToTFRTConversionPatterns( const tfrt_compiler::TensorArraySideEffectAnalysis *tensor_array_side_effect_analysis, bool func_use_fallback_tensor, bool enable_while_parallel_iterations, - bool tpu_lower_to_fallback, bool target_tpurt, bool use_bridge_for_gpu) { + bool tpu_lower_to_fallback, bool target_tpurt) { // By default, we lower all TF ops to fallback ops. patterns->add( context, corert_converter, fallback_converter, symbol_table, - cost_analysis, tpu_lower_to_fallback, target_tpurt, use_bridge_for_gpu); + cost_analysis, tpu_lower_to_fallback, target_tpurt); patterns->add(context, corert_converter); @@ -1525,7 +1519,6 @@ class TfToTfrtConversionPass enable_while_parallel_iterations_ = options.enable_while_parallel_iterations; target_gpu_ = options.target_gpu; - use_bridge_for_gpu_ = options.use_bridge_for_gpu; } TfToTfrtConversionPass(const TfToTfrtConversionPass &) {} @@ -1564,14 +1557,11 @@ class TfToTfrtConversionPass SetUpTFToTFRTConversionLegality(&target, func_type_converter, corert_converter.chain_type()); - PopulateJitRtConversionPatterns(&target, &context, &patterns, - &corert_converter); - PopulateTFToTFRTConversionPatterns( &context, &patterns, &corert_converter, &fallback_converter, &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis, func_use_fallback_tensor_, enable_while_parallel_iterations_, - tpu_lower_to_fallback_, target_tpurt_, use_bridge_for_gpu_); + tpu_lower_to_fallback_, target_tpurt_); return mlir::applyPartialConversion(func, target, std::move(patterns)); } @@ -1689,9 +1679,6 @@ class TfToTfrtConversionPass chain_value = create_op; } - chain_value = - CreateJitRtFallbackCompileKernel(builder, module, chain_value); - builder.create(func_op.getLoc(), chain_value); } @@ -1769,11 +1756,6 @@ class TfToTfrtConversionPass llvm::cl::desc("If true, target GPU compiler passes."), llvm::cl::init(false)}; - // TODO(b/260915352): Remove the flag and default to using bridge. - Option use_bridge_for_gpu_{ - *this, "use-bridge-for-gpu", - llvm::cl::desc("If true, GPU bridge is used."), llvm::cl::init(false)}; - Option cost_threshold_{ *this, "tfrt-cost-threshold", llvm::cl::desc( diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc deleted file mode 100644 index 91a4c1d61fd166..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc +++ /dev/null @@ -1,414 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "llvm/Support/FormatVariadic.h" -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" -#include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" -#include "tfrt/jitrt/opdefs/jitrt_ops.h" // from @tf_runtime -#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime - -namespace tensorflow { -namespace { - -class TfrtJitRtStubImpl : public TfrtJitRtStub { - void RegisterJitRtDialects(mlir::DialectRegistry ®istry) override; - - void PopulateJitRtConversionPatterns( - mlir::ConversionTarget *target, mlir::MLIRContext *context, - mlir::RewritePatternSet *patterns, - CoreRTConverter *corert_converter) override; - - mlir::Value CreateJitRtFallbackCompileKernel( - mlir::OpBuilder &builder, mlir::ModuleOp module, - mlir::Value chain_value) override; - - void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm) override; -}; - -void TfrtJitRtStubImpl::RegisterJitRtDialects(mlir::DialectRegistry ®istry) { - registry.insert(); -} - -// TODO(ezhulenev): tf_device.cluster operations after auto-fusion should -// have the correct device assigned based on the fused operations. We should -// use this device to convert operands and results from/to corert handles. -// For now it is safe to assume that it is "CPU" because we do not support -// any other devices and do not support distributed models. -constexpr char kJitRtDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0"; - -// Convert jitrt.call operations to the tf_jitrt.fallback.execute operation. -class JitRtCallToJitRtCompileAndExecuteConversion - : public OpConversionPattern { - public: - explicit JitRtCallToJitRtCompileAndExecuteConversion(MLIRContext *context) - : OpConversionPattern(context) {} - - LogicalResult matchAndRewrite( - tfrt::jitrt::CallOp call, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Convert operands to fallback tensors. - llvm::SmallVector fallback_operands; - if (failed(tfrt_compiler::ConvertFallbackOperands( - call, kJitRtDevice, adaptor.getOperands(), &fallback_operands, - rewriter))) - return rewriter.notifyMatchFailure(call, "failed to convert operand"); - - // tf_jitrt.fallback.execute always produces fallback tensors. - llvm::SmallVector result_types( - call->getNumResults(), - rewriter.getType()); - - // Replace jitrt.call operation with a tf_jitrt.fallback.execute operation. - rewriter.replaceOpWithNewOp( - call, result_types, call.getCallee(), fallback_operands, kJitRtDevice); - - return success(); - } -}; - -// Helper function for inserting TFRT JitRt dialect conversions. -void TfrtJitRtStubImpl::PopulateJitRtConversionPatterns( - mlir::ConversionTarget *target, MLIRContext *context, - RewritePatternSet *patterns, CoreRTConverter *corert_converter) { - target->addLegalDialect(); - target->addIllegalDialect(); - // Lower jitrt.call to the pair of compile and execute operations. - patterns->add(context); -} - -mlir::Value TfrtJitRtStubImpl::CreateJitRtFallbackCompileKernel( - mlir::OpBuilder &builder, mlir::ModuleOp module, mlir::Value chain_value) { - // Pre-compile all JIT compiled kernels found in the module. - llvm::SmallVector compiled; - - // A set SymbolRef attributes referencing compiled kernels. - llvm::DenseSet kernels; - - // Compile all kernels in parallell. - module.walk([&](tf_jitrt::FallbackExecuteOp execute) { - // Do not compiled the same kernel multiple times. - if (kernels.contains(execute.getKernel())) return; - - auto compile = builder.create( - execute.getLoc(), builder.getType(), - execute.getKernel(), execute.getDevice()); - compiled.push_back(compile.getResult()); - kernels.insert(compile.getKernel()); - }); - - // Wait for the compilation completion before returning from init function. - if (!compiled.empty()) { - // Do not forget to wait for the fallback kernels initialization. - compiled.insert(compiled.begin(), chain_value); - chain_value = builder.create( - module.getLoc(), builder.getType(), - compiled); - } - - return chain_value; -} - -// -------------------------------------------------------------------------- // -// Outline tf_device.cluster operation regions into functions in the nested -// modules and replaces all cluster operations with jitrt.call operations. -// -------------------------------------------------------------------------- // - -class OutlineJitRtClustersPass - : public PassWrapper> { - public: - llvm::StringRef getArgument() const final { - return "tf-outline-jitrt-cluster"; - } - llvm::StringRef getDescription() const final { - return "Outlines `tf_device.cluster` operations into functions and " - "replaces them with `jitrt.call` operations."; - } - - void runOnOperation() override; - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } - - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OutlineJitRtClustersPass) - - private: - struct CompiledModule { - ModuleOp module; - func::FuncOp entrypoint; - llvm::SetVector operands; - }; - - // Creates a nested module with a single function that will be compiled into - // the kernel at runtime. - CompiledModule CreateCompiledModule(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table); - - // Update compiled module entrypoint signature with inferred operands - // constraints. - LogicalResult SetEntrypointConstraints(CompiledModule &compiled); - - // Outlines cluster operation regions into compiled modules, and replaces - // cluster operation with a jitrt.call operation. - LogicalResult OutlineClusterOp(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table); - - // Mapping from the outlined module string representation to the module itself - // and an entrypoint function. Used to deduplicate identical modules during - // the `tf_device.cluster` outlining. - llvm::StringMap> outlined_; -}; - -OutlineJitRtClustersPass::CompiledModule -OutlineJitRtClustersPass::CreateCompiledModule(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table) { - MLIRContext *ctx = cluster->getContext(); - Location loc = cluster.getLoc(); - - // Create a module that will hold compiled function and async wrappers. - // TODO(ezhulenev): Give better names to module and function. - auto compiled_module = ModuleOp::create(loc, {"kernel"}); - compiled_module->setAttr("tfrt.compiled", UnitAttr::get(ctx)); - compiled_module->setAttr( - "tfrt.max-arg-size", - IntegerAttr::get(IntegerType::get(ctx, 64), max_arg_size)); - - SymbolTable compiled_module_symbol_table(compiled_module); - - // Find out the cluster arguments and their types. - llvm::SetVector live_ins; - getUsedValuesDefinedAbove(cluster.getBody(), cluster.getBody(), live_ins); - - llvm::SmallVector operand_types; - operand_types.reserve(live_ins.size()); - for (Value v : live_ins) operand_types.emplace_back(v.getType()); - - // Create a function in the compiled module. - auto compiled_func_type = - FunctionType::get(ctx, operand_types, cluster->getResultTypes()); - auto compiled_func = func::FuncOp::create(loc, "compute", compiled_func_type); - compiled_module_symbol_table.insert(compiled_func); - - // Replace uses of live-in values within cluster region with block arguments. - Block *compiled_func_block = compiled_func.addEntryBlock(); - for (auto p : llvm::zip(live_ins, compiled_func_block->getArguments())) - replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), - cluster.getBody()); - - // Move all operations in cluster into compiled_func's entry block. - auto &cluster_body = cluster.GetBody().getOperations(); - compiled_func_block->getOperations().splice( - compiled_func_block->end(), cluster_body, cluster_body.begin(), - cluster_body.end()); - - // Replace `tf_device.return` terminator with `func.return` in the function - // body. - auto device_return = - cast(compiled_func_block->getTerminator()); - OpBuilder builder(device_return.getOperation()); - builder.create(device_return.getLoc(), - device_return.getOperands()); - device_return.erase(); - - // TODO(ezhulenev): MLIR doesn't define operation equivalence upstream yet, - // replace module printing with a more principled solution when available. - // Operations in the cluster can be in different order, however define the - // identical Tensorflow programs, with current approach we'll not be able - // to detect duplicates like this. - - // Remove location attribute attached to Tensorflow operations to be able to - // deduplicate compiled clusters with the same set of operations. - // - // TODO(ezhulenev): Figure out how to propagate locations for error reporting, - // right now JitRt will ignore them anyway. - compiled_module.walk([](Operation *op) { op->removeAttr("_class"); }); - - // Serialize prepared module to string. - std::string serialized; - llvm::raw_string_ostream os(serialized); - compiled_module.print(os); - - // Try to find if identical module was already outlined. - auto it = outlined_.find(serialized); - - // Return identical module that was already outlined earlier. - if (it != outlined_.end()) { - compiled_module.erase(); // erase identical module - return {it->second.first, it->second.second, live_ins}; - } - - // Insert compiled module into the symbol table and assign it a unique name. - symbol_table->insert(compiled_module); - - // Cache unique module. - outlined_.insert({std::move(serialized), {compiled_module, compiled_func}}); - - return {compiled_module, compiled_func, live_ins}; -} - -LogicalResult OutlineJitRtClustersPass::SetEntrypointConstraints( - CompiledModule &compiled) { - func::FuncOp func = compiled.entrypoint; - - // Functions outlined from jitrt device clusters must have a single block. - assert(func.getBody().getBlocks().size() == 1 && "expected single block"); - - mlir::TFDevice::ClusteringPolicySet policies; - populateTfJitRtConstraintsPolicies(policies); - - // Infer constraints on the values defined in the entrypoint function - // (including function entry block arguments). - mlir::TFDevice::ValuesConstraintSet constraints; - if (failed(mlir::TFDevice::PropagateValuesConstraints( - func.getBody(), policies, constraints, /*resolve=*/true))) - return failure(); - - // Annotate arguments with inferred constraints. - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - if (auto constraint = constraints.GetConstraint(func.getArgument(i))) { - auto constraint_name = mlir::StringAttr::get( - &getContext(), llvm::formatv("{0}", *constraint).str()); - func.setArgAttr(i, "rt.constraint", constraint_name); - } - } - - return success(); -} - -LogicalResult OutlineJitRtClustersPass::OutlineClusterOp( - tf_device::ClusterOp cluster, int64_t max_arg_size, - SymbolTable *symbol_table) { - Location loc = cluster->getLoc(); - OpBuilder builder(cluster); - - CompiledModule compiled_module = - CreateCompiledModule(cluster, max_arg_size, symbol_table); - func::FuncOp compiled_func = compiled_module.entrypoint; - - // Add constraints to the entrypoint arguments. - if (failed(SetEntrypointConstraints(compiled_module))) return failure(); - - // Replace device cluster with a jitrt.call operation. - auto module_name = *compiled_module.module.getSymName(); - auto func_name = compiled_func.getSymName(); - auto func_flat_ref = - mlir::SymbolRefAttr::get(builder.getContext(), func_name); - auto func_ref = mlir::SymbolRefAttr::get(builder.getContext(), module_name, - {func_flat_ref}); - - auto cluster_func_op = builder.create( - loc, cluster.getResultTypes(), func_ref, - compiled_module.operands.getArrayRef()); - - cluster.replaceAllUsesWith(cluster_func_op); - cluster.erase(); - - return success(); -} - -void OutlineJitRtClustersPass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable symbol_table(module); - - // Keep track of the maximum argument size for each function with tf_device - // cluster operations in the function body. We need to pass it to the compiled - // module to correctly compute its cost later. - llvm::DenseMap max_arg_size_map; - - auto get_max_arg_size = [&](mlir::func::FuncOp func) -> int64_t { - auto it = max_arg_size_map.find(func); - if (it != max_arg_size_map.end()) return it->second; - return max_arg_size_map[func] = tf_jitrt::GetMaxArgSize(func); - }; - - OpBuilder builder(module.getContext()); - auto result = module.walk([&](tf_device::ClusterOp cluster) -> WalkResult { - // Ensure that cluster was formed for TFRT JIT compilation. - auto policy = cluster->getAttr("policy").dyn_cast_or_null(); - if (!policy || policy.getValue() != "tfrt.auto-fusion") - return WalkResult::advance(); - - // Get the maximum argument size of the parent function. - mlir::func::FuncOp parent_func = - cluster->getParentOfType(); - int64_t max_arg_size = get_max_arg_size(parent_func); - - if (failed(OutlineClusterOp(cluster, max_arg_size, &symbol_table))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - - if (result.wasInterrupted()) { - module->emitError("Failed to outline tf_device.cluster operations"); - signalPassFailure(); - } -} - -std::unique_ptr CreateOutlineJitRtClustersPass() { - return std::make_unique(); -} - -void TfrtJitRtStubImpl::AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm) { - // Outline auto-fusion clusters into tf_device.cluster_operations and then - // convert them to functions. We currently support only tfrt fallback tensors - // as operands, so we disable these passes if we can have native ops after - // lowering. - pm.addNestedPass(CreateTfJitRtClusteringPass( - options.auto_fusion_oplist, options.auto_fusion_min_cluster_size)); - - // Sink small constants into the outlined clusters to reduce the number of - // arguments for each of the execute operations. - auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster, - mlir::ElementsAttr value) -> bool { - // Ensure that cluster was formed for TFRT JIT compilation. - auto policy = cluster->getAttr("policy").dyn_cast_or_null(); - if (!policy || policy.getValue() != "tfrt.auto-fusion") return false; - - // Check that TF->JitRt compiler supports constant compilation. - return mlir::succeeded(IsCompilableConstant(value)); - }; - - pm.addNestedPass( - mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const)); - - // Outline formed JIT compiled device clusters into function. - pm.addPass(CreateOutlineJitRtClustersPass()); -} - -mlir::PassRegistration tf_outline_jitrt_cluster_pass( - CreateOutlineJitRtClustersPass); - -const bool kUnused = - (RegisterTfrtJitRtStub(std::make_unique()), true); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc deleted file mode 100644 index 1bde6382c79bdc..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" - -#include -#include -#include - -namespace tensorflow { -namespace { - -class TfrtJitRtStubRegistry { - public: - TfrtJitRtStubRegistry() : stub_(std::make_unique()) {} - - void Register(std::unique_ptr stub) { - stub_ = std::move(stub); - } - - TfrtJitRtStub &Get() { return *stub_; } - - private: - std::unique_ptr stub_; -}; - -TfrtJitRtStubRegistry &GetGlobalTfrtJitRtStubRegistry() { - static auto *const stub = new TfrtJitRtStubRegistry; - return *stub; -} - -} // namespace - -void RegisterTfrtJitRtStub(std::unique_ptr stub) { - GetGlobalTfrtJitRtStubRegistry().Register(std::move(stub)); -} - -void RegisterJitRtDialects(mlir::DialectRegistry ®istry) { - GetGlobalTfrtJitRtStubRegistry().Get().RegisterJitRtDialects(registry); -} - -// Helper function for inserting TFRT JitRt dialect conversions. -void PopulateJitRtConversionPatterns(mlir::ConversionTarget *target, - mlir::MLIRContext *context, - mlir::RewritePatternSet *patterns, - CoreRTConverter *corert_converter) { - GetGlobalTfrtJitRtStubRegistry().Get().PopulateJitRtConversionPatterns( - target, context, patterns, corert_converter); -} - -mlir::Value CreateJitRtFallbackCompileKernel(mlir::OpBuilder &builder, - mlir::ModuleOp module, - mlir::Value chain_value) { - return GetGlobalTfrtJitRtStubRegistry() - .Get() - .CreateJitRtFallbackCompileKernel(builder, module, chain_value); -} - -void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm) { - GetGlobalTfrtJitRtStubRegistry().Get().AddTfrtJitRtPasses(options, pm); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h deleted file mode 100644 index d9c00c4d376909..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ -#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ - -#include - -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" - -namespace tensorflow { - -class TfrtJitRtStub { - public: - virtual ~TfrtJitRtStub() = default; - - virtual void RegisterJitRtDialects(mlir::DialectRegistry ®istry) {} - - virtual void PopulateJitRtConversionPatterns( - mlir::ConversionTarget *target, mlir::MLIRContext *context, - mlir::RewritePatternSet *patterns, CoreRTConverter *corert_converter) {} - - virtual mlir::Value CreateJitRtFallbackCompileKernel( - mlir::OpBuilder &builder, mlir::ModuleOp module, - mlir::Value chain_value) { - return chain_value; - } - - virtual void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm) {} -}; - -void RegisterTfrtJitRtStub(std::unique_ptr stub); - -void RegisterJitRtDialects(mlir::DialectRegistry ®istry); - -// Helper function for inserting TFRT JitRt dialect conversions. -void PopulateJitRtConversionPatterns(mlir::ConversionTarget *target, - mlir::MLIRContext *context, - mlir::RewritePatternSet *patterns, - CoreRTConverter *corert_converter); - -mlir::Value CreateJitRtFallbackCompileKernel(mlir::OpBuilder &builder, - mlir::ModuleOp module, - mlir::Value chain_value); - -void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h index a4c62f8bf20a2e..0a1209f457be7e 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h @@ -100,11 +100,6 @@ struct TfrtPipelineOptions llvm::cl::desc("If true, target GPU compiler passes."), llvm::cl::init(false)}; - // TODO(b/260915352): Remove the flag and default to using bridge. - Option use_bridge_for_gpu{ - *this, "use-bridge-for-gpu", - llvm::cl::desc("If true, GPU bridge is used."), llvm::cl::init(false)}; - Option func_use_fallback_tensor{ *this, "func-use-fallback-tensor", llvm::cl::desc( @@ -153,26 +148,6 @@ struct TfrtPipelineOptions llvm::cl::desc("If true, streams with inter data depenedencies will be " "preferred to be merged for inline execution."), llvm::cl::init(false)}; - - // A set of flags to control auto-fusion: automatic clustering of Tensorflow - // operations and compiling outlined regions using MLIR based compilation - // stack. - // - // WARNING: These flags are experimental and are intended for manual testing - // of different auto-fusion strategies. They will be removed in the future. - - ListOption auto_fusion_oplist{ - *this, "auto-fusion-oplist", - llvm::cl::desc("A list of Tensorflow operations to cluster together for " - "JIT compilation. Alternatively use 'tier1', ..., 'all' " - "to allow clustering for all operations included in the " - "given clustering tier.")}; - - Option auto_fusion_min_cluster_size{ - *this, "auto-fusion-min-cluster-size", - llvm::cl::desc("Minimum size of the cluster that should be outlined for " - "compilation"), - llvm::cl::init(2)}; }; } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 6045d40d56692a..4ff4a47feaaa15 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -166,7 +166,13 @@ Status ConvertTfMlirToRuntimeExecutable( } } - if (options.device_target == TfrtDeviceInfraTarget::kTpurt) { + if (options.backend_compiler != nullptr) { + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("tf_dialect_before_backend_compile", module); + } + TF_RETURN_IF_ERROR( + options.backend_compiler->CompileTensorflow(model_context, module)); + } else if (options.device_target == TfrtDeviceInfraTarget::kTpurt) { VLOG(1) << "Running MLIR TPU bridge for tpurt"; if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("tpu_bct_conversion_before", module); @@ -198,8 +204,7 @@ Status ConvertTfMlirToRuntimeExecutable( return diag_handler.Combine(absl::InternalError( "Failed to process TPUPartitionedCallOp for fallback execution")); } - } else if (options.device_target == TfrtDeviceInfraTarget::kGpu && - options.use_bridge_for_gpu) { + } else if (options.device_target == TfrtDeviceInfraTarget::kGpu) { TF_RETURN_IF_ERROR(mlir::TF::RunTFXLABridge(module)); // GPU XLA clusters are wrapped in functions, which could be transformed by @@ -212,12 +217,6 @@ Status ConvertTfMlirToRuntimeExecutable( TF_RETURN_IF_ERROR(fallback_state->AddFunctionDef(func_def)); } } - } else if (options.backend_compiler != nullptr) { - if (VLOG_IS_ON(1)) { - tensorflow::DumpMlirOpToFile("tf_dialect_before_backend_compile", module); - } - TF_RETURN_IF_ERROR( - options.backend_compiler->CompileTensorflow(model_context, module)); } if (VLOG_IS_ON(1)) { @@ -300,7 +299,6 @@ std::unique_ptr GetTfrtPipelineOptions( (options.device_target == TfrtDeviceInfraTarget::kTpurt); pipeline_options->target_gpu = (options.device_target == TfrtDeviceInfraTarget::kGpu); - pipeline_options->use_bridge_for_gpu = options.use_bridge_for_gpu; pipeline_options->tpu_fuse_ops = options.tpu_fuse_ops; pipeline_options->use_tpu_host_allocator_for_inputs = options.use_tpu_host_allocator_for_inputs; @@ -312,9 +310,6 @@ std::unique_ptr GetTfrtPipelineOptions( pipeline_options->func_use_fallback_tensor = true; pipeline_options->enable_while_parallel_iterations = options.enable_while_parallel_iterations; - pipeline_options->auto_fusion_oplist = options.auto_fusion_oplist; - pipeline_options->auto_fusion_min_cluster_size = - options.auto_fusion_min_cluster_size; pipeline_options->cost_threshold = options.cost_threshold; pipeline_options->upper_cost_threshold = options.upper_cost_threshold; pipeline_options->merge_inter_dependent_streams = diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc index 1e4a81d0d0cc03..3bfb5d853a9a97 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc @@ -57,10 +57,6 @@ std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options) { << ", hoist_invariant_ops = " << options.hoist_invariant_ops << ", enable_while_parallel_iterations = " << options.enable_while_parallel_iterations - << ", auto_fusion_oplist = [" - << absl::StrJoin(options.auto_fusion_oplist, ",") << "]" - << ", auto_fusion_min_cluster_size = " - << options.auto_fusion_min_cluster_size << ", cost_threshold = " << options.cost_threshold << ", upper_cost_threshold = " << options.upper_cost_threshold << ", merge_inter_dependent_streams = " diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index 23ef81be002a27..619f89cfa83d71 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -125,20 +125,6 @@ struct TfrtCompileOptions { // basis. This is currently experimental. bool enable_while_parallel_iterations = false; - // A set of flags to control auto-fusion: automatic clustering of Tensorflow - // operations and compiling outlined regions using MLIR based compilation - // stack. - // - // WARNING: These flags are experimental and are intended for manual testing - // of different auto-fusion strategies. They will be removed in the future. - - // A list of Tensorflow operations that are supported by auto-fusion - // clustering and compilation (e.g. tf.Tanh). - std::vector auto_fusion_oplist; - - // Minimum size of the cluster to be compiled at runtime. - int auto_fusion_min_cluster_size = 2; - // The cost threshold to decide whether a sequence of operations is cheap, and // then whether it can be executed inline. If the cost is smaller than the // threshold, it will be considered as cheap operations. Since the cost must @@ -161,10 +147,6 @@ struct TfrtCompileOptions { // Whether to compile to sync TFRT dialect. bool compile_to_sync_tfrt_dialect = false; - - // Whether to use bridge for GPU. - // TODO(b/260915352): Remove the flag and default to using bridge. - bool use_bridge_for_gpu = false; }; std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index e4892ba1d2e3ce..05dbdcc437d562 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +18,7 @@ td_library( "tf_framework_ops.td", "tf_status.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:AllocationOpInterfaceTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -29,7 +29,7 @@ td_library( gentbl_cc_library( name = "tf_framework_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -55,7 +55,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_status_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-enum-decls"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 3aeaca1a5a1b48..1f29b8555cab9f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -7,7 +7,7 @@ load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -19,7 +19,7 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -31,7 +31,7 @@ cc_library( name = "tf_framework_legalize_to_llvm", srcs = ["tf_framework_legalize_to_llvm.cc"], hdrs = ["rewriters.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":utils", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", @@ -47,7 +47,7 @@ cc_library( name = "bufferize", srcs = ["bufferize.cc"], hdrs = ["rewriters.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//mlir:ArithDialect", @@ -66,7 +66,7 @@ cc_library( name = "embed_tf_framework", srcs = ["embed_tf_framework.cc"], hdrs = ["rewriters.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//mlir:ControlFlowDialect", @@ -79,7 +79,7 @@ cc_library( gentbl_cc_library( name = "kernel_gen_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [( [ "-gen-pass-decls", @@ -187,7 +187,7 @@ cc_library( "tf_to_jit_invocations.cc", ], hdrs = ["passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":bufferize", # buildcleaner: keep ":embed_tf_framework", # buildcleaner: keep diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index 586204d9594062..33eb74dc1f9560 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -3,7 +3,7 @@ # https://developer.mlplatform.org/w/tosa/ # https://github.com/llvm/llvm-project/blob/main/mlir/docs/Dialects/TOSA.md -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") # TODO: Tighten visibility once targets are at the right granularity. @@ -35,12 +35,12 @@ filegroup( srcs = [ "@llvm-project//mlir:TosaDialectTdFiles", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) gentbl_cc_library( name = "tosa_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -63,7 +63,7 @@ cc_library( "transforms/passes.h", "transforms/passes.h.inc", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//mlir:FuncDialect", @@ -82,7 +82,7 @@ cc_library( "transforms/legalize_common.h", "transforms/legalize_utils.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", @@ -105,7 +105,7 @@ cc_library( gentbl_cc_library( name = "tosa_legalize_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -135,7 +135,7 @@ cc_library( "tf_passes.h", "transforms/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [":friends"], deps = [ ":legalize_common", @@ -158,7 +158,7 @@ cc_library( gentbl_cc_library( name = "tosa_legalize_tfl_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -194,7 +194,7 @@ cc_library( "tfl_passes.h", "transforms/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [":friends"], deps = [ ":legalize_common", @@ -228,7 +228,7 @@ cc_library( "tf_tfl_passes.h", "transforms/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [":friends"], deps = [ ":legalize_common", diff --git a/tensorflow/compiler/mlir/tosa/g3doc/legalization.md b/tensorflow/compiler/mlir/tosa/g3doc/legalization.md index d09a96f178bbaf..8da389f17e0b84 100644 --- a/tensorflow/compiler/mlir/tosa/g3doc/legalization.md +++ b/tensorflow/compiler/mlir/tosa/g3doc/legalization.md @@ -140,7 +140,7 @@ vector get_padding_values_from_pad_type(tensorflow::Padding padding, tens int64 op_size, pad_before_tf, pad_after_tf; - tensorflow::GetWindowedOutputSizeVerboseV2(input_type.shape[ifm_dim], filter_type.shape[filter_dim], + tensorflow::GetWindowedOutputSizeVerbose(input_type.shape[ifm_dim], filter_type.shape[filter_dim], dim_dilation, dim_stride, padding, // Outputs &op_size, &pad_before_tf, &pad_after_tf); diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index add1d1bc541259..68e68345ce7436 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -2779,3 +2779,28 @@ func.func @test_imag_non_complex(%arg0: tensor<1x8x9xf32>) -> (tensor<1x8x9xf32> %0 = "tfl.imag"(%arg0) {} : (tensor<1x8x9xf32>) -> tensor<1x8x9xf32> return %0 : tensor<1x8x9xf32> } + +// ----- + +// CHECK-LABEL: test_squared_difference_qi8 +// CHECK-DAG: %[[VAR0:.*]] = "tosa.rescale"(%arg0) +// CHECK-DAG: %[[VAR1:.*]] = "tosa.rescale"(%arg1) +// CHECK-DAG: %[[VAR2:.*]] = "tosa.sub"(%[[VAR0]], %[[VAR1]]) +// CHECK-DAG: %[[VAR3:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR2]]) <{shift = 0 : i32}> : +// CHECK-DAG: %[[VAR4:.*]] = "tosa.rescale"(%[[VAR3]]) +// CHECK: return %[[VAR4]] +func.func @test_squared_difference_qi8(%arg0: tensor<1x197x768x!quant.uniform>, %arg1: tensor<1x197x1x!quant.uniform>) -> tensor<1x197x768x!quant.uniform> { + %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x197x768x!quant.uniform>, tensor<1x197x1x!quant.uniform>) -> tensor<1x197x768x!quant.uniform> + func.return %0 : tensor<1x197x768x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_squared_difference_f32 +// CHECK-DAG: %[[VAR0:.*]] = "tosa.sub"(%arg0, %arg1) +// CHECK-DAG: %[[VAR1:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR0]]) <{shift = 0 : i32}> : +// CHECK: return %[[VAR1]] +func.func @test_squared_difference_f32(%arg0: tensor<1x197x768xf32>, %arg1: tensor<1x197x1xf32>) -> tensor<1x197x768xf32> { + %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x197x768xf32>, tensor<1x197x1xf32>) -> tensor<1x197x768xf32> + func.return %0 : tensor<1x197x768xf32> +} diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index b72a54a64b4058..d8785910105ea2 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -583,6 +583,78 @@ std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, return std::nullopt; } + bool x_is_qtype = + x_type.getElementType().isa(); + bool y_is_qtype = + y_type.getElementType().isa(); + bool result_is_qtype = + result_type.getElementType().isa(); + + if (x_is_qtype != result_is_qtype || y_is_qtype != result_is_qtype) { + (void)rewriter.notifyMatchFailure( + op, + "input/output tensor should all be in FP32, INT32 or quantized INT8"); + return std::nullopt; + } + + // If the output is I8 then we need to rescale to I32 + // Then scale back to I8 + if (result_is_qtype) { + auto x_qtype = + x_type.getElementType().cast(); + auto y_qtype = + y_type.getElementType().cast(); + auto result_qtype = + result_type.getElementType().cast(); + + uint32_t result_bits = result_qtype.getStorageTypeIntegralWidth(); + + if (result_bits == 8) { + ShapedType rescale_type = result_type.clone(rewriter.getI32Type()); + + // We need to make sure the inputs are rescaled correctly + // Following the behaviour defined here lite/kernels/squared_difference.cc + double in_x_scale = x_qtype.getScale(); + double in_y_scale = y_qtype.getScale(); + double result_scale = result_qtype.getScale(); + + double twice_max_input_scale = 2.0 * std::max(in_x_scale, in_y_scale); + + const int32_t LEFT_SHIFT = 7; + + double x_rescale_scale = in_x_scale / twice_max_input_scale; + double y_rescale_scale = in_y_scale / twice_max_input_scale; + double output_rescale_scale = + (twice_max_input_scale * twice_max_input_scale) / + ((static_cast(1 << LEFT_SHIFT * 2)) * result_scale); + + Value x_scaled = buildRescaleToInt32( + rewriter, op, x, + x_rescale_scale * static_cast(1 << LEFT_SHIFT), + x_qtype.getZeroPoint()); + Value y_scaled = buildRescaleToInt32( + rewriter, op, y, + y_rescale_scale * static_cast(1 << LEFT_SHIFT), + y_qtype.getZeroPoint()); + + auto sub_op = CreateOpAndInfer( + rewriter, op->getLoc(), rescale_type, x_scaled, y_scaled); + auto mul_op = CreateOpAndInfer( + rewriter, op->getLoc(), rescale_type, sub_op.getResult(), + sub_op.getResult(), 0); + + // Convert the operator back to the original type + return buildRescaleFromInt32(rewriter, op, result_type, mul_op, + output_rescale_scale, + result_qtype.getZeroPoint()); + } + + (void)rewriter.notifyMatchFailure( + op, "Only FP32, INT32 or quantized INT8 is supported"); + return std::nullopt; + } + + // This will cover FP32/FP16/INT32 legalization auto sub_op = CreateOpAndInfer(rewriter, op->getLoc(), result_type, x, y); return CreateOpAndInfer(rewriter, op->getLoc(), result_type, diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index eb4edcd9c47080..39c3ec9b5a5eb5 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -528,7 +528,7 @@ bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, ip_size = ip_size < 0 ? f_size * dim_dilation : ip_size; int64_t op_size, pad_before_tf, pad_after_tf; // Complains if using int64_T - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( ip_size, f_size, dim_dilation, dim_stride, tf_pad, &op_size, &pad_before_tf, &pad_after_tf); if (!status.ok()) return false; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 025680547f5988..9af16abb6f5618 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -535,6 +535,7 @@ tf_xla_py_strict_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "noasan", # Timed out on 2023-07-12 ], deps = [ ":xla_test", @@ -744,6 +745,10 @@ tf_xla_py_strict_test( name = "eager_test", size = "medium", srcs = ["eager_test.py"], + # copybara:uncomment_begin + # #TODO(b/287111047): Remove once the bug is fixed. + # disable_tpu_tfrt = True, + # copybara:uncomment_end enable_mlir_bridge = False, python_version = "PY3", tags = [ @@ -1112,6 +1117,7 @@ tf_xla_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:nn_ops", @@ -1701,6 +1707,7 @@ tf_xla_py_strict_test( python_version = "PY3", shard_count = 50, tags = [ + "no_rocm", "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "noasan", #times out @@ -1780,6 +1787,10 @@ tf_xla_py_strict_test( name = "while_test", size = "small", srcs = ["while_test.py"], + # copybara:uncomment_begin + # #TODO(b/291130193): Remove once the bug is fixed. + # disable_tpu_tfrt = True, + # copybara:uncomment_end enable_mlir_bridge = False, python_version = "PY3", tags = [ @@ -1929,6 +1940,7 @@ tf_xla_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -2799,6 +2811,36 @@ tf_xla_py_strict_test( ], ) +tf_xla_py_strict_test( + name = "xla_call_module_no_shape_assertions_check_test", + size = "small", + srcs = ["xla_call_module_test.py"], + disabled_backends = ["cpu_ondemand"], # cpu_ondemand overrides the TF_XLA_FLAGS + enable_mlir_bridge = False, + env = {"TF_XLA_FLAGS": "--tf_xla_call_module_disabled_checks=shape_assertions"}, + main = "xla_call_module_test.py", + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + use_xla_device = False, # Uses tf.function(jit_compile=True) + deps = [ + ":xla_test", + "//tensorflow/compiler/mlir/stablehlo", + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", + ], +) + tf_xla_py_strict_test( name = "bincount_op_test", size = "small", diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 3a7e22c02e54c5..bb760360687d9e 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops @@ -36,7 +37,7 @@ def NHWCToNCHW(input_tensor): Returns: the converted tensor or a shape array """ - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor.Tensor): return array_ops.transpose(input_tensor, [0, 3, 1, 2]) else: return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]] @@ -51,7 +52,7 @@ def NCHWToNHWC(input_tensor): Returns: the converted tensor or a shape array """ - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor.Tensor): return array_ops.transpose(input_tensor, [0, 2, 3, 1]) else: return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]] diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 250dcc45fe14de..bbadb955356e0f 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -23,7 +23,7 @@ from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import function -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -48,7 +48,7 @@ def _assertOpOutputMatchesExpected(self, op, args, expected): ] feeds = {placeholders[i]: args[i] for i in range(0, len(args))} output = op(*placeholders) - if isinstance(output, ops.Tensor): + if isinstance(output, tensor.Tensor): output = [output] results = session.run(output, feeds) diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index ea0231c356f5da..90c400b9e461a6 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -432,6 +432,203 @@ def f(x): 'arguments, but it has only 1 total arguments'): self._assertOpOutputMatchesExpected(f, (x,), (x,)) + def test_shape_assertion_success(self): + x = np.ones((3, 5), dtype=np.int32) + res = np.int32(x.shape[0]) + + def f(x): # x: f32[b, 5] and b = 3 + # return x.shape[0] + module, version = serialize(""" +module @jit_fun.1 { + func.func public @main(%arg1: tensor) -> tensor { + %b = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %3 = stablehlo.constant dense<3> : tensor + %ok = stablehlo.compare EQ, %b, %3, SIGNED : (tensor, tensor) -> tensor + stablehlo.custom_call @shape_assertion(%ok) { + error_message = "The error message", + has_side_effect = true + } : (tensor) -> () + return %b : tensor + } + +} +""") + return xla.call_module([x,], version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()],) + + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + def test_shape_assertion_failure(self): + x = np.ones((3, 5), dtype=np.int32) + res = np.int32(x.shape[0]) + + def f(x): # x: f32[b, 5] and b = 3, with a constraint b == 4. + # return x.shape[0] + module, version = serialize(""" +module @jit_fun.1 { + func.func public @main(%arg1: tensor) -> tensor { + %b = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %4 = stablehlo.constant dense<4> : tensor + %ok = stablehlo.compare EQ, %b, %4, SIGNED : (tensor, tensor) -> tensor + stablehlo.custom_call @shape_assertion(%ok, %b, %4) { + error_message = "Expecting {0} == {1}", + has_side_effect = true + } : (tensor, tensor, tensor) -> () + return %b : tensor + } +} +""") + return xla.call_module([x,], version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()],) + + # This test runs as part of two targets, with and without + # disabling shape_assertions. + disabled_shape_assertions_check = ( + '--tf_xla_call_module_disabled_checks=shape_assertions' + in os.getenv('TF_XLA_FLAGS', '')) + if disabled_shape_assertions_check: + # No error even though the constraint is false. + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + else: + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Expecting 3 == 4'): + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + def test_invalid_shape_assertion(self): + arg_i1 = np.bool_(True) + arg_i32 = np.int32(2) + res = arg_i32 + + # This test runs as part of two targets, with and without + # disabling shape_assertions. + disabled_shape_assertions_check = ( + '--tf_xla_call_module_disabled_checks=shape_assertions' + in os.getenv('TF_XLA_FLAGS', '')) + if disabled_shape_assertions_check: + self.skipTest('Test is N/A when shape_assertions are disabled') + + subtest_count = 1 + def one_subtest(error_msg: str, module_str: str): + def f(*args): + module, version = serialize(module_str) + return xla.call_module( + list(args), + version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()], + ) + + nonlocal subtest_count + subtest_count += 1 + with self.subTest(count=subtest_count, error_msg=error_msg): + with self.assertRaisesRegex(errors.InvalidArgumentError, error_msg): + self._assertOpOutputMatchesExpected(f, (arg_i1, arg_i32), (res,)) + + one_subtest( + 'expects assert_what .* to be a constant of type tensor', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense<0> : tensor + stablehlo.custom_call @shape_assertion(%ok) { + error_message = "Some error", + has_side_effect = true + } : (tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects static assert_what', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + stablehlo.custom_call @shape_assertion(%arg_i1) { + error_message = "Some error", + has_side_effect = true + } : (tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects has_side_effect=true', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense : tensor + stablehlo.custom_call @shape_assertion(%ok) { + error_message = "Some error", + has_side_effect = false + } : (tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects error_message .* Found specifier {0}', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense : tensor + stablehlo.custom_call @shape_assertion(%ok) { + error_message = "Some error {0}", + has_side_effect = true + } : (tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects static error_message_input', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense : tensor + stablehlo.custom_call @shape_assertion(%ok, %arg_i32) { + error_message = "Some error {0}", + has_side_effect = true + } : (tensor, tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects error_message_input .* to be a constant of type tensor', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense : tensor + %c = stablehlo.constant dense<2.0> : tensor + stablehlo.custom_call @shape_assertion(%ok, %c) { + error_message = "Some error {0}", + has_side_effect = true + } : (tensor, tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + def test_dynamic_iota(self): x = np.ones((3, 5), dtype=np.int32) res = np.arange(x.shape[0], dtype=np.int32) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 24a6dc43bd1096..96e07f337dee90 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -85,6 +85,7 @@ alias( "@local_config_tensorrt//:use_static_tensorrt": "@local_config_tensorrt//:tensorrt", "//conditions:default": ":tensorrt_stub", }), + visibility = ["//visibility:private"], ) tf_cuda_cc_test( diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 42b2acd6d27cec..dcf4d4880a3344 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -254,7 +254,11 @@ cc_library( ":xla_compiled_cpu_runtime_hdrs", ], copts = runtime_copts() + tf_openmp_copts(), - defines = ["EIGEN_NEON_GEBP_NR=4"], + defines = [ + "EIGEN_NEON_GEBP_NR=4", + # TODO(b/238649163): remove this once no longer necessary. + "EIGEN_USE_AVX512_GEMM_KERNELS=0", + ], features = [ "fully_static_link", "-parse_headers", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6d940eccedc72f..1c60ba5874746b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -394,6 +394,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", + "//tensorflow/compiler/xla/python:refine_polymorphic_shapes", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//tensorflow/tsl/platform:errors", diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 833efb34649950..242c022c892faf 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -292,7 +292,7 @@ StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, } int64_t unused_output_size; - TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( input_shape.dimensions(dim), filter_shape.dimensions(i), rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, &padding[i].first, &padding[i].second)); diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 49e80226786355..01e1d57e732b6a 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -149,7 +149,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { int64_t unused_output_size; OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( + ctx, GetWindowedOutputSizeVerbose( input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i], window_strides[i], padding_, &unused_output_size, &padding[i].first, &padding[i].second)); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index e265184ad2d42d..e7def0a7f8bab4 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project @@ -54,6 +55,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/python/refine_polymorphic_shapes.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tensorflow/tsl/platform/errors.h" @@ -84,16 +86,29 @@ constexpr int VERSION_START_SUPPORT_CALL_TF_GRAPH = 5; // mandates a non-empty `platforms` attribute. // Used in jax2tf since June 2023. constexpr int VERSION_START_SUPPORT_DISABLED_CHECKS = 6; +// Version 7 adds support for `stablehlo.shape_assertion` operations and +// for `shape_assertions` specified in `disabled_checks`. +// Used in JAX serialization since July 2023. +constexpr int VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7; constexpr int VERSION_MINIMUM_SUPPORTED = VERSION_START_STABLE_HLO_COMPATIBILITY; -constexpr int VERSION_MAXIMUM_SUPPORTED = VERSION_START_SUPPORT_DISABLED_CHECKS; +constexpr int VERSION_MAXIMUM_SUPPORTED = + VERSION_START_SUPPORT_SHAPE_ASSERTIONS; constexpr absl::string_view DISABLED_CHECK_PLATFORM = "platform"; bool IsPlatformCheckDisabled(absl::Span disabled_checks) { - return std::find(disabled_checks.begin(), disabled_checks.end(), - DISABLED_CHECK_PLATFORM) != disabled_checks.end(); + return llvm::is_contained(disabled_checks, DISABLED_CHECK_PLATFORM); +} + +constexpr absl::string_view DISABLED_CHECK_SHAPE_ASSERTIONS = + "shape_assertions"; + +bool IsShapeAssertionsCheckDisabled( + absl::Span loading_disabled_checks) { + return llvm::is_contained(loading_disabled_checks, + DISABLED_CHECK_SHAPE_ASSERTIONS); } // Computes a dimension value from the dim_arg specification. @@ -379,7 +394,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( auto arg = main_body.getArgument(i); arg.setType(static_array_input_types[i]); // If the argument is used by `func.return`, then we also need to - // update function result types. It's not great that we need this hack, + // update the function result types. It's not great that we need this hack, // but in the future when we have stablehlo.func, stablehlo.return, etc, // this will not be needed. // TODO(burmako): Once https://github.com/openxla/stablehlo/issues/425 is @@ -395,34 +410,14 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( if (VLOG_IS_ON(5)) { DumpMlirOpToFile("xla_call_module.after_refined_input_types", *module_); } + bool enable_shape_assertions = + (version_ >= VERSION_START_SUPPORT_SHAPE_ASSERTIONS && + !IsShapeAssertionsCheckDisabled(loading_disabled_checks_)); + TF_RETURN_IF_ERROR( + xla::RefinePolymorphicShapes(*module_, enable_shape_assertions)); - // Verify the module before running passes on it. - // If the module doesn't pass verification, all sorts of weirdness might - // happen if we run the pass manager. - { - mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - - if (failed(verify(*module_))) { - return absl::InvalidArgumentError( - absl::StrCat("Module verification failed: ", - diag_handler.ConsumeStatus().ToString())); - } - - mlir::PassManager pm(module_->getContext()); - applyTensorflowAndCLOptions(pm); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); - pm.addNestedPass( - mlir::stablehlo::createStablehloCanonicalizeDynamismPass()); - if (mlir::failed(pm.run(*module_))) { - return absl::InvalidArgumentError( - absl::StrCat("Module shape refinement failed: ", - diag_handler.ConsumeStatus().ToString())); - } - - if (VLOG_IS_ON(3)) { - DumpMlirOpToFile("xla_call_module.after_shape_refinement", *module_); - } + if (VLOG_IS_ON(3)) { + DumpMlirOpToFile("xla_call_module.after_shape_refinement", *module_); } return tsl::OkStatus(); } @@ -458,22 +453,22 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( module_ = mlir::parseSourceString(module_str, context_); } - std::vector loading_disabled_checks = disabled_checks; - loading_disabled_checks.insert( - loading_disabled_checks.end(), + loading_disabled_checks_ = disabled_checks; + loading_disabled_checks_.insert( + loading_disabled_checks_.end(), GetXlaCallModuleFlags()->disabled_checks.begin(), GetXlaCallModuleFlags()->disabled_checks.end()); if (!module_) { return absl::InvalidArgumentError("Cannot deserialize computation"); } - VLOG(3) << "Parsed serialized module (version " << version + VLOG(3) << "Parsed serialized module (version = " << version << ", platforms = [" << absl::StrJoin(platforms, ", ") << "], loading_platform = " << loading_platform << ", dim_args_spec = [" << absl::StrJoin(dim_args_spec_, ", ") << "], disabled_checks = [" << absl::StrJoin(disabled_checks, ", ") << "], loading_disabled_checks = [" - << absl::StrJoin(loading_disabled_checks, ", ") << "]), module = " + << absl::StrJoin(loading_disabled_checks_, ", ") << "]), module = " << DumpMlirOpToFile("xla_call_module.parsed", *module_); if (version < VERSION_MINIMUM_SUPPORTED) { @@ -493,7 +488,7 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( auto found_platform = std::find(platforms.begin(), platforms.end(), loading_platform); if (found_platform == platforms.end()) { - if (!IsPlatformCheckDisabled(loading_disabled_checks)) { + if (!IsPlatformCheckDisabled(loading_disabled_checks_)) { return absl::NotFoundError(absl::StrCat( "The current platform ", loading_platform, " is not among the platforms required by the module: [", @@ -559,35 +554,8 @@ tsl::Status XlaCallModuleLoader::ValidateDialect() { return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::ValidateStaticShapes() { - mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - bool moduleHasDynamicShapes = false; - - module_->walk([&](mlir::Operation *op) { - // It's sufficient to only check results because operands either come from - // results or from block arguments which are checked below. - auto hasDynamicShape = [](mlir::Value value) { - auto shaped_type = value.getType().dyn_cast(); - return shaped_type ? !shaped_type.hasStaticShape() : false; - }; - bool opHasDynamicShapes = false; - opHasDynamicShapes |= llvm::any_of(op->getResults(), hasDynamicShape); - for (mlir::Region ®ion : op->getRegions()) { - opHasDynamicShapes |= - llvm::any_of(region.getArguments(), hasDynamicShape); - } - if (opHasDynamicShapes) { - moduleHasDynamicShapes = true; - op->emitOpError() << "has dynamic shapes"; - } - }); - - if (moduleHasDynamicShapes) { - return absl::InvalidArgumentError( - absl::StrCat("Module has dynamic shapes: ", - diag_handler.ConsumeStatus().ToString())); - } - return tsl::OkStatus(); +absl::Status XlaCallModuleLoader::ValidateStaticShapes() { + return xla::ValidateStaticShapes(*module_); } absl::Status XlaCallModuleLoader::LowerModuleToMhlo() { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h index 54aaa6ae58f097..8d8c30f96fbca8 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -61,7 +61,7 @@ class XlaCallModuleLoader { // Validates that the module represents a statically-shaped StableHLO program, // otherwise all sorts of weirdness might happen in the HLO exporter which is // much easier to detect here. - tsl::Status ValidateStaticShapes(); + absl::Status ValidateStaticShapes(); // Lowers the StableHLO module to MHLO in place. absl::Status LowerModuleToMhlo(); @@ -97,6 +97,9 @@ class XlaCallModuleLoader { // a platform index arg. int platform_index_; std::vector dim_args_spec_; + // The disabled checks at loading time, including those from the + // disabled_checks attribute and the TF_XLA_FLAGS environment variable. + std::vector loading_disabled_checks_; mlir::func::FuncOp main_; }; diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index d3638443234033..4cc4845b60b7e7 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -663,12 +663,16 @@ def call_module( return res -# pylint: enable=g-doc-args -# pylint: enable=g-doc-return-or-yield +def call_module_maximum_supported_version(): + """Maximum version of XlaCallModule op supported. + See versioning details documentation for the XlaCallModule op at: + https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code + """ + return 7 -def call_module_maximum_supported_version(): - return 6 +# pylint: enable=g-doc-args +# pylint: enable=g-doc-return-or-yield def call_module_disable_check_platform(): diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD index a44c5735b5b8f9..d166b63e3fd9c1 100644 --- a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD +++ b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD @@ -5,7 +5,6 @@ load( "tsl_copts", "tsl_gpu_library", ) -load("//tensorflow/tsl:tsl.default.bzl", "tsl_gpu_cc_test") load( "//tensorflow/tsl/platform:build_config.bzl", "tf_additional_device_tracer_srcs", @@ -23,6 +22,10 @@ load( "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load( + "//tensorflow/compiler/xla:xla.bzl", + "xla_cc_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -107,10 +110,11 @@ tsl_gpu_library( ], ) -tsl_gpu_cc_test( +xla_cc_test( name = "cupti_error_manager_test", size = "small", srcs = ["cupti_error_manager_test.cc"], + copts = tf_profiler_copts() + tsl_copts(), tags = tf_cuda_tests_tags() + [ "gpu_cupti", "nomac", @@ -125,9 +129,7 @@ tsl_gpu_cc_test( ":cupti_wrapper", ":mock_cupti", "@com_google_absl//absl/memory", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/profiler/utils:time_utils", - "//tensorflow/tsl/profiler/backends/cpu:annotation_stack_impl", ]), ) diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index c0603c89de4feb..97ceef5dca9059 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -197,6 +197,9 @@ class ExecutableBuildOptions { } absl::string_view fdo_profile() const { return fdo_profile_; } + void set_fdo_profile(const std::string& fdo_profile) { + fdo_profile_ = fdo_profile; + } std::string* mutable_fdo_profile() { return &fdo_profile_; } // Returns a string representation of the build options, suitable for diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index f6d893ac845dca..e9d6e29ba31fb1 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -42,7 +42,30 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_gpu_asm_extra_flags(""); - opts.set_xla_gpu_use_runtime_fusion(true); + + // As of cudnn 8.9.0, enabling cudnn runtime fusion sometimes causes a + // situation where cudnn returns 0 algorithms for an otherwise-valid conv, + // causing compilation to fail. Examples of failing convs: + // + // // failing kLeakyRelu, b/290967578 + // (f16[2,256,768,16]{3,2,1,0}, u8[0]{0}) + // custom-call(f16[2,256,768,3]{3,2,1,0} %a, f16[16,3,3,3]{3,2,1,0} %b, + // f16[16]{0} %c), window={size=3x3 pad=1_1x1_1}, + // dim_labels=b01f_o01i->b01f, operand_precision={highest,highest}, + // custom_call_target="__cudnn$convBiasActivationForward", + // backend_config={"activation_mode":"kLeakyRelu","conv_result_scale":1, + // "side_input_scale":0,"leakyrelu_alpha":0.199951171875} + // + // // failing kRelu6, b/291011396 + // (f16[1,384,1024,32]{3,2,1,0}, u8[0]{0}) + // custom-call(f16[1,769,2049,3]{3,2,1,0} %a, f16[32,3,3,3]{3,2,1,0} %b, + // f16[32]{0} %c), window={size=3x3 stride=2x2}, dim_labels=b01f_o01i->b01f, + // operand_precision={highest,highest}, + // custom_call_target="__cudnn$convBiasActivationForward", + // backend_config={"activation_mode":"kRelu6","conv_result_scale":1, + // "side_input_scale":0,"leakyrelu_alpha":0} + opts.set_xla_gpu_use_runtime_fusion(false); + opts.set_xla_eliminate_hlo_implicit_broadcast(true); opts.set_xla_dump_hlo_as_html(false); opts.set_xla_dump_fusion_visualization(false); @@ -82,9 +105,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // TODO(b/258036887): Enable cuda_graph_level=2. Currently blocked by CUDA 12 // integration. opts.set_xla_gpu_cuda_graph_level(0); - opts.set_xla_gpu_cuda_graph_num_runs_to_instantiate(2); + opts.set_xla_gpu_cuda_graph_num_runs_to_instantiate(-1); opts.set_xla_gpu_enable_persistent_temp_buffers(false); - opts.set_xla_gpu_cuda_graph_min_graph_size(2); + opts.set_xla_gpu_cuda_graph_min_graph_size(5); opts.set_xla_gpu_cuda_graph_enable_concurrent_region(false); // Despite the name, fast min/max on GPUs does not seem to be any faster, and @@ -137,7 +160,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_triton_gemm(true); opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); - opts.set_xla_gpu_enable_triton_softmax_fusion(false); + opts.set_xla_gpu_enable_triton_softmax_fusion(true); + opts.set_xla_gpu_triton_fusion_level(1); // Moving reduce-scatter out of while loops can increase memory footprint, so // turning it off by default. @@ -145,7 +169,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_collective_inflation_factor(1); - opts.set_xla_gpu_enable_experimental_block_size(false); + opts.set_xla_gpu_enable_experimental_block_size(true); opts.set_xla_gpu_exhaustive_tiling_search(false); opts.set_xla_gpu_enable_priority_fusion(false); @@ -1130,6 +1154,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Forces any reductions during matrix multiplications to use the " "accumulator type and not the output type. The precision of the dot " "operation may not increase that much if there is output fusion.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_triton_fusion_level", + int32_setter_for(&DebugOptions::set_xla_gpu_triton_fusion_level), + debug_options->xla_gpu_triton_fusion_level(), + "Triton fusion level, higher levels mean more fused operations.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 83e21ca58de84f..aebf14c80eb0fe 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -44,6 +44,7 @@ channel_id)` | `operand` | `XlaOp` | Array to concatenate across | : : : replicas. : | `all_gather_dim` | `int64` | Concatenation dimension. | +| `shard_count` | `int64` | Size of each replica group. | | `replica_groups` | vector of vectors of | Groups between which the | : : `int64` : concatenation is performed. : | `channel_id` | optional `int64` | Optional channel ID for | diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD index 3b1817c2790e23..6f785714bcaf59 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD @@ -59,6 +59,8 @@ cc_library( deps = [ ":auto_sharding_strategy", "//tensorflow/compiler/xla:statusor", + "//tensorflow/tsl/platform:hash", + "//tensorflow/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_ortools//ortools/linear_solver", diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 6e0cb4b9ae85ff..c31315ea83d8a7 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -26,6 +27,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "tensorflow/tsl/platform/hash.h" +#include "tensorflow/tsl/platform/types.h" #include "ortools/linear_solver/linear_solver.h" #include "ortools/linear_solver/linear_solver.pb.h" #ifdef PLATFORM_GOOGLE @@ -104,6 +107,13 @@ void PrintLargestInstructions( } } +// Adds deterministic noise to the coefficient using the name & salt multiplier. +void AddSalt(const std::string& name, double saltiplier, double* coeff) { + if (saltiplier <= 0.0) return; + const tsl::uint64 hash = tsl::Hash64(name); // stable across runs & platforms + *coeff *= 1.0 + saltiplier * hash / std::numeric_limits::max(); +} + // We formulate the auto sharding process as the following ILP problem: // Variables: // s[i]: Sharding strategy one-hot vector. @@ -162,12 +172,8 @@ AutoShardingSolverResult CallORToolsSolver( #ifdef PLATFORM_GOOGLE if (solver->ProblemType() == operations_research::MPSolver::SAT_INTEGER_PROGRAMMING) { - // Set random_seed, interleave_search and share_binary_clauses for - // determinism, and num_workers for parallelism. - solver_parameter_str = absl::StrCat( - "share_binary_clauses:false,random_seed:1,interleave_" - "search:true,num_workers:", - num_workers); + // Set num_workers for parallelism. + solver_parameter_str = absl::StrCat("num_workers:", num_workers); solver->SetSolverSpecificParametersAsString(solver_parameter_str); } #endif @@ -206,8 +212,10 @@ AutoShardingSolverResult CallORToolsSolver( for (size_t j = 0; j < s[i].size(); ++j) { double accumulated_coefficient = solver->MutableObjective()->GetCoefficient(s[i][j]); + double coefficient = request.c[i][j] + request.d[i][j]; + AddSalt(absl::StrCat(i, "S", j), request.saltiplier, &coefficient); solver->MutableObjective()->SetCoefficient( - s[i][j], accumulated_coefficient + request.c[i][j] + request.d[i][j]); + s[i][j], accumulated_coefficient + coefficient); } } // Edge costs @@ -215,8 +223,10 @@ AutoShardingSolverResult CallORToolsSolver( for (size_t j = 0; j < e[i].size(); ++j) { double accumulated_coefficient = solver->MutableObjective()->GetCoefficient(e[i][j]); + double coefficient = request.r[i][j]; + AddSalt(absl::StrCat(i, "E", j), request.saltiplier, &coefficient); solver->MutableObjective()->SetCoefficient( - e[i][j], accumulated_coefficient + request.r[i][j]); + e[i][j], accumulated_coefficient + coefficient); } } @@ -460,6 +470,7 @@ AutoShardingSolverResult CallORToolsSolver( } // Return value + double unsalted_objective = 0.0; std::vector chosen_strategy(request.num_nodes, -1), e_val(num_edges, -1); for (int i = 0; i < request.num_nodes; ++i) { @@ -467,6 +478,7 @@ AutoShardingSolverResult CallORToolsSolver( // if lhs == 1 if (s[i][j]->solution_value() > 0.5) { chosen_strategy[i] = j; + unsalted_objective += request.c[i][j] + request.d[i][j]; break; } } @@ -476,11 +488,13 @@ AutoShardingSolverResult CallORToolsSolver( // if lhs == 1 if (e[i][j]->solution_value() > 0.5) { e_val[i] = j; + unsalted_objective += request.r[i][j]; break; } } } + LOG(INFO) << "Unsalted objective value: " << unsalted_objective; LOG(INFO) << "N = " << request.num_nodes; if (request.memory_budget < 0) { LOG(INFO) << "memory budget: -1"; @@ -492,7 +506,7 @@ AutoShardingSolverResult CallORToolsSolver( request.instruction_names); return AutoShardingSolverResult( std::make_tuple(std::move(chosen_strategy), std::move(e_val), - solver->Objective().Value()), + unsalted_objective), false); } diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 5225b91e143c43..88bc578bceec56 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -44,6 +44,7 @@ struct AutoShardingSolverRequest { std::vector instruction_names; std::optional solver_timeout_in_seconds; bool crash_at_infinity_costs_check = false; + double saltiplier = 0.0001; // Modifies each objective term by at most 0.01% }; struct AutoShardingSolverResult { diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc b/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc index 4c096e58b5ed1a..7fd242ca42c32f 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/logging.h" #include "tensorflow/tsl/platform/status.h" namespace xla { @@ -375,6 +377,22 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, return OkStatus(); } +HloInstruction* HloComputation::NextInstruction(HloInstruction* current) { + InstructionList::iterator instructions_it; + if (current == nullptr) { + instructions_it = instructions_.begin(); + } else { + auto it = instruction_iterators_.find(current); + CHECK(it != instruction_iterators_.end()); + instructions_it = it->second; + ++instructions_it; + } + if (instructions_it == instructions_.end()) { + return nullptr; + } + return instructions_it->get(); +} + void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, bool accept_different_shape) { // The shape of the root (ignoring layout) is an invariant of the computation @@ -407,25 +425,6 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, root_instruction_ = new_root_instruction; } -namespace { - -// Helper which builds a post order of the HLO call graph. -void ComputeComputationPostOrder(HloComputation* computation, - absl::flat_hash_set* visited, - std::vector* post_order) { - if (visited->insert(computation).second) { - for (auto* instruction : computation->instructions()) { - for (HloComputation* called_computation : - instruction->called_computations()) { - ComputeComputationPostOrder(called_computation, visited, post_order); - } - } - post_order->push_back(computation); - } -} - -} // namespace - void HloComputation::ComputeInstructionPostOrder( HloInstruction* root, const ChannelDependencies& channel_dependencies, absl::flat_hash_map& visited, @@ -583,21 +582,43 @@ std::vector HloComputation::MakeEmbeddedComputationsList() const { absl::flat_hash_set visited; std::vector post_order; - - // To avoid special handling of this computation, cast away const of - // 'this'. 'this' is immediately removed from the post order after - // construction. - // - // TODO(b/78350259): This violates const-correctness, since while the original - // computation is not returned, we still retrieve non-const computations from - // a const one. Consider also avoiding const for HloComputation, or review XLA - // for const-correctness of non-HloInstruction* types like this. - ComputeComputationPostOrder(const_cast(this), &visited, - &post_order); - - // We don't want to include this computation in the post order. - CHECK_EQ(this, post_order.back()); - post_order.pop_back(); + // The first element of the pair is the currently processed computation, the + // second is the instruction within the computation that is currently being + // processed. 'nullptr' for the instruction indicates that no instruction has + // been processed so far. + std::stack> st; + + // We cannot directly push (this, nullptr) to the stack, as the stack should + // contain only mutable computations. Also, we don't want to include the + // computation itself in the list of embedded computations. + for (auto* instruction : instructions()) { + auto process_called_computations = + [&](std::vector called_computations) { + // Put the called computations in reverse order onto the stack. + // Otherwise we don't match the recursive enumeration of + // computations, which processes the first called computation first. + absl::c_reverse(called_computations); + for (HloComputation* called_computation : called_computations) { + if (visited.insert(called_computation).second) { + st.emplace(called_computation, nullptr); + } + } + }; + process_called_computations(instruction->called_computations()); + while (!st.empty()) { + auto cur = st.top(); + st.pop(); + HloComputation* computation = cur.first; + HloInstruction* next_instruction = + computation->NextInstruction(cur.second); + if (next_instruction == nullptr) { + post_order.push_back(computation); + } else { + st.emplace(computation, next_instruction); + process_called_computations(next_instruction->called_computations()); + } + } + } return post_order; } @@ -1279,7 +1300,6 @@ void SortClonedInstructionUsersAndControlLists( const HloCloneContext& context, absl::FunctionRef replace, const HloComputation::InstructionList& sorted_instructions) { - using InstructionSorter = MappedPtrContainerSorter; auto instruction_mapper = [&context, replace](const HloInstruction* i) { return context.FindInstruction(replace(i)); }; diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_computation.h b/tensorflow/compiler/xla/hlo/ir/hlo_computation.h index d3aeedc3c7c133..1da5ee5109e37e 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_computation.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_computation.h @@ -782,6 +782,13 @@ class HloComputation { Status RemoveInstructionImpl(HloInstruction* instruction, bool ignore_safety_check); + // Finds the next instruction in the 'instructions_' list after 'current'. + // 'current' must either be nullptr or an instruction that is part of this + // computation. If it is nullptr, next_instruction returns the first + // instruction of the computation. Returns nullptr if there is no next + // instruction. + HloInstruction* NextInstruction(HloInstruction* current); + std::string name_; int64_t unique_id_; HloInstruction* root_instruction_; diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc index 48282e2db005a3..1f583a923fb6cc 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc @@ -1049,6 +1049,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->set_frontend_attributes(proto.frontend_attributes()); } + if (proto.has_statistics_viz()) { + instruction->set_statistics_viz(proto.statistics_viz()); + } + return std::move(instruction); } @@ -1761,6 +1765,7 @@ HloInstruction::CreateBroadcastSequence( broadcast->copy_sharding(operand); } broadcast->set_frontend_attributes(operand->frontend_attributes()); + broadcast->set_statistics_viz(operand->statistics_viz()); return broadcast; } // Do explicit broadcast for degenerate broadcast. @@ -1787,6 +1792,7 @@ HloInstruction::CreateBroadcastSequence( reshaped_operand->copy_sharding(operand); } reshaped_operand->set_frontend_attributes(operand->frontend_attributes()); + reshaped_operand->set_statistics_viz(operand->statistics_viz()); // Broadcast 'reshape' up to the larger size. auto broadcast = HloInstruction::CreateBroadcast( broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -1795,6 +1801,7 @@ HloInstruction::CreateBroadcastSequence( broadcast->copy_sharding(operand); } broadcast->set_frontend_attributes(operand->frontend_attributes()); + broadcast->set_statistics_viz(operand->statistics_viz()); return broadcast; } @@ -1878,6 +1885,7 @@ void HloInstruction::SetupDerivedInstruction( } derived_instruction->set_metadata(metadata_); derived_instruction->set_frontend_attributes(frontend_attributes_); + derived_instruction->set_statistics_viz(statistics_viz_); } bool HloInstruction::IsRoot() const { @@ -3518,6 +3526,12 @@ void HloInstruction::PrintExtraAttributes( printer->Append("}"); }); } + + if (!statistics_viz_.statistics().empty()) { + printer.Next([this](Printer* printer) { + AppendCat(printer, "statistics=", StatisticsVizToString(statistics_viz_)); + }); + } } std::vector HloInstruction::ExtraAttributesToString( @@ -3585,6 +3599,8 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_frontend_attributes() = frontend_attributes_; + *proto.mutable_statistics_viz() = statistics_viz_; + return proto; } @@ -4268,6 +4284,28 @@ std::string FrontendAttributesToString( absl::StrJoin(sorted_attributes, ",", formatter)); } +std::string StatisticsVizToString(const StatisticsViz& statistics_viz) { + // Statistics is either empty, or always starts with the index of the + // statistic that is rendered on the graph, followed by the statistics that + // are being tracked. The index is 0 based, starting from the first statistic + // being tracked. The index and statistics are within a comma-separated list + // of attribute=value pairs, + // e.g., statistics={visualizing_index=0, count_nan=100, count_inf=200}. + + if (statistics_viz.statistics().empty()) return "{}"; + + std::vector all_statistics(statistics_viz.statistics().begin(), + statistics_viz.statistics().end()); + + const auto formatter = [](std::string* out, const Statistic& item) { + absl::StrAppend(out, item.stat_name(), "=", item.stat_val()); + }; + return absl::StrFormat("{%s,%s}", + absl::StrCat("visualizing_index=", + statistics_viz.stat_index_to_visualize()), + absl::StrJoin(all_statistics, ",", formatter)); +} + std::string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = absl::c_any_of(padding.dimensions(), diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h index e0bafcbcf46322..443698af0219c2 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h @@ -1835,6 +1835,27 @@ class HloInstruction { return frontend_attributes_; } + void add_single_statistic(Statistic statistic) { + *statistics_viz_.add_statistics() = std::move(statistic); + } + + void set_stat_index_to_visualize(int64_t index) { + statistics_viz_.set_stat_index_to_visualize(index); + } + + bool has_statistics() const { return !statistics_viz_.statistics().empty(); } + + const Statistic& statistic_to_visualize() const { + return statistics_viz_.statistics().at( + statistics_viz_.stat_index_to_visualize()); + } + + void set_statistics_viz(StatisticsViz statistics_viz) { + statistics_viz_ = std::move(statistics_viz); + } + + const StatisticsViz& statistics_viz() const { return statistics_viz_; } + // Getter/setter for raw JSON-encoded backend config. Prefer the // functions above that deal in proto Messages where possible. const std::string& raw_backend_config_string() const { @@ -2434,6 +2455,10 @@ class HloInstruction { // z' = const(20), frontend_attributes={?} FrontendAttributes frontend_attributes_; + // Used to render an HLO graph when tracking the propagation desired values + // through it. + StatisticsViz statistics_viz_; + // String identifier for instruction. std::string name_; @@ -2468,6 +2493,7 @@ StatusOr StringToFusionKind( std::string PaddingConfigToString(const PaddingConfig& padding); std::string FrontendAttributesToString( const FrontendAttributes& frontend_attributes); +std::string StatisticsVizToString(const StatisticsViz& statistics_viz); std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm); std::string RandomDistributionToString(const RandomDistribution& distribution); std::string PrecisionToString(const PrecisionConfig::Precision& precision); diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc b/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc index cc8771ce430d80..196fb2ded19cbf 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -40,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/protobuf.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index e8b8378279807e..756bf319b0fb46 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -86,6 +86,8 @@ class Tile { absl::InlinedVector dimensions_; }; +using TileVector = absl::InlinedVector; + // TODO: Rename the `dim_level_types` field to `lvl_types`, so that it // matches `mlir::sparse_tensor::SparseTensorEncodingAttr`. class Layout { @@ -293,7 +295,7 @@ class Layout { return *this; } absl::Span tiles() const { return tiles_; } - absl::InlinedVector* mutable_tiles() { return &tiles_; } + TileVector* mutable_tiles() { return &tiles_; } int64_t element_size_in_bits() const { return element_size_in_bits_; } Layout& set_element_size_in_bits(int64_t value) { @@ -376,7 +378,7 @@ class Layout { DimensionVector minor_to_major_; // The tiles used in tiling-based layout. - absl::InlinedVector tiles_; + TileVector tiles_; // The primitive type to use for sparse array indices and pointers. Each of // these must either be INVALID, or an unsigned integer type. diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD index 55fa13e9372023..eee5a6e926911f 100644 --- a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc index 0b473a634fedab..83c1bdf5d0b371 100644 --- a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc @@ -495,6 +495,54 @@ struct SparseSDDMMCallRewriter { } }; +// This rewriter rewrites 2:4 SpMM custom op to linalg.generic operator that +// carries the DENSE24 trait and does multiplication. +struct Sparse2To4SpMMCallRewriter { + LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { + assert(op.getInputs().size() == 3 && "Need C, A, B matrices"); + assert(op.getResults().size() == 1 && "Need one output tensor"); + Location loc = op.getLoc(); + Value mat_c = op.getInputs()[0]; + Value mat_a = op.getInputs()[1]; + Value mat_b = op.getInputs()[2]; + + auto etp = mat_c.getType().dyn_cast().getElementType(); + // Build the enveloping generic op with the following trait: + // indexing_maps = [ + // affine_map<(i,j,k) -> (i,k)>, // A + // affine_map<(i,j,k) -> (k,j)>, // B + // affine_map<(i,j,k) -> (i,j)> // S + // ], + // iterator_types = ["parallel", "parallel", "reduction"], + // doc = "C(i,j) += SUM_k A(i,k) B(k,j)" + SmallVector iteratorTypes; + iteratorTypes.push_back(utils::IteratorType::parallel); + iteratorTypes.push_back(utils::IteratorType::parallel); + iteratorTypes.push_back(utils::IteratorType::reduction); + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr i, j, k; + bindDims(op.getContext(), i, j, k); + auto indexing_maps = infer({{i, k}, {k, j}, {i, j}}); + auto generic_op = rewriter.create( + loc, TypeRange{mat_c.getType()}, ValueRange{mat_a, mat_b}, + ValueRange{mat_c}, indexing_maps, iteratorTypes); + // Set DENSE24 attribute. + generic_op->setAttr("DENSE24", rewriter.getI32IntegerAttr(1)); + // Construct operations in the linalg.generic block. + Block* main = rewriter.createBlock(&generic_op.getRegion(), {}, + {etp, etp, etp}, {loc, loc, loc}); + Value arg_c = main->getArgument(2); + rewriter.setInsertionPointToStart(&generic_op.getRegion().front()); + auto mul = rewriter.create(loc, main->getArgument(0), + main->getArgument(1)); + auto add = rewriter.create(loc, mul.getResult(), arg_c); + rewriter.create(loc, add.getResult()); + rewriter.replaceOp(op, generic_op.getResults()); + return success(); + } +}; + class SparseCustomCallRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; using SparseCustomTargetRewriter = std::function { std::make_pair("sparse_tensor_transpose", SparseTransposeCallRewriter()), // User custom ops that need rewriting. std::make_pair("sparse_jax_sddmm", SparseSDDMMCallRewriter()), + std::make_pair("sparse_jax_2to4_spmm", Sparse2To4SpMMCallRewriter()), }; // Rewrites a CustomCallOp to corresponding sparse_tensor operation. diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD index 1f192d5507e424..622293491f938e 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -33,6 +33,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", + "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", @@ -73,6 +74,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:nccl_collective_thunks", "//tensorflow/compiler/xla/stream_executor:blas", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", + "//tensorflow/tsl/platform:env", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc index c524043ace6cc9..60a2f1ff1fba52 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -31,6 +32,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h" #include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" +#include "tensorflow/tsl/platform/env.h" namespace xla { namespace gpu { @@ -93,6 +95,13 @@ llvm::SmallVector GetRegionInfos( DataflowAnalysis::DataflowGraph dataflow_graph = dataflow_analysis.GetDataflowGraph(capture_func); + // If verbose logging is enabled print the dataflow graph as a DOT graph. + if (VLOG_IS_ON(100)) { + std::cout << "Dependency graph for graph capture function " + << capture_func.getName().str() << ":\n" + << dataflow_analysis.ToDot(dataflow_graph); + } + llvm::SmallVector region; auto store_region_and_start_new_region = [&]() { diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc index acca9c7b11e5e0..f4b23864a1fe08 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc @@ -16,8 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h" #include +#include +#include #include +#include "absl/strings/str_cat.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project @@ -165,6 +168,50 @@ bool HasDependency(llvm::ArrayRef buffer_uses_a, return false; } +bool Reachable(const DataflowAnalysis::DataflowGraph& graph, size_t from_index, + size_t to_index) { + std::queue bfs_queue; + bfs_queue.push(from_index); + + while (!bfs_queue.empty()) { + size_t index = bfs_queue.front(); + bfs_queue.pop(); + if (index == to_index) return true; + + const DataflowAnalysis::Node& node = graph[index]; + for (size_t child_index : node.children) { + bfs_queue.push(child_index); + } + } + + return false; +} + +// Remove edges that are redundant for determining the execution order of +// kernels. We use the following algorithm to compute the transitive reduction: +// +// for edge (u,v) do +// if there is a path from u to v in that does not use edge (u,v) then +// remove edge (u,v) +// +// TODO(b/288594057): Use a more efficient algorithm. +void TransitiveReduction(DataflowAnalysis::DataflowGraph& graph) { + for (DataflowAnalysis::Node& node : graph) { + auto is_reducible = [&](size_t to_index) -> bool { + for (size_t child_index : node.children) { + if (child_index != to_index) { + if (Reachable(graph, child_index, to_index)) return true; + } + } + return false; + }; + + node.children.erase(std::remove_if(node.children.begin(), + node.children.end(), is_reducible), + node.children.end()); + } +} + } // namespace DataflowAnalysis::DataflowGraph DataflowAnalysis::GetDataflowGraph( @@ -192,8 +239,35 @@ DataflowAnalysis::DataflowGraph DataflowAnalysis::GetDataflowGraph( } } + TransitiveReduction(graph); return graph; } +std::string DataflowAnalysis::ToDot(const DataflowGraph& graph) { + std::string pad; + std::string res; + auto indent = [&] { pad.append(2, ' '); }; + auto outdent = [&] { pad.resize(pad.size() - 2); }; + auto addline = [&](auto&&... args) { + absl::StrAppend(&res, pad, args..., "\n"); + }; + auto get_name = [](const Node& node) -> std::string { + return absl::StrCat("\"", node.operation->getName().getStringRef().str(), + "_", node.index, "\""); + }; + + addline("digraph {"); + indent(); + for (const Node& node : graph) { + for (size_t child_index : node.children) { + Node child = graph[child_index]; + addline(get_name(node), " -> ", get_name(child)); + } + } + outdent(); + addline("}"); + return res; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h index 72deb6eed23c81..f301e4297dc62b 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -45,6 +46,8 @@ class DataflowAnalysis { // have write-conflicts. // (3) We have information about read-only and read-write buffer arguments. DataflowGraph GetDataflowGraph(mlir::func::FuncOp graph_capture_function); + + std::string ToDot(const DataflowGraph& graph); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc index cb07346409e8be..1f43ff28261b0a 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc @@ -372,6 +372,7 @@ class ConvOpLowering : public OpRewritePattern { if (auto fused = dyn_cast(op.getOperation())) { call->setAttr(b.getStringAttr("activation_mode"), fused.getActivationModeAttr()); + set_attr("leakyrelu_alpha", fused.getLeakyreluAlphaAttr()); } // Copy attributes specific for fused convolutions with side input. diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc index 7c5f31c8bbb973..3a921bb4af7b5d 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc @@ -57,8 +57,10 @@ class OutlineCudaGraphsPass : public impl::OutlineCudaGraphsPassBase { public: OutlineCudaGraphsPass() = default; - explicit OutlineCudaGraphsPass(int cuda_graph_level) - : cuda_graph_level_(cuda_graph_level) {} + explicit OutlineCudaGraphsPass(int cuda_graph_level, int min_graph_size) + : cuda_graph_level_(cuda_graph_level) { + this->min_graph_size_ = min_graph_size; + } void runOnOperation() override; @@ -326,16 +328,14 @@ static std::vector GetGraphCaptureFuncArgs(const CaptureSequence& seq) { // and replace them with an XLA Gpu runtime function call. static LogicalResult Outline(unsigned ordinal, CustomCallDeclarations& custom_calls, - CaptureSequence& seq) { + CaptureSequence& seq, int min_graph_size) { // Only operations that have to be moved into the graph capture function // represent Gpu computations. unsigned num_move_captures = llvm::count_if(seq, [](auto capture) { return capture.second == OpCapturePattern::Capture::kMove; }); DebugOptions debug_options = GetDebugOptionsFromFlags(); - int32_t graph_capture_threshold = - debug_options.xla_gpu_cuda_graph_min_graph_size(); - if (num_move_captures < graph_capture_threshold) return failure(); + if (num_move_captures < min_graph_size) return failure(); SymbolTable& sym_table = custom_calls.sym_table(); MLIRContext* ctx = sym_table.getOp()->getContext(); @@ -479,7 +479,8 @@ void OutlineCudaGraphsPass::runOnOperation() { unsigned ordinal = 1; // entry point will be exported with ordinal 0 for (auto& seq : CollectCaptureSequences(getAnalysis(), getOperation(), patterns)) { - if (succeeded(Outline(ordinal, custom_calls, seq))) ordinal++; + if (succeeded(Outline(ordinal, custom_calls, seq, min_graph_size_))) + ordinal++; } } @@ -488,8 +489,9 @@ std::unique_ptr> createOutlineCudaGraphsPass() { } std::unique_ptr> createOutlineCudaGraphsPass( - int cuda_graph_level) { - return std::make_unique(cuda_graph_level); + int cuda_graph_level, int min_graph_size) { + return std::make_unique(cuda_graph_level, + min_graph_size); } } // namespace gpu diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc index 2912fb1df9f4a6..54f91186b8be8f 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" -#include -#include - #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -34,14 +31,14 @@ void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, // Clean up IR before converting it to the runtime operations. pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); // Convert global memrefs corresponding to constant arguments. pm.addPass(createConvertMemrefGetGlobalToArgPass()); pm.addPass(createSymbolDCEPass()); // Clean up unused global constants. // Outline CUDA-Graph-compatible operations into graph capture functions. - pm.addPass(createOutlineCudaGraphsPass(opts.cuda_graph_level)); + pm.addPass( + createOutlineCudaGraphsPass(opts.cuda_graph_level, opts.min_graph_size)); if (opts.enable_concurrent_region) { pm.addPass(createAddConcurrentRegionsPass()); } diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h index ebf4058365ef10..3917c73cec634c 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h @@ -42,6 +42,7 @@ struct GpuPipelineOpts { // CUDA Graphs, which allows us to amortize the cost of launching multiple // device kernels. int32_t cuda_graph_level = 0; + int32_t min_graph_size = 0; bool enable_concurrent_region = false; }; @@ -101,7 +102,7 @@ std::unique_ptr> createOutlineCudaGraphsPass(); std::unique_ptr> -createOutlineCudaGraphsPass(int32_t cuda_graph_level); +createOutlineCudaGraphsPass(int32_t cuda_graph_level, int32_t min_graph_size); //===----------------------------------------------------------------------===// // Passes for marking concurrent region in CUDA graph capture function. diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td index 1c7dbb55bc9f40..c2daab9f2a6ef4 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td @@ -185,6 +185,11 @@ def OutlineCudaGraphsPass : }]; let constructor = "createOutlineCudaGraphsPass()"; + + let options = [ + Option<"min_graph_size_", "min_graph_size", "int64_t", /*default=*/"2", + "The minimum size of the outlined CUDA graph function.">, + ]; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir index cd52e0ae826194..59d6ca10a9bdca 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir @@ -212,6 +212,7 @@ func.func @conv_forward_fused(%input: memref<8x5x5x1xf32, #map1>, reverse = [0, 0] } { activation_mode = #lmhlo_gpu, + leakyrelu_alpha = 0.0 : f64, backend_config = #lmhlo_gpu.convolution_backend_config< algorithm = 11, is_cudnn_frontend = true, diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir index 3cdae0c117489d..ce2a7227b97a89 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir @@ -412,6 +412,7 @@ module attributes {gpu.container_module} { reverse = [0, 0] } { activation_mode = #lmhlo_gpu, + leakyrelu_alpha = 0.0 : f64, backend_config = #lmhlo_gpu.convolution_backend_config< algorithm = -1, is_cudnn_frontend = true, diff --git a/tensorflow/compiler/xla/mlir/framework/ir/BUILD b/tensorflow/compiler/xla/mlir/framework/ir/BUILD index 07f6d6404661ea..d1d021e7ba79e7 100644 --- a/tensorflow/compiler/xla/mlir/framework/ir/BUILD +++ b/tensorflow/compiler/xla/mlir/framework/ir/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -13,7 +13,7 @@ td_library( srcs = [ "xla_framework_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -23,7 +23,7 @@ td_library( gentbl_cc_library( name = "xla_framework_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], diff --git a/tensorflow/compiler/xla/mlir/framework/transforms/BUILD b/tensorflow/compiler/xla/mlir/framework/transforms/BUILD index 5a1fddaa0a9278..52b3ad191bec19 100644 --- a/tensorflow/compiler/xla/mlir/framework/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/framework/transforms/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ diff --git a/tensorflow/compiler/xla/mlir/math/transforms/BUILD b/tensorflow/compiler/xla/mlir/math/transforms/BUILD index 271cc519266e67..a3a44b926f1393 100644 --- a/tensorflow/compiler/xla/mlir/math/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/math/transforms/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -32,7 +32,7 @@ cc_library( "math_optimization.cc", ], hdrs = ["passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":passes_inc_gen", "@llvm-project//mlir:ArithDialect", diff --git a/tensorflow/compiler/xla/mlir/memref/transforms/BUILD b/tensorflow/compiler/xla/mlir/memref/transforms/BUILD index cd77dc34eea949..d1372b1feb59e1 100644 --- a/tensorflow/compiler/xla/mlir/memref/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/memref/transforms/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -29,7 +29,7 @@ cc_library( name = "passes", srcs = ["aligned_allocations.cc"], hdrs = ["passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":passes_inc_gen", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/xla/mlir/runtime/BUILD b/tensorflow/compiler/xla/mlir/runtime/BUILD index 6c2504899bfd01..48d976de9f2d9d 100644 --- a/tensorflow/compiler/xla/mlir/runtime/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") package_group( @@ -7,7 +7,6 @@ package_group( packages = [ # copybara:uncomment_begin(google-only) # "//platforms/xla/service/cpu/...", - # "//learning/brain/experimental/tfrt/autofusion/...", # "//third_party/mlir_edge/tpgen/...", # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. # "@tf_runtime//...", @@ -38,7 +37,7 @@ build_test( xla_cc_binary( name = "xla-runtime-opt", srcs = ["xla-runtime-opt.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/mlir/math/transforms:passes", "//tensorflow/compiler/xla/mlir/memref/transforms:passes", diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/BUILD b/tensorflow/compiler/xla/mlir/runtime/ir/BUILD index 2f7875bb9116ff..49b659fe9dcc5d 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/ir/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") package( @@ -15,7 +15,7 @@ td_library( "rt_interfaces.td", "rt_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["include"], visibility = ["//visibility:private"], deps = [ @@ -27,7 +27,7 @@ td_library( gentbl_cc_library( name = "rt_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-dialect-decls"], @@ -69,7 +69,7 @@ gentbl_cc_library( gentbl_cc_library( name = "rt_interfaces_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-attr-interface-decls"], @@ -97,7 +97,7 @@ cc_library( "rt_interfaces.h", "rt_ops.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":rt_inc_gen", ":rt_interfaces_inc_gen", diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD index 1fe391993265dd..dc9bd544f64fc2 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//tensorflow/tsl/platform:build_config.bzl", "if_llvm_system_z_available") @@ -12,7 +12,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -37,7 +37,7 @@ cc_library( "rt_to_llvm.cc", ], hdrs = ["passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call_encoding", ":passes_inc_gen", @@ -66,7 +66,7 @@ cc_library( name = "calling_convention", srcs = ["calling_convention.cc"], hdrs = ["calling_convention.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/mlir/runtime/ir:rt", "@llvm-project//mlir:FuncDialect", @@ -78,7 +78,7 @@ cc_library( xla_cc_test( name = "calling_convention_test", srcs = ["calling_convention_test.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":calling_convention", "//tensorflow/compiler/xla/mlir/runtime/ir:rt", @@ -95,7 +95,7 @@ cc_library( name = "compilation_pipeline_cpu", srcs = ["compilation_pipeline_cpu.cc"], hdrs = ["compilation_pipeline_cpu.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ ":compilation_pipeline_options", @@ -120,10 +120,10 @@ cc_library( "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToLLVM", "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgToLLVM", @@ -150,7 +150,7 @@ cc_library( name = "compilation_pipeline_gpu", srcs = ["compilation_pipeline_gpu.cc"], hdrs = ["compilation_pipeline_gpu.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ ":compilation_pipeline_options", @@ -184,7 +184,7 @@ cc_library( cc_library( name = "compilation_pipeline_options", hdrs = ["compilation_pipeline_options.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call_encoding", "//tensorflow/compiler/xla/runtime:type_id", @@ -196,7 +196,7 @@ cc_library( name = "custom_call_encoding", srcs = ["custom_call_encoding.cc"], hdrs = ["custom_call_encoding.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/mlir/runtime/ir:rt", @@ -219,7 +219,7 @@ cc_library( name = "jit_compiler", srcs = ["jit_compiler.cc"], hdrs = ["jit_compiler.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":calling_convention", ":compiler", @@ -267,7 +267,7 @@ cc_library( name = "specialization", srcs = ["specialization.cc"], hdrs = ["specialization.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":type_converter", "//tensorflow/compiler/xla/mlir/runtime/utils:constraints", @@ -292,7 +292,7 @@ cc_library( name = "type_converter", srcs = ["type_converter.cc"], hdrs = ["type_converter.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/mlir/runtime/ir:rt", @@ -309,7 +309,7 @@ cc_library( xla_cc_test( name = "type_converter_test", srcs = ["type_converter_test.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":type_converter", "//tensorflow/compiler/xla/runtime:types", @@ -322,7 +322,7 @@ xla_cc_test( cc_library( name = "compiler", hdrs = ["compiler.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc index 05fe3773e39365..0ff7e1b0d366da 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc @@ -21,11 +21,11 @@ limitations under the License. #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project #include "mlir/Conversion/MathToLibm/MathToLibm.h" // from @llvm-project #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project -#include "mlir/Conversion/Passes.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc index a110e81c681531..871fbb41aa0214 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc @@ -28,7 +28,6 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -63,20 +62,24 @@ void RegisterTestlibDialect(DialectRegistry& dialects) { } static void CreateDefaultXlaGpuRuntimeCompilationPipeline( - mlir::OpPassManager& pm, const CompilationPipelineOptions& opts) { + mlir::OpPassManager& pm, const CompilationPipelineOptions& opts, + bool add_async_passes) { pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createAsyncFuncToAsyncRuntimePass()); + + if (add_async_passes) pm.addPass(mlir::createAsyncFuncToAsyncRuntimePass()); // Export functions to the XLA runtime. pm.addPass(CreateExportRuntimeFunctionsPass()); pm.addPass(CreateConvertCustomCallsPass()); pm.addPass(CreateConvertAssertsPass()); - // Lower from high level async operations to async runtime. - pm.addPass(mlir::createAsyncToAsyncRuntimePass()); + if (add_async_passes) { + // Lower from high level async operations to async runtime. + pm.addPass(mlir::createAsyncToAsyncRuntimePass()); - // Add async.runtime reference counting operations. - pm.addPass(mlir::createAsyncRuntimePolicyBasedRefCountingPass()); + // Add async.runtime reference counting operations. + pm.addPass(mlir::createAsyncRuntimePolicyBasedRefCountingPass()); + } // Convert runtime operations and custom calls to LLVM dialect. ConvertRuntimeToLLvmOpts rt_to_llvm_opts = { @@ -86,7 +89,7 @@ static void CreateDefaultXlaGpuRuntimeCompilationPipeline( pm.addPass(CreateConvertRuntimeToLLVMPass(std::move(rt_to_llvm_opts))); // Convert async dialect to LLVM once everything else is in the LLVM dialect. - pm.addPass(mlir::createConvertAsyncToLLVMPass()); + if (add_async_passes) pm.addPass(mlir::createConvertAsyncToLLVMPass()); // Convert everything else to LLVM dialect. pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); @@ -94,13 +97,14 @@ static void CreateDefaultXlaGpuRuntimeCompilationPipeline( pm.addPass(mlir::createReconcileUnrealizedCastsPass()); // Clean up IR before passing it to LLVM. - pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); } void CreateDefaultXlaGpuRuntimeCompilationPipeline( - PassManager& passes, const CompilationPipelineOptions& opts) { - CreateDefaultXlaGpuRuntimeCompilationPipeline(*passes, opts); + PassManager& passes, const CompilationPipelineOptions& opts, + bool add_async_passes) { + CreateDefaultXlaGpuRuntimeCompilationPipeline(*passes, opts, + add_async_passes); } void AppendXlaGpuDialectRegistry(mlir::MLIRContext& context) { @@ -111,7 +115,7 @@ void AppendXlaGpuDialectRegistry(mlir::MLIRContext& context) { static void CreateDefaultGpuPipeline(mlir::OpPassManager& pm) { CompilationPipelineOptions copts; - CreateDefaultXlaGpuRuntimeCompilationPipeline(pm, copts); + CreateDefaultXlaGpuRuntimeCompilationPipeline(pm, copts, false); } static mlir::PassPipelineRegistration<> kXlaRuntimePipeline( diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h index 5f78e16fbf3719..4bf1bc7c2d4e66 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MLIR_RUNTIME_TRANSFORMS_COMPILATION_PIPELINE_GPU_H_ #define TENSORFLOW_COMPILER_XLA_MLIR_RUNTIME_TRANSFORMS_COMPILATION_PIPELINE_GPU_H_ -#include - #include "tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_options.h" #include "tensorflow/compiler/xla/runtime/compiler.h" @@ -42,7 +40,8 @@ void RegisterTestlibDialect(DialectRegistry& dialects); // it is expected that all end users will construct their own compilation // pipelines from the available XLA and MLIR passes. void CreateDefaultXlaGpuRuntimeCompilationPipeline( - PassManager& passes, const CompilationPipelineOptions& opts); + PassManager& passes, const CompilationPipelineOptions& opts, + bool add_async_passes = false); void AppendXlaGpuDialectRegistry(mlir::MLIRContext& context); diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc index a83f477729ee54..b23c7f0d1bd86f 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc @@ -422,7 +422,8 @@ static LLVM::AllocaOp PackValue(ImplicitLocOpBuilder &b, Allocas &a, LLVM::AllocaOp alloca = a.GetOrCreate(b, value.getType()); // Start the lifetime of encoded value. b.create(b.getI64IntegerAttr(-1), alloca); - b.create(value, alloca); + // Use volatile store to suppress expensive LLVM optimizations. + b.create(value, alloca, /*alignment=*/0, /*isVolatile=*/true); return alloca; } diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc index 399ad8de6fbbae..00df6405898bd1 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc @@ -337,8 +337,9 @@ static FailureOr EncodeArguments( // Start the lifetime of the encoded arguments pointers. b.create(b.getI64IntegerAttr(-1), alloca); - // Store constructed arguments pointers array into the alloca. - b.create(arr, alloca.getRes()); + // Store constructed arguments pointers array into the alloca. Use volatile + // store to suppress expensive LLVM optimizations. + b.create(arr, alloca, /*alignment=*/0, /*isVolatile=*/true); // Alloca that encodes the custom call arguments. arguments.encoded = alloca; @@ -431,8 +432,9 @@ static FailureOr EncodeResults( // Start the lifetime of the encoded results pointers allocation. b.create(b.getI64IntegerAttr(-1), alloca); - // Store constructed results pointers array on the stack - b.create(arr, alloca); + // Store constructed results pointers array on the stack. Use volatile + // store to suppress expensive LLVM optimizations. + b.create(arr, alloca, /*alignment=*/0, /*isVolatile=*/true); // Alloca that encodes the custom call returns. results.encoded = alloca; diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir index b26b0b6f6062c4..38af292598b59d 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir @@ -335,7 +335,7 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1 : f32) { // CHECK-DAG: %[[ARGS:.*]] = llvm.alloca {{.*}} x !llvm.array<3 x ptr> // CHECK-DAG: %[[N_ARGS:.*]] = llvm.mlir.addressof @__rt_num_args - // CHECK-DAG: llvm.store %[[ARG]], %[[MEM]] + // CHECK-DAG: llvm.store volatile %[[ARG]], %[[MEM]] // CHECK: %[[ARGS_TYPES:.*]] = llvm.mlir.addressof @__rt_args_type_table // CHECK: llvm.insertvalue %[[ARGS_TYPES]], {{.*}}[1] : !llvm.array<3 x ptr> @@ -460,7 +460,7 @@ func.func @opaque_arg(%ctx: !rt.execution_context, %arg: !rt.opaque) { func.func @opaque_custom_call_arg(%ctx: !rt.execution_context, %arg: !rt.opaque) { // CHECK: %[[ALLOCA:.*]] = llvm.alloca {{.*}} x !llvm.ptr - // CHECK: llvm.store %[[ARG1]], %[[ALLOCA]] : !llvm.ptr + // CHECK: llvm.store volatile %[[ARG1]], %[[ALLOCA]] : !llvm.ptr // CHECK: call @target %status = rt.call %ctx["target"] (%arg) : (!rt.opaque) -> () return @@ -627,7 +627,7 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1: f32) { // CHECK-NOT: llvm.alloca // llvm.intr.lifetime.start -1, %[[ARG_ALLOCA]] : !llvm.ptr - // CHECK: llvm.store %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr + // CHECK: llvm.store volatile %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr // llvm.intr.lifetime.start -1, %[[ARGS]] : !llvm.ptr // CHECK: llvm.store {{.*}}, %[[ARGS]] // CHECK: call @target @@ -636,7 +636,7 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1: f32) { // llvm.intr.lifetime.end -1, %[[ARG_ALLOCA]] : !llvm.ptr // llvm.intr.lifetime.start -1, %[[ARG_ALLOCA]] : !llvm.ptr - // CHECK: llvm.store %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr + // CHECK: llvm.store volatile %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr // llvm.intr.lifetime.start -1, %[[ARGS]] : !llvm.ptr // CHECK: llvm.store {{.*}}, %[[ARGS]] // CHECK: call @target diff --git a/tensorflow/compiler/xla/mlir/runtime/utils/BUILD b/tensorflow/compiler/xla/mlir/runtime/utils/BUILD index 0b552cd7e5c2aa..0bec5a090be1d3 100644 --- a/tensorflow/compiler/xla/mlir/runtime/utils/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/utils/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -11,7 +11,7 @@ cc_library( name = "async_runtime_api", srcs = ["async_runtime_api.cc"], hdrs = ["async_runtime_api.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/runtime:async_runtime", "//tensorflow/tsl/platform:platform_port", @@ -26,7 +26,7 @@ cc_library( cc_library( name = "c_runner_utils", hdrs = ["c_runner_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:OrcJIT", "@llvm-project//mlir:mlir_c_runner_utils", @@ -37,7 +37,7 @@ cc_library( name = "constraints", srcs = ["constraints.cc"], hdrs = ["constraints.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/runtime:constraints", "//tensorflow/compiler/xla/runtime:errors", @@ -55,7 +55,7 @@ cc_library( name = "custom_calls", srcs = ["custom_calls.cc"], hdrs = ["custom_calls.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -66,7 +66,7 @@ cc_library( cc_library( name = "float_16bits", hdrs = ["float_16bits.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:OrcJIT", "@llvm-project//mlir:mlir_float16_utils", diff --git a/tensorflow/compiler/xla/mlir/utils/BUILD b/tensorflow/compiler/xla/mlir/utils/BUILD index 9e9a6f974d91a3..9d8e0449bf34fb 100644 --- a/tensorflow/compiler/xla/mlir/utils/BUILD +++ b/tensorflow/compiler/xla/mlir/utils/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -13,7 +13,7 @@ cc_library( name = "error_util", srcs = ["error_util.cc"], hdrs = ["error_util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:errors", "@com_google_absl//absl/status", diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD b/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD index 33922080c1dcd8..2737d23adda775 100644 --- a/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD +++ b/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -14,7 +14,7 @@ td_library( "xla_cpu_enums.td", "xla_cpu_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", "@llvm-project//mlir:BufferizableOpInterfaceTdFiles", @@ -25,7 +25,7 @@ td_library( gentbl_cc_library( name = "xla_cpu_dialect_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-dialect-decls"], @@ -43,7 +43,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_cpu_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -61,7 +61,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_cpu_enums_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-enum-decls"], diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD index 9cf6296ae7c05a..8cc3a88e4fe81a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/BUILD +++ b/tensorflow/compiler/xla/mlir_hlo/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") load("@bazel_skylib//rules:build_test.bzl", "build_test") @@ -25,7 +25,7 @@ filegroup( td_library( name = "hlo_ops_td_files", srcs = glob(["mhlo/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", @@ -46,7 +46,7 @@ td_library( gentbl_cc_library( name = "mhlo_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -64,7 +64,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lmhlo_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -82,7 +82,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -101,7 +101,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_attrs_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -120,7 +120,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_enums_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -139,7 +139,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_typedefs_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -164,7 +164,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_pattern_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/IR/", tbl_outs = [ ( @@ -183,7 +183,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_ops_structs_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -202,7 +202,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -221,7 +221,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_gpu_ops_enums_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -240,7 +240,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_gpu_ops_dialect_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -259,7 +259,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_gpu_ops_attrdefs_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -278,7 +278,7 @@ gentbl_cc_library( gentbl_filegroup( name = "hlo_ops_doc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -295,7 +295,7 @@ gentbl_filegroup( gentbl_filegroup( name = "lhlo_ops_doc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -324,7 +324,7 @@ cc_library( td_library( name = "lhlo_gpu_ops_td_files", srcs = glob(["lhlo_gpu/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ ":hlo_ops_td_files", @@ -335,7 +335,7 @@ td_library( gentbl_cc_library( name = "lhlo_gpu_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -355,7 +355,7 @@ gentbl_cc_library( #TODO(aminim): revisit the naming and grouping of these rules post-move. gentbl_cc_library( name = "canonicalize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -371,7 +371,7 @@ gentbl_cc_library( td_library( name = "deallocation_ops_td_files", srcs = glob(["deallocation/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:OpBaseTdFiles", @@ -381,7 +381,7 @@ td_library( gentbl_cc_library( name = "deallocation_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -455,7 +455,7 @@ cc_library( gentbl_cc_library( name = "deallocation_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -505,7 +505,7 @@ cc_library( td_library( name = "lhlo_ops_td_files", srcs = glob(["lhlo/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ ":hlo_ops_td_files", @@ -523,7 +523,7 @@ td_library( gentbl_cc_library( name = "lhlo_structured_interface_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -734,6 +734,7 @@ cc_library( "mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc", "mhlo/transforms/legalize_to_standard/generated_legalize_to_standard.inc", "mhlo/transforms/legalize_to_standard/legalize_to_standard.cc", + "mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc", "mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc", "mhlo/transforms/lower_complex/generated_lower_complex.inc", "mhlo/transforms/lower_complex/lower_complex.cc", @@ -1059,7 +1060,7 @@ cc_library( gentbl_cc_library( name = "legalize_to_standard_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/transforms/", tbl_outs = [ ( @@ -1079,7 +1080,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lower_complex_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/transforms/", tbl_outs = [ ( @@ -1136,7 +1137,7 @@ cc_library( gentbl_cc_library( name = "chlo_legalize_to_hlo_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/transforms", tbl_outs = [ ( @@ -1392,7 +1393,7 @@ cc_library( gentbl_cc_library( name = "gml_st_test_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1433,7 +1434,7 @@ cc_library( gentbl_cc_library( name = "transforms_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1451,7 +1452,7 @@ gentbl_cc_library( gentbl_cc_library( name = "gpu_transforms_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1717,7 +1718,7 @@ filegroup( td_library( name = "gml_st_ops_td_files", srcs = glob(["gml_st/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -1730,7 +1731,7 @@ td_library( gentbl_cc_library( name = "gml_st_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1797,7 +1798,7 @@ cc_library( gentbl_cc_library( name = "gml_st_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1838,7 +1839,7 @@ cc_library( td_library( name = "thlo_ops_td_files", srcs = glob(["thlo/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -1850,7 +1851,7 @@ td_library( gentbl_cc_library( name = "thlo_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1921,7 +1922,7 @@ cc_library( gentbl_cc_library( name = "thlo_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc index c07760ff679544..5c7496c4193d17 100644 --- a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" @@ -386,7 +387,7 @@ struct VectorizeForCPUPass ThloReverseVectorizationPattern, TransferReadOfOneDimExpandShape>(ctx); tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); - vector::populateVectorTransferTensorSliceTransforms(patterns); + tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) return signalPassFailure(); } diff --git a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 72fd56b2a15ed5..7e61139b761477 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -100,7 +100,8 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_ConvBaseOp<"conv_forward_fused"> { Arg:$output, Arg:$scratch), GpuConvolutionAttributes<(ins - ActivationAttr:$activation_mode)>.attributes); + ActivationAttr:$activation_mode, + F64Attr:$leakyrelu_alpha)>.attributes); } // output = activation(result_scale * conv(input, filter) + diff --git a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td index 2a527ac553779f..1e151212a3718e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td @@ -30,12 +30,13 @@ def ActivationModeRelu6 : I32EnumAttrCase<"Relu6", 4>; def ActivationModeReluX : I32EnumAttrCase<"ReluX", 5>; def ActivationModeBandPass : I32EnumAttrCase<"BandPass", 6>; def ActivationModeElu: I32EnumAttrCase<"Elu", 7>; +def ActivationModeLeakyRelu: I32EnumAttrCase<"LeakyRelu", 8>; def Activation: I32EnumAttr<"Activation", "Activation applied with fused convolution", [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, - ActivationModeBandPass, ActivationModeElu]> { + ActivationModeBandPass, ActivationModeElu, ActivationModeLeakyRelu]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::lmhlo_gpu"; } diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 645366eecc9158..b45ee1443836e9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -100,9 +100,7 @@ using mlir::hlo::printDimSizes; #define GET_TYPEDEF_CLASSES #include "mhlo/IR/hlo_ops_typedefs.cc.inc" -namespace mlir { -namespace mhlo { - +namespace mlir::mhlo { namespace detail { /// A type representing a collection of other types. struct AsyncBundleTypeStorage final @@ -6279,8 +6277,7 @@ LogicalResult UniformDequantizeOp::inferReturnTypeComponents( using mlir::hlo::parseWindowAttributes; using mlir::hlo::printWindowAttributes; -} // namespace mhlo -} // namespace mlir +} // namespace mlir::mhlo using mlir::hlo::parseComplexOpType; using mlir::hlo::parseCustomCallTarget; @@ -6302,8 +6299,7 @@ using mlir::hlo::printVariadicSameOperandsAndResultType; #define GET_OP_CLASSES #include "mhlo/IR/hlo_ops.cc.inc" -namespace mlir { -namespace mhlo { +namespace mlir::mhlo { //===----------------------------------------------------------------------===// // mhlo Dialect Interfaces @@ -6344,7 +6340,7 @@ struct MhloHloDialectInterface : public hlo::HloDialectInterface { return TypeExtensionsAttr::get(getDialect()->getContext(), bounds); } }; -} // end anonymous namespace +} // namespace //===----------------------------------------------------------------------===// // mhlo Dialect Constructor @@ -7365,5 +7361,4 @@ LogicalResult MhloDialect::verifyOperationAttribute(Operation* op, return success(); } -} // namespace mhlo -} // namespace mlir +} // namespace mlir::mhlo diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h index 9b54a8494a8309..4a5483a91c6d87 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h @@ -36,6 +36,11 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" #include "stablehlo/dialect/Base.h" +// Forward declaration for hlo_ops_typedefs.h.inc. +namespace mlir::mhlo::detail { +struct AsyncBundleTypeStorage; +} // namespace mlir::mhlo::detail + // Include order below matters. #include "mhlo/IR/hlo_ops_enums.h.inc" #define GET_ATTRDEF_CLASSES @@ -92,21 +97,19 @@ void printConvolutionDimensions(AsmPrinter &p, Operation *, ParseResult parseConvolutionDimensions(AsmParser &parser, ConvDimensionNumbersAttr &dnums); -} // end namespace mhlo -} // end namespace mlir +} // namespace mhlo +} // namespace mlir #define GET_OP_CLASSES #include "mhlo/IR/hlo_ops.h.inc" -namespace mlir { -namespace mhlo { +namespace mlir::mhlo { SortOp createSortOp(PatternRewriter *rewriter, const Location &loc, const llvm::ArrayRef &operands, const llvm::ArrayRef &elementTypes, int64_t dimension, bool isStable, ComparisonDirection direction); -} // end namespace mhlo -} // end namespace mlir +} // namespace mlir::mhlo #endif // MLIR_HLO_MHLO_IR_HLO_OPS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 84c1449a3aaa40..e0fa81c241f3d1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -47,6 +47,7 @@ add_mlir_library(MhloPasses legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc legalize_shape_computations/legalize_shape_computations.cc + legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc lower_complex/lower_complex.cc lower_complex/lower_complex_patterns.td diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc new file mode 100644 index 00000000000000..daaafd01572399 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc @@ -0,0 +1,157 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace mhlo { + +#define GEN_PASS_DEF_LEGALIZETORCHINDEXSELECTTOGATHERPASS +#include "mhlo/transforms/mhlo_passes.h.inc" + +namespace { + +struct TorchIndexSelectIsGather : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TorchIndexSelectOp op, + PatternRewriter &rewriter) const override { + auto operand = op.getOperand(); + auto operandTy = operand.getType(); + if (!operandTy.hasRank()) { + return rewriter.notifyMatchFailure(op, "unranked operand"); + } + + auto index = op.getIndex(); + if (!operand.getType().hasStaticShape() || + !index.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "operand and index must have static shapes"); + } + + int64_t dim = static_cast(op.getDim()); + int64_t batchDims = op.getBatchDims(); + if (dim < batchDims) { + return rewriter.notifyMatchFailure( + op, "dim must be greater than or equal to the number of batch dims"); + } + + int64_t indexVectorDim = index.getType().getRank(); + auto indexTy = index.getType(); + auto indexElementTy = indexTy.getElementType().dyn_cast(); + if (!indexElementTy) { + return rewriter.notifyMatchFailure( + op, "index must have integer element type"); + } + + if (index.getType().getElementType().getIntOrFloatBitWidth() == 64 && + operandTy.getShape()[dim] < std::numeric_limits::max()) { + index = rewriter.create( + op.getLoc(), index, rewriter.getIntegerType(32, /*isSigned=*/false)); + } + + if (batchDims > 0) { + llvm::SmallVector newIndexShape(indexTy.getShape()); + newIndexShape.push_back(1); + auto newIndexType = RankedTensorType::get( + newIndexShape, index.getType().getElementType()); + + llvm::SmallVector toConcat; + for (auto batchDim = 0; batchDim < batchDims; ++batchDim) { + toConcat.push_back( + rewriter.create(op.getLoc(), newIndexType, batchDim)); + } + toConcat.push_back( + rewriter.create(op.getLoc(), newIndexType, index)); + index = rewriter.create(op.getLoc(), ValueRange(toConcat), + indexVectorDim); + } + + llvm::SmallVector offsetDims; + llvm::SmallVector collapsedSliceDims; + llvm::SmallVector startIndexMap; + llvm::SmallVector sliceSizes(operandTy.getShape()); + for (auto i = 0; i < operandTy.getRank(); ++i) { + if (i < batchDims || i == dim) { + sliceSizes[i] = std::min(sliceSizes[i], static_cast(1)); + collapsedSliceDims.push_back(i); + startIndexMap.push_back(i); + } else { + if (i < dim) { + offsetDims.push_back(i); + } else { + offsetDims.push_back(i + indexVectorDim - (1 + batchDims)); + } + } + } + + auto gatherDimensionNumbersAttr = GatherDimensionNumbersAttr::get( + rewriter.getContext(), offsetDims, collapsedSliceDims, startIndexMap, + indexVectorDim); + + auto sliceSizesAttr = rewriter.getI64TensorAttr(sliceSizes); + + auto gatherOp = + rewriter.create(op.getLoc(), operand, index, + gatherDimensionNumbersAttr, sliceSizesAttr); + rewriter.replaceOp(op, gatherOp); + return success(); + } +}; + +struct LegalizeTorchIndexSelectToGatherPass + : public impl::LegalizeTorchIndexSelectToGatherPassBase< + LegalizeTorchIndexSelectToGatherPass> { + /// Perform the lowering of standard dialect operations to approximations. + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateTorchIndexSelectToGatherPatterns(&getContext(), &patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void populateTorchIndexSelectToGatherPatterns(mlir::MLIRContext *context, + RewritePatternSet *patterns) { + patterns->add(context); +} + +std::unique_ptr> +createLegalizeTorchIndexSelectToGatherPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index 1f740b45734df3..c8d3b9b6cf5806 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -93,6 +93,11 @@ def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-i let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; } +def LegalizeTorchIndexSelectToGatherPass : Pass<"mhlo-legalize-torch-index-select-to-gather", "func::FuncOp"> { + let summary = "Legalizes torch index select to a gather."; + let constructor = "createLegalizeTorchIndexSelectToGatherPass()"; +} + def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "func::FuncOp"> { let summary = "Legalize trigonometric operations from standard dialect to an approximation."; diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h index 50deb8d5b3353b..316aa825584d41 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h @@ -160,6 +160,8 @@ std::unique_ptr> createLegalizeEinsumToDotGeneralPass(); std::unique_ptr> createLegalizeGatherToTorchIndexSelectPass(); +std::unique_ptr> +createLegalizeTorchIndexSelectToGatherPass(); std::unique_ptr> createFlattenTuplePass(); // Creates a pass for expanding mhlo.tuple ops. diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h index dcca8e78cd8671..f2e4f85148f8be 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h @@ -49,6 +49,10 @@ void populateEinsumToDotGeneralPatterns(mlir::MLIRContext *context, void populateGatherToTorchIndexSelectPatterns(mlir::MLIRContext *context, RewritePatternSet *patterns); +// Rewrite patterns for torch index select to equivalent gather legalization. +void populateTorchIndexSelectToGatherPatterns(mlir::MLIRContext *context, + RewritePatternSet *patterns); + void populateMhloToStdPatterns(RewritePatternSet *patterns, MLIRContext *ctx); // Collection of rewrite patterns for lowering all mhlo ops to their diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir index d925e6199d1be2..ca4b2c0be2a117 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir @@ -141,6 +141,7 @@ func.func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32x dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { activation_mode = #lmhlo_gpu, + leakyrelu_alpha = 0.0 : f64, backend_config = #lmhlo_gpu.convolution_backend_config< algorithm = 0, tensor_ops_enabled = true, diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir new file mode 100644 index 00000000000000..19230564c80108 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir @@ -0,0 +1,189 @@ +// RUN: mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather -split-input-file %s -o - | FileCheck %s + +// CHECK-LABEL: @index_select_to_gather_convert_index_type +func.func @index_select_to_gather_convert_index_type(%arg0 : tensor<5x1x5xi64>, %arg1 : tensor<2xi64>) -> tensor<2x1x5xi64> { + // CHECK: [[ARG1:%.+]] = mhlo.convert %arg1 : (tensor<2xi64>) -> tensor<2xui32> + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, [[ARG1]]) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [1, 2], + // CHECK-SAME: collapsed_slice_dims = [0], + // CHECK-SAME: start_index_map = [0], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> + // CHECK-SAME: } : (tensor<5x1x5xi64>, tensor<2xui32>) -> tensor<2x1x5xi64> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi64>, tensor<2xi64>) -> tensor<2x1x5xi64> + // CHECK: return [[RES]] : tensor<2x1x5xi64> + func.return %0 : tensor<2x1x5xi64> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_multi_offset_dims +func.func @index_select_to_gather_multi_offset_dims(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [1, 2], + // CHECK-SAME: collapsed_slice_dims = [0], + // CHECK-SAME: start_index_map = [0], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> + // CHECK-SAME: } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + // CHECK: return [[RES]] : tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_larger_output +func.func @index_select_to_gather_larger_output(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [3], + // CHECK-SAME: collapsed_slice_dims = [0], + // CHECK-SAME: start_index_map = [0], + // CHECK-SAME: index_vector_dim = 3 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> + // CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + // CHECK: return [[RES]] : tensor<1x3x1x4xf32> + func.return %0 : tensor<1x3x1x4xf32> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_regular_map +func.func @index_select_to_gather_regular_map(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi32>) -> tensor<2x4xi32> { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [1], + // CHECK-SAME: collapsed_slice_dims = [0], + // CHECK-SAME: start_index_map = [0], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> + // CHECK-SAME: } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<2x4xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<2x4xi32> + // CHECK: return [[RES]] : tensor<2x4xi32> + func.return %0 : tensor<2x4xi32> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_reverse_map +func.func @index_select_to_gather_reverse_map(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi32>) -> tensor<3x2xi32> { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [0], + // CHECK-SAME: collapsed_slice_dims = [1], + // CHECK-SAME: start_index_map = [1], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[3, 1]> : tensor<2xi64> + // CHECK-SAME: } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<3x2xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 1 : i64, + batch_dims = 0 : i64 + } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<3x2xi32> + // CHECK: return [[RES]] : tensor<3x2xi32> + func.return %0 : tensor<3x2xi32> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_batch_dim_greater_than_1 +func.func @index_select_to_gather_batch_dim_greater_than_1(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x5xi32> { + // CHECK: [[ARG0:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x1xi32> + // CHECK: [[ARG1:%.+]] = mhlo.reshape %arg1 : (tensor<2xi32>) -> tensor<2x1xi32> + // CHECK: [[ARG2:%.+]] = "mhlo.concatenate"([[ARG0]], [[ARG1]]) {dimension = 1 : i64} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, [[ARG2]]) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [1], + // CHECK-SAME: collapsed_slice_dims = [0, 1], + // CHECK-SAME: start_index_map = [0, 1], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> + // CHECK-SAME: } : (tensor<5x1x5xi32>, tensor<2x2xi32>) -> tensor<2x5xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 1 : i64, + batch_dims = 1 : i64 + } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x5xi32> + func.return %0 : tensor<2x5xi32> +} + +// ----- + +func.func @index_select_to_gather_unranked(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + func.return %0 : tensor<*xi32> +} + +// ----- + +func.func @index_select_to_gather_non_static_operand(%arg0 : tensor<5x1x?xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x?xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// ----- + +func.func @index_select_to_gather_non_static_index(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor) -> tensor<2x1x5xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi32>, tensor) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// ----- + +func.func @index_select_to_gather_dim_less_than_batch_dims(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 1 : i64 + } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// ----- + +func.func @index_select_to_gather_non_integer_index(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xf32>) -> tensor<2x1x5xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi32>, tensor<2xf32>) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index d352faee3674c3..600dce75cbad40 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -222,6 +222,7 @@ cc_library( hdrs = ["pjrt_executable.h"], visibility = [":friends"], deps = [ + ":execute_options_proto_cc", ":pjrt_common", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -231,6 +232,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) @@ -244,6 +246,7 @@ xla_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/tsl/platform:status_matchers", "@com_google_googletest//:gtest_main", ], ) @@ -492,6 +495,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/runtime:cpu_event", + "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:platform_port", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", @@ -795,3 +799,9 @@ tf_proto_library( # deps = [":compile_options_proto"], # ) # copybara:uncomment_end + +tf_proto_library( + name = "execute_options_proto", + srcs = ["execute_options.proto"], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/compiler/xla/pjrt/c/BUILD b/tensorflow/compiler/xla/pjrt/c/BUILD index 169ad42fc3eb4a..d1531a0bf1c523 100644 --- a/tensorflow/compiler/xla/pjrt/c/BUILD +++ b/tensorflow/compiler/xla/pjrt/c/BUILD @@ -160,6 +160,8 @@ cc_library( deps = [ ":pjrt_c_api_hdrs", ":pjrt_c_api_helpers", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", @@ -168,6 +170,7 @@ cc_library( "//tensorflow/compiler/xla/pjrt:pjrt_future", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h index b94359e9e4c48a..f9b4c3e67e11fe 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h @@ -53,7 +53,7 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 7 +#define PJRT_API_MINOR 9 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -138,6 +138,50 @@ typedef PJRT_Error* (*PJRT_CallbackError)(PJRT_Error_Code code, const char* message, size_t message_size); +// ---------------------------- Named Values ----------------------------------- + +typedef enum { + PJRT_NamedValue_kString = 0, + PJRT_NamedValue_kInt64, + PJRT_NamedValue_kInt64List, + PJRT_NamedValue_kFloat, +} PJRT_NamedValue_Type; + +// Named value for key-value pairs. +struct PJRT_NamedValue { + size_t struct_size; + void* priv; + const char* name; + size_t name_size; + PJRT_NamedValue_Type type; + union { + const char* string_value; + int64_t int64_value; + const int64_t* int64_array_value; + float float_value; + }; + // `value_size` is the number of elements for array/string and 1 for scalar + // values. + size_t value_size; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size); + +// ---------------------------------- Plugin ----------------------------------- + +struct PJRT_Plugin_Attributes_Args { + size_t struct_size; + void* priv; + // Returned attributes have the lifetime of the process. + PJRT_NamedValue* attributes; // out + size_t num_attributes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes); + +// Returns an array of plugin attributes which are key-value pairs. One example +// attribute is the minimum supported StableHLO version. +// TODO(b/280349977): standardize the list of attributes. +typedef PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args); + // ---------------------------------- Events ----------------------------------- // Represents a notifying event that is returned by PJRT APIs that enqueue @@ -221,34 +265,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_OnReady_Args, user_arg); // error status and a pointer to an object of the caller's choice as arguments. typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args); -// ------------------------ Other Common Data Types ---------------------------- - -typedef enum { - PJRT_NamedValue_kString = 0, - PJRT_NamedValue_kInt64, - PJRT_NamedValue_kInt64List, - PJRT_NamedValue_kFloat, -} PJRT_NamedValue_Type; - -// Named value for key-value pairs. -struct PJRT_NamedValue { - size_t struct_size; - void* priv; - const char* name; - size_t name_size; - PJRT_NamedValue_Type type; - union { - const char* string_value; - int64_t int64_value; - const int64_t* int64_array_value; - float float_value; - }; - // `value_size` is the number of elements for array/string and 1 for scalar - // values. - size_t value_size; -}; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size); - // ---------------------------------- Client ----------------------------------- typedef struct PJRT_Client PJRT_Client; @@ -1352,6 +1368,10 @@ struct PJRT_Buffer_ToHostBuffer_Args { void* priv; PJRT_Buffer* src; + // The caller can specify an optional host layout. If nullptr, the layout of + // the src buffer will be used. The caller is responsible to keep the data + // (tiled or strides) in the host_layout alive during the call. + PJRT_Buffer_MemoryLayout* host_layout; // `dst` can be nullptr to query required size which will be set into // `dst_size`. void* dst; // in/out @@ -1660,6 +1680,8 @@ typedef struct { _PJRT_API_STRUCT_FIELD(PJRT_Error_Message); _PJRT_API_STRUCT_FIELD(PJRT_Error_GetCode); + _PJRT_API_STRUCT_FIELD(PJRT_Plugin_Attributes); + _PJRT_API_STRUCT_FIELD(PJRT_Event_Destroy); _PJRT_API_STRUCT_FIELD(PJRT_Event_IsReady); _PJRT_API_STRUCT_FIELD(PJRT_Event_Error); diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc index 6bb8f9fd37cf11..84d3629031f85a 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc @@ -24,8 +24,9 @@ namespace xla { namespace pjrt { namespace { -const bool kUnused = - (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); }), true); +const bool kUnused = (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); }, + /*platform_name=*/"cpu"), + true); class PjrtCApiCpuTest : public ::testing::Test { protected: @@ -66,17 +67,6 @@ class PjrtCApiCpuTest : public ::testing::Test { } }; -TEST_F(PjrtCApiCpuTest, PlatformName) { - PJRT_Client_PlatformName_Args args; - args.client = client_; - args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; - PJRT_Error* error = api_->PJRT_Client_PlatformName(&args); - ASSERT_EQ(error, nullptr); - absl::string_view platform_name(args.platform_name, args.platform_name_size); - ASSERT_EQ("cpu", platform_name); -} - } // namespace } // namespace pjrt } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 75f211487c2d71..42f5f8700892b0 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -15,8 +15,10 @@ limitations under the License. #include #include +#include #include #include +#include #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h" #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h" @@ -36,11 +38,19 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { args->num_options); const auto kExpectedOptionNameAndTypes = absl::flat_hash_map( - {{"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, + {{"visible_devices", + PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List}, + {"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, {"num_nodes", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}}); PJRT_RETURN_IF_ERROR( ValidateCreateOptions(create_options, kExpectedOptionNameAndTypes)); + std::optional> visible_devices; + if (auto it = create_options.find("visible_devices"); + it != create_options.end()) { + const auto& vec = std::get>(it->second); + visible_devices->insert(vec.begin(), vec.end()); + } int node_id = 0; if (auto it = create_options.find("node_id"); it != create_options.end()) { node_id = std::get(it->second); @@ -53,16 +63,15 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { // TODO(b/261916900) initializing allocator_config is important as should be // passed through the args later. xla::GpuAllocatorConfig allocator_config; - PJRT_ASSIGN_OR_RETURN( - std::unique_ptr client, - xla::GetStreamExecutorGpuClient( - /*asynchronous=*/true, allocator_config, node_id, num_nodes, - /*allowed_devices=*/std::nullopt, - /*platform_name=*/std::nullopt, true, - pjrt::ToCppKeyValueGetCallback(args->kv_get_callback, - args->kv_get_user_arg), - pjrt::ToCppKeyValuePutCallback(args->kv_put_callback, - args->kv_put_user_arg))); + PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, + xla::GetStreamExecutorGpuClient( + /*asynchronous=*/true, allocator_config, node_id, + num_nodes, visible_devices, + /*platform_name=*/std::nullopt, true, + pjrt::ToCppKeyValueGetCallback( + args->kv_get_callback, args->kv_get_user_arg), + pjrt::ToCppKeyValuePutCallback( + args->kv_put_callback, args->kv_put_user_arg))); args->client = pjrt::CreateWrapperClient(std::move(client)); return nullptr; } diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 627c8b864c882a..d1e7ef9ab96b31 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -32,8 +32,9 @@ namespace xla { namespace pjrt { namespace { -const bool kUnused = - (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); }), true); +const bool kUnused = (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); }, + /*platform_name=*/"gpu"), + true); class PjrtCApiGpuTest : public ::testing::Test { protected: @@ -74,22 +75,6 @@ class PjrtCApiGpuTest : public ::testing::Test { } }; -TEST_F(PjrtCApiGpuTest, PlatformName) { - PJRT_Client_PlatformName_Args args; - args.client = client_; - args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; - PJRT_Error* error = api_->PJRT_Client_PlatformName(&args); - ASSERT_EQ(error, nullptr); - absl::string_view platform_name(args.platform_name, args.platform_name_size); - ASSERT_EQ("gpu", platform_name); -} - -TEST_F(PjrtCApiGpuTest, ApiVersion) { - CHECK_EQ(api_->pjrt_api_version.major_version, PJRT_API_MAJOR); - CHECK_EQ(api_->pjrt_api_version.minor_version, PJRT_API_MINOR); -} - std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> CreateTestCKVCallback( absl::flat_hash_map* kv_store, absl::Mutex& mu) { PjRtClient::KeyValueGetCallback kv_get = @@ -124,9 +109,7 @@ std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> CreateTestCKVCallback( absl::StatusOr BuildCreateArg( ::pjrt::PJRT_KeyValueCallbackData* kv_callback_data, - const absl::flat_hash_map& options) { - TF_ASSIGN_OR_RETURN(std::vector c_options, - ::pjrt::ConvertToPjRtNamedValueList(options)); + std::vector& c_options) { PJRT_Client_Create_Args args; args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; args.priv = nullptr; @@ -158,8 +141,11 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { absl::flat_hash_map options = { {"num_nodes", static_cast(num_nodes)}, {"node_id", static_cast(i)}}; - TF_ASSERT_OK_AND_ASSIGN(PJRT_Client_Create_Args create_arg, - BuildCreateArg(kv_callback_data.get(), options)); + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); + TF_ASSERT_OK_AND_ASSIGN( + PJRT_Client_Create_Args create_arg, + BuildCreateArg(kv_callback_data.get(), c_options)); PJRT_Error* error = api->PJRT_Client_Create(&create_arg); EXPECT_EQ(error, nullptr) << error->status.message(); diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc index b4d51f2d18924a..4daf5669954bfe 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc @@ -32,12 +32,15 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h" #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_future.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -71,10 +74,14 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f class TestCApiFactory { public: - void Register(std::function factory) { + void Register(std::function factory, + absl::string_view platform_name) { absl::MutexLock lock(&mu_); CHECK(!factory_); factory_ = std::move(factory); + CHECK(platform_name_.empty()) << "Platform name already provided"; + CHECK(!platform_name.empty()) << "Provided platform name is empty"; + platform_name_ = platform_name; } std::function Get() const { @@ -83,9 +90,17 @@ class TestCApiFactory { return factory_; } + std::string GetPlatformName() const { + absl::MutexLock lock(&mu_); + CHECK(!platform_name_.empty()) + << "Test didn't call RegisterPjRtCApiTestFactory()"; + return platform_name_; + } + private: mutable absl::Mutex mu_; std::function factory_ ABSL_GUARDED_BY(mu_); + std::string platform_name_; }; TestCApiFactory& GetGlobalTestCApiFactory() { @@ -95,10 +110,15 @@ TestCApiFactory& GetGlobalTestCApiFactory() { const PJRT_Api* GetCApi() { return GetGlobalTestCApiFactory().Get()(); } +std::string GetPlatformName() { + return GetGlobalTestCApiFactory().GetPlatformName(); +} + } // namespace -void RegisterPjRtCApiTestFactory(std::function factory) { - GetGlobalTestCApiFactory().Register(std::move(factory)); +void RegisterPjRtCApiTestFactory(std::function factory, + absl::string_view platform_name) { + GetGlobalTestCApiFactory().Register(std::move(factory), platform_name); } namespace { @@ -106,6 +126,7 @@ class PjrtCApiTest : public ::testing::Test { protected: const PJRT_Api* api_; PJRT_Client* client_; + std::string platform_name_; // We directly access the internal C++ client to test if the C API has the // same behavior as the C++ API. xla::PjRtClient* cc_client_; @@ -114,6 +135,7 @@ class PjrtCApiTest : public ::testing::Test { void SetUp() override { api_ = GetCApi(); client_ = make_client(); + platform_name_ = GetPlatformName(); } void TearDown() override { destroy_client(client_); } @@ -283,6 +305,7 @@ class PjrtCApiTest : public ::testing::Test { .struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE, .priv = nullptr, .src = src_buffer, + .host_layout = nullptr, .dst = nullptr, .dst_size = 0, .event = nullptr, @@ -380,6 +403,17 @@ TEST_F(PjrtCApiTest, ApiVersion) { // ---------------------------------- Client ----------------------------------- +TEST_F(PjrtCApiTest, PlatformName) { + PJRT_Client_PlatformName_Args args; + args.client = client_; + args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; + args.priv = nullptr; + PJRT_Error* error = api_->PJRT_Client_PlatformName(&args); + ASSERT_EQ(error, nullptr); + absl::string_view platform_name(args.platform_name, args.platform_name_size); + ASSERT_EQ(platform_name_, platform_name); +} + TEST_F(PjrtCApiTest, ClientProcessIndex) { PJRT_Client_ProcessIndex_Args process_index_args = PJRT_Client_ProcessIndex_Args{ @@ -864,6 +898,31 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { EXPECT_EQ(error, nullptr); } +TEST_F(PjrtCApiBufferTest, ToHostBufferNoHostLayout) { + PJRT_Buffer_ToHostBuffer_Args args; + args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; + args.priv = nullptr; + args.src = buffer_.get(); + Shape host_shape = ShapeUtil::MakeShape(F32, {4}); + auto literal = std::make_shared(host_shape); + args.host_layout = nullptr; + args.dst = literal->untyped_data(); + args.dst_size = ShapeUtil::ByteSizeOfElements(host_shape); + args.event = nullptr; + + PJRT_Error* error = api_->PJRT_Buffer_ToHostBuffer(&args); + PjRtFuture transfer_to_host = + ::pjrt::ConvertCEventToCppFuture(args.event, api_); + TF_CHECK_OK(transfer_to_host.Await()); + + EXPECT_EQ(error, nullptr); + ASSERT_EQ(literal->data().size(), 4); + std::vector float_data(4); + std::iota(float_data.begin(), float_data.end(), 41.0f); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1(float_data), + *literal)); +} + // --------------------------------- Helpers ----------------------------------- class PjrtCommonCApiHelpersTest : public PjrtCApiTest {}; diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.h index a2f4b7c1334652..742bad437d7b4d 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h" namespace xla { @@ -28,7 +29,8 @@ namespace pjrt { // all the tests in this test factory with the PJRT_Api generated by the input // to RegisterPjRtCApiTestFactory. See // tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc for an example usage -void RegisterPjRtCApiTestFactory(std::function factory); +void RegisterPjRtCApiTestFactory(std::function factory, + absl::string_view platform_name); } // namespace pjrt } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 32bac4e95fc26f..976c9b19bd4257 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -210,6 +210,13 @@ PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args) { return nullptr; } +// ---------------------------------- Plugin ----------------------------------- + +PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args) { + args->num_attributes = 0; + return nullptr; +} + // ---------------------------------- Client ----------------------------------- PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args) { @@ -471,7 +478,7 @@ PJRT_Error* PJRT_Client_BufferFromHostBuffer( PJRT_Buffer_MemoryLayout_Type_Strides: { PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError(absl::StrCat( "PJRT_Buffer_MemoryLayout_Type_Strides in device_layout is not " - "supported in PJRT_Client_BufferFromHostBuffer for platform '%s'", + "supported in PJRT_Client_BufferFromHostBuffer for platform ", args->client->client->platform_name()))); break; } @@ -1347,8 +1354,24 @@ PJRT_Error* PJRT_Buffer_ToHostBuffer(PJRT_Buffer_ToHostBuffer_Args* args) { } else { device_shape = args->src->buffer->on_device_shape(); } - const xla::Shape& host_shape = - xla::ShapeUtil::DeviceShapeToHostShape(device_shape); + xla::Shape host_shape = xla::ShapeUtil::DeviceShapeToHostShape(device_shape); + if (args->host_layout != nullptr) { + if (args->host_layout->type == + PJRT_Buffer_MemoryLayout_Type::PJRT_Buffer_MemoryLayout_Type_Strides) { + PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError( + absl::StrCat("PJRT_Buffer_ToHostBuffer does not support host_layout " + "with strides for platform ", + args->src->buffer->client()->platform_name()))); + } + if (args->host_layout->tiled.num_tiles > 0) { + PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError( + absl::StrCat("PJRT_Buffer_ToHostBuffer does not support host_layout " + "with tiled dimension for platform ", + args->src->buffer->client()->platform_name()))); + } + PJRT_ASSIGN_OR_RETURN(*host_shape.mutable_layout(), + ConvertToLayout(args->host_layout->tiled)); + } size_t host_buffer_size = xla::ShapeUtil::ByteSizeOfElements(host_shape); diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 63f1e46809eda1..d1b89621b90fca 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -141,6 +141,8 @@ void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args); void PJRT_Error_Message(PJRT_Error_Message_Args* args); PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args); +PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args); + PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args); PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args); PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args); @@ -320,6 +322,8 @@ constexpr PJRT_Api CreatePjrtApi( /*PJRT_Error_Message=*/pjrt::PJRT_Error_Message, /*PJRT_Error_GetCode=*/pjrt::PJRT_Error_GetCode, + /*PJRT_Plugin_Attributes=*/pjrt::PJRT_Plugin_Attributes, + /*PJRT_Event_Destroy=*/pjrt::PJRT_Event_Destroy, /*PJRT_Event_IsReady=*/pjrt::PJRT_Event_IsReady, /*PJRT_Event_Error=*/pjrt::PJRT_Event_Error, diff --git a/tensorflow/compiler/xla/pjrt/execute_options.proto b/tensorflow/compiler/xla/pjrt/execute_options.proto new file mode 100644 index 00000000000000..af9558200fb907 --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/execute_options.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package xla; + +enum ExecutionModeProto { + EXECUTION_MODE_UNSPECIFIED = 0; + EXECUTION_MODE_DEFAULT = 1; + EXECUTION_MODE_SYNCHRONOUS = 2; + EXECUTION_MODE_ASYNCHRONOUS = 3; +} + +// Mirrors `xla::ExecuteOptions`. +message ExecuteOptionsProto { + bool arguments_are_tupled = 1; + bool untuple_result = 2; + int32 launch_id = 3; + bool strict_shape_checking = 4; + ExecutionModeProto execution_mode = 6; + repeated int32 non_donatable_input_indices = 7; +} diff --git a/tensorflow/compiler/xla/pjrt/gpu/BUILD b/tensorflow/compiler/xla/pjrt/gpu/BUILD index 33a8e7f9299ff1..3b749c980316fd 100644 --- a/tensorflow/compiler/xla/pjrt/gpu/BUILD +++ b/tensorflow/compiler/xla/pjrt/gpu/BUILD @@ -94,6 +94,7 @@ xla_cc_test( "gpu", "no_oss", "noasan", + "nomsan", "requires-gpu-nvidia:2", "no_rocm", ], diff --git a/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 532dc1bef757b8..3d9fee99e156c6 100644 --- a/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -271,6 +271,47 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsync) { literal->Relayout(src_literal.shape().layout()).data()); } +TEST(StreamExecutorGpuClientTest, ToLiteralAsyncBeforeBufferReady) { + TF_ASSERT_OK_AND_ASSIGN( + auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, + /*node_id=*/0)); + ASSERT_GE(client->addressable_devices().size(), 1); + + auto src_literal = LiteralUtil::CreateR1({41.0f, 42.0f, 43.0f, 44.0f}); + TF_ASSERT_OK_AND_ASSIGN( + auto transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + {src_literal.shape()}, client->addressable_devices()[0])); + auto buffer = transfer_manager->RetrieveBuffer(0); + + absl::Mutex mu; + auto literal = std::make_shared( + ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape())); + bool got_literal = false; + + buffer->ToLiteral(literal.get(), [&](Status s) { + absl::MutexLock l(&mu); + TF_ASSERT_OK(s); + got_literal = true; + }); + + absl::SleepFor(absl::Milliseconds(10)); + ASSERT_FALSE(got_literal); + TF_ASSERT_OK( + transfer_manager->TransferLiteralToBuffer(0, src_literal, [&]() {})); + + buffer.reset(); + + { + absl::MutexLock l(&mu); + mu.Await(absl::Condition(&got_literal)); + } + + ASSERT_TRUE(ShapeUtil::Compatible(src_literal.shape(), literal->shape())); + ASSERT_EQ(src_literal.data(), + literal->Relayout(src_literal.shape().layout()).data()); +} + TEST(StreamExecutorGpuClientTest, FromHostAsync) { TF_ASSERT_OK_AND_ASSIGN( auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc index 4bbf4663891f62..44e0accce8f7b1 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc @@ -1441,6 +1441,18 @@ PjRtFuture PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) { args.dst_size = ShapeUtil::ByteSizeOfElements(shape); args.dst = literal->untyped_data(); + xla::StatusOr c_layout_data; + if (literal->shape().has_layout()) { + c_layout_data = + pjrt::ConvertToBufferMemoryLayoutData(&literal->shape().layout()); + if (!c_layout_data.ok()) { + return PjRtFuture(c_layout_data.status()); + } + args.host_layout = &(c_layout_data->c_layout); + } else { + args.host_layout = nullptr; + } + const PJRT_Api* api = pjrt_c_api(); std::unique_ptr error{ diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index ba6e5856042207..ceff30f1825e43 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -60,7 +60,6 @@ PjRtFuture PjRtBuffer::CopyRawToHostFuture( return PjRtFuture(std::move(promise)); } -MultiSliceConfig::~MultiSliceConfig() = default; std::string CompiledMemoryStats::DebugString() const { return absl::Substitute( diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index a59502f33167fd..ee2715e9743753 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -1199,112 +1199,6 @@ class PjRtBuffer { virtual bool IsOnCpu() const = 0; }; -class ExecuteContext { - public: - virtual ~ExecuteContext() = default; -}; - -struct PjRtTransferMetadata { - // May be invalid if - // ExecuteOptions::use_major_to_minor_data_layout_for_callbacks is true for - // this execution. - Shape device_shape; -}; - -struct SendCallback { - int64_t channel_id; - // The callback for retrieving the send value. It will be invoked once for - // each invocation of the corresponding Send op in the HLO program (So it can - // be invoked multiple times if it is in a loop). Currently there is no - // guarantee that the callback here will be invoked in the same order as their - // corresponding HLO Send ops. The callback can also return errors to indicate - // the execution should fail. - // - // IMPORTANT: the implementation might NOT signal the error to the execution, - // and the execution will run to completion with UNDEFINED DATA returned by - // the callback. If there is any potential control flow that depends on the - // value of the returned data, an error return is unsafe. - // - // TODO(chky): Currently the callback invocation order may not be consistent - // with the HLO send op invocation order, due to limitations in some PjRt - // implementation. Consider making it strictly the same order as HLO program. - std::function - callback; -}; - -struct RecvCallback { - int64_t channel_id; - // The callback for feeding the recv value. It will be invoked once for each - // invocation of the corresponding Recv op in the HLO program (So it can be - // invoked multiple times if it is in a loop). Currently there is no - // guarantee that the callback here will be invoked in the same order as their - // corresponding HLO Recv ops. - std::function stream)> - callback; -}; - -struct ExecuteOptions { - // If true, the client must pass a single PjRtBuffer which contains all of - // the arguments as a single XLA tuple, otherwise each argument must be - // passed in its own PjRtBuffer. May only be true if the executable was - // compiled with parameter_is_tupled_arguments==true. - bool arguments_are_tupled = false; - // If true, the computation must return a tuple, which will be destructured - // into its elements. - bool untuple_result = false; - // If non-zero, identifies this execution as part of a potentially - // multi-device launch. This can be used to detect scheduling errors, e.g. if - // multi-host programs are launched in different orders on different hosts, - // the launch IDs may be used by the runtime to detect the mismatch. - int32_t launch_id = 0; - // If non-null, an opaque context passed to an execution that may be used to - // supply additional arguments to a derived class of PjRtExecutable. - const ExecuteContext* context = nullptr; - // If true, check that the PjRtBuffer argument shapes match the compiled - // shapes. Otherwise, any shape with the right size on device may be passed. - bool strict_shape_checking = true; - - // Set multi_slice_config when the computation spans multiple slices. The - // config should match what was used during compilation to generate this - // executable. - const MultiSliceConfig* multi_slice_config = nullptr; - - // The send/recv callbacks for PjRt execution. The first level span is for - // multi-device parallel execution, the second level vector contains the - // callbacks for all send/recv ops in the executable. These callbacks can be - // stateful and the user code is responsible for managing the states here. - // These callbacks must outlive the execution. - absl::Span> send_callbacks; - absl::Span> recv_callbacks; - - // If true, send callbacks are passed PjRtChunks in major-to-minor layout, and - // recv functions should pass major-to-minor chunks to - // CopyToDeviceStream::AddChunk. - // - // If false, send callbacks are passed PjRtChunks in the on-device layout - // specified in the PjRtTransferMetadata, and recv functions should similarly - // pass device-layout chunks to CopyToDeviceStream::AddChunk. - bool use_major_to_minor_data_layout_for_callbacks = false; - - // The `execution_mode` decides whether the execution will be invoked in the - // caller thread or launched to a separate thread. By default, the - // implementation may choose either strategy or use a heuristic to decide. - // Currently it is only applied to CPU implementations - enum class ExecutionMode { kDefault = 0, kSynchronous, kAsynchronous }; - ExecutionMode execution_mode = ExecutionMode::kDefault; - - // A set of indices denoting the input buffers that should not be donated. - // An input buffer may be non-donable, for example, if it is referenced more - // than once. Since such runtime information is not available at compile time, - // the compiler might mark the input as `may-alias`, which could lead PjRt to - // donate the input buffer when it should not. By defining this set of - // indices, a higher-level PjRt caller can instruct PjRtClient not to donate - // specific input buffers. - absl::flat_hash_set non_donatable_input_indices; -}; - // Represents a compiled computation that can be executed given handles to // device-allocated literals. If any input/output alias has been specified in // the computation, the parameter containing the input buffer will be donated diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable.cc b/tensorflow/compiler/xla/pjrt/pjrt_executable.cc index b4ac9732815130..f087e103003bc1 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/pjrt/execute_options.pb.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/platform/statusor.h" @@ -108,6 +109,78 @@ StatusOr CompileOptions::FromProto( return output; } +MultiSliceConfig::~MultiSliceConfig() = default; + +absl::StatusOr ExecuteOptions::ToProto() const { + ExecuteOptionsProto proto; + + proto.set_arguments_are_tupled(arguments_are_tupled); + proto.set_untuple_result(untuple_result); + proto.set_launch_id(launch_id); + if (context != nullptr) { + return absl::UnimplementedError( + "ExecuteOptions with non-nullptr context is not serializable"); + } + proto.set_strict_shape_checking(strict_shape_checking); + + if (multi_slice_config != nullptr) { + return absl::UnimplementedError( + "ExecuteOptions with multi-slice config is not serializable"); + } + + if (!send_callbacks.empty() || !recv_callbacks.empty()) { + return absl::UnimplementedError( + "ExecuteOptions with send/recv calbacks is not serializable"); + } + + switch (execution_mode) { + case ExecutionMode::kDefault: + proto.set_execution_mode(EXECUTION_MODE_DEFAULT); + break; + case ExecutionMode::kSynchronous: + proto.set_execution_mode(EXECUTION_MODE_SYNCHRONOUS); + break; + case ExecutionMode::kAsynchronous: + proto.set_execution_mode(EXECUTION_MODE_ASYNCHRONOUS); + break; + } + + proto.mutable_non_donatable_input_indices()->Add( + non_donatable_input_indices.begin(), non_donatable_input_indices.end()); + + return proto; +} + +absl::StatusOr ExecuteOptions::FromProto( + const ExecuteOptionsProto& proto) { + ExecuteOptions options; + + options.arguments_are_tupled = proto.arguments_are_tupled(); + options.untuple_result = proto.untuple_result(); + options.launch_id = proto.launch_id(); + + switch (proto.execution_mode()) { + case EXECUTION_MODE_DEFAULT: + options.execution_mode = ExecutionMode::kDefault; + break; + case EXECUTION_MODE_SYNCHRONOUS: + options.execution_mode = ExecutionMode::kSynchronous; + break; + case EXECUTION_MODE_ASYNCHRONOUS: + options.execution_mode = ExecutionMode::kAsynchronous; + break; + default: + return absl::UnimplementedError( + absl::StrCat("Unknown execution mode: ", proto.execution_mode())); + } + + options.non_donatable_input_indices.insert( + proto.non_donatable_input_indices().begin(), + proto.non_donatable_input_indices().end()); + + return options; +} + void GetOpSharding(std::vector& out, const OpSharding& sharding) { if (sharding.type() == OpSharding::TUPLE) { for (const OpSharding& s : sharding.tuple_shardings()) { diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable.h b/tensorflow/compiler/xla/pjrt/pjrt_executable.h index 3dfbb6c038442e..96610182284b14 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_EXECUTABLE_H_ #include +#include #include #include #include @@ -24,9 +25,11 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/pjrt/execute_options.pb.h" #include "tensorflow/compiler/xla/pjrt/pjrt_common.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" @@ -120,6 +123,120 @@ struct LoadOptions { const MultiSliceConfig* multi_slice_config = nullptr; }; +class ExecuteContext { + public: + virtual ~ExecuteContext() = default; +}; + +struct PjRtTransferMetadata { + // May be invalid if + // ExecuteOptions::use_major_to_minor_data_layout_for_callbacks is true for + // this execution. + Shape device_shape; +}; + +class PjRtChunk; +class PjRtTransferMetadata; +class CopyToDeviceStream; + +struct SendCallback { + int64_t channel_id; + // The callback for retrieving the send value. It will be invoked once for + // each invocation of the corresponding Send op in the HLO program (So it can + // be invoked multiple times if it is in a loop). Currently there is no + // guarantee that the callback here will be invoked in the same order as their + // corresponding HLO Send ops. The callback can also return errors to indicate + // the execution should fail. + // + // IMPORTANT: the implementation might NOT signal the error to the execution, + // and the execution will run to completion with UNDEFINED DATA returned by + // the callback. If there is any potential control flow that depends on the + // value of the returned data, an error return is unsafe. + // + // TODO(chky): Currently the callback invocation order may not be consistent + // with the HLO send op invocation order, due to limitations in some PjRt + // implementation. Consider making it strictly the same order as HLO program. + std::function + callback; +}; + +struct RecvCallback { + int64_t channel_id; + // The callback for feeding the recv value. It will be invoked once for each + // invocation of the corresponding Recv op in the HLO program (So it can be + // invoked multiple times if it is in a loop). Currently there is no + // guarantee that the callback here will be invoked in the same order as their + // corresponding HLO Recv ops. + std::function stream)> + callback; +}; + +struct ExecuteOptions { + // If true, the client must pass a single PjRtBuffer which contains all of + // the arguments as a single XLA tuple, otherwise each argument must be + // passed in its own PjRtBuffer. May only be true if the executable was + // compiled with parameter_is_tupled_arguments==true. + bool arguments_are_tupled = false; + // If true, the computation must return a tuple, which will be destructured + // into its elements. + bool untuple_result = false; + // If non-zero, identifies this execution as part of a potentially + // multi-device launch. This can be used to detect scheduling errors, e.g. if + // multi-host programs are launched in different orders on different hosts, + // the launch IDs may be used by the runtime to detect the mismatch. + int32_t launch_id = 0; + // If non-null, an opaque context passed to an execution that may be used to + // supply additional arguments to a derived class of PjRtExecutable. + const ExecuteContext* context = nullptr; + // If true, check that the PjRtBuffer argument shapes match the compiled + // shapes. Otherwise, any shape with the right size on device may be passed. + bool strict_shape_checking = true; + + // Set multi_slice_config when the computation spans multiple slices. The + // config should match what was used during compilation to generate this + // executable. + const MultiSliceConfig* multi_slice_config = nullptr; + + // The send/recv callbacks for PjRt execution. The first level span is for + // multi-device parallel execution, the second level vector contains the + // callbacks for all send/recv ops in the executable. These callbacks can be + // stateful and the user code is responsible for managing the states here. + // These callbacks must outlive the execution. + absl::Span> send_callbacks; + absl::Span> recv_callbacks; + + // If true, send callbacks are passed PjRtChunks in major-to-minor layout, and + // recv functions should pass major-to-minor chunks to + // CopyToDeviceStream::AddChunk. + // + // If false, send callbacks are passed PjRtChunks in the on-device layout + // specified in the PjRtTransferMetadata, and recv functions should similarly + // pass device-layout chunks to CopyToDeviceStream::AddChunk. + bool use_major_to_minor_data_layout_for_callbacks = false; + + // The `execution_mode` decides whether the execution will be invoked in the + // caller thread or launched to a separate thread. By default, the + // implementation may choose either strategy or use a heuristic to decide. + // Currently it is only applied to CPU implementations + enum class ExecutionMode { kDefault = 0, kSynchronous, kAsynchronous }; + ExecutionMode execution_mode = ExecutionMode::kDefault; + + // A set of indices denoting the input buffers that should not be donated. + // An input buffer may be non-donable, for example, if it is referenced more + // than once. Since such runtime information is not available at compile time, + // the compiler might mark the input as `may-alias`, which could lead PjRt to + // donate the input buffer when it should not. By defining this set of + // indices, a higher-level PjRt caller can instruct PjRtClient not to donate + // specific input buffers. + absl::flat_hash_set non_donatable_input_indices; + + absl::StatusOr ToProto() const; + static absl::StatusOr FromProto( + const ExecuteOptionsProto& proto); +}; + // Static device memory usage for a compiled program. // The on-device memory needed to run an executable is at least // generated_code_size_in_bytes diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc b/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc index 747755a05bd529..a480172117d5f7 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc @@ -14,15 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/pjrt/pjrt_executable.h" +#include + +#include #include #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/pjrt/compile_options.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/status_matchers.h" namespace xla { namespace { +using ::tsl::testing::StatusIs; + TEST(CompileOptionsTest, Serialization) { CompileOptions src; src.compile_portable_executable = true; @@ -41,15 +47,47 @@ TEST(CompileOptionsTest, Serialization) { EXPECT_EQ(proto.SerializeAsString(), output_proto.SerializeAsString()); } -TEST(FromProtoTest, MultiSliceConfigNotSupported) { +TEST(CompileOptionsTest, MultiSliceConfigNotSupported) { CompileOptionsProto proto; *proto.mutable_serialized_multi_slice_config() = "multi_size_config"; auto option = CompileOptions::FromProto(proto); - EXPECT_EQ(option.status().code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(option.status().message(), - "multi_slice_config not supported in CompileOptions::FromProto."); + EXPECT_THAT( + option.status(), + StatusIs( + absl::StatusCode::kUnimplemented, + "multi_slice_config not supported in CompileOptions::FromProto.")); +} + +TEST(ExecuteOptionsTest, Serialization) { + ExecuteOptions src; + src.arguments_are_tupled = true; + src.untuple_result = false; + src.launch_id = 1234; + src.strict_shape_checking = true; + src.execution_mode = ExecuteOptions::ExecutionMode::kAsynchronous; + src.non_donatable_input_indices = {2, 3}; + + TF_ASSERT_OK_AND_ASSIGN(ExecuteOptionsProto proto, src.ToProto()); + TF_ASSERT_OK_AND_ASSIGN(ExecuteOptions output, + ExecuteOptions::FromProto(proto)); + TF_ASSERT_OK_AND_ASSIGN(ExecuteOptionsProto output_proto, src.ToProto()); + + EXPECT_EQ(proto.SerializeAsString(), output_proto.SerializeAsString()); +} + +TEST(ExecuteOptionsTest, SendRecvNotSupported) { + ExecuteOptions options; + std::vector> send_callbacks(1); + options.send_callbacks = send_callbacks; + std::vector> recv_callbacks(1); + options.recv_callbacks = recv_callbacks; + + EXPECT_THAT( + options.ToProto(), + StatusIs(absl::StatusCode::kUnimplemented, + "ExecuteOptions with send/recv calbacks is not serializable")); } } // namespace diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 2043ff62361fb2..928704167ce290 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -1088,7 +1088,21 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( void* device_ptr, const Shape& shape, PjRtDevice* device, std::function on_delete_callback) { se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape)); - absl::Span> definition_events; + + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + tensorflow::down_cast(device) + ->GetLocalDeviceState()); + + absl::InlinedVector, 2> + definition_events; + definition_events.emplace_back( + std::make_shared(this->thread_pool())); + TF_ASSIGN_OR_RETURN(EventPool::Handle event, + local_device->event_pool().ThenAllocateAndRecordEvent( + local_device->compute_stream())); + definition_events.back()->SetSequencingEvent(std::move(event), + local_device->compute_stream()); + auto device_buffer = std::make_shared( /*allocator=*/nullptr, device->local_hardware_id(), std::initializer_list{buffer}, definition_events, @@ -1343,48 +1357,65 @@ PjRtFuture PjRtStreamExecutorBuffer::ToLiteral( AcquireHoldLocked(&device_buffer); } - WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); - ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_); - StatusOr event_or = - local_device->event_pool().AllocateEvent(stream->parent()); - if (!event_or.ok()) { - return PjRtFuture(event_or.status()); - } - - GenericTransferManager::LiteralFromDeviceMetadata transfer_metadata; - // We never call device functions from the `done` callback. - transfer_metadata.callback_is_host_callback_safe = true; + auto promise = PjRtFuture::CreatePromise(); + auto usage_event = + std::make_shared(client_->thread_pool()); TransferManager* transfer_manager = client_->client()->backend().transfer_manager(); - TransferManager::TransferMetadata* transfer_metadata_ptr = - (dynamic_cast(transfer_manager) != nullptr) - ? &transfer_metadata - : nullptr; + auto tracked_device_buffer = device_buffer.buffer(); + + // When using the ComputeSynchronized allocation model, retain a + // reference to the device_buffer until the copy completes, to + // ensure that the buffer isn't deleted or donated while it is still + // in use. The choice of retaining a reference at the host is a + // heuristic; the alternative is to ensure, before freeing the + // buffer, that the compute stream is synchronized past the + // transfer, but it seems better to hold onto the buffer too long + // than to stall the compute stream, particularly since the + // overwhelmingly common use case of CopyToHostAsync will hold onto + // the reference long enough to read the buffer in a subsequent call + // to ToLiteral. + device_buffer.ConvertUsageHold(stream, usage_event, /*reference_held=*/true); + + auto async_to_literal = [usage_event, tracked_device_buffer, stream, + transfer_manager = std::move(transfer_manager), + on_device_shape{on_device_shape_}, literal, promise, + local_device]() mutable { + StatusOr event_or = + local_device->event_pool().AllocateEvent(stream->parent()); + if (!event_or.ok()) { + promise.Set(event_or.status()); + return; + } + WaitForBufferDefinitionEventsOnStream(*tracked_device_buffer, stream); + ShapedBuffer shaped_buffer = + tracked_device_buffer->AsShapedBuffer(on_device_shape); - auto promise = PjRtFuture::CreatePromise(); - transfer_manager->TransferLiteralFromDevice( - stream, shaped_buffer, literal, - [promise](Status status) mutable { promise.Set(status); }, - transfer_metadata_ptr); + GenericTransferManager::LiteralFromDeviceMetadata transfer_metadata; + // We never call device functions from the `done` callback. + transfer_metadata.callback_is_host_callback_safe = true; - auto usage_event = - std::make_shared(client_->thread_pool()); - local_device->event_pool().ThenRecordEvent(stream, event_or.value()); - usage_event->SetSequencingEvent(std::move(event_or).value(), stream); - // When using the ComputeSynchronized allocation model, retain a reference to - // the device_buffer until the copy completes, to ensure that the buffer isn't - // deleted or donated while it is still in use. The choice of retaining a - // reference at the host is a heuristic; the alternative is to ensure, before - // freeing the buffer, that the compute stream is synchronized past the - // transfer, but it seems better to hold onto the buffer too long than to - // stall the compute stream, particularly since the overwhelmingly common - // use case of CopyToHostAsync will hold onto the reference long enough to - // read the buffer in a subsequent call to ToLiteral. - RecordUsage(std::move(device_buffer), local_device, local_device, usage_event, - stream, - /*prefer_to_retain_reference=*/true); + TransferManager::TransferMetadata* transfer_metadata_ptr = + (dynamic_cast(transfer_manager) != nullptr) + ? &transfer_metadata + : nullptr; + + transfer_manager->TransferLiteralFromDevice( + stream, shaped_buffer, literal, + [promise](Status status) mutable { promise.Set(status); }, + transfer_metadata_ptr); + + local_device->event_pool().ThenRecordEvent(stream, event_or.value()); + usage_event->SetSequencingEvent(std::move(event_or).value(), stream); + + local_device->ThenRelease(stream, tracked_device_buffer); + }; + + tracked_device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( + absl::StrFormat("async_to_literal_%p", literal), + std::move(async_to_literal)); return PjRtFuture( std::move(promise), @@ -1720,29 +1751,29 @@ struct TupleHandle { }; Status CheckCompatibleShapes(bool strict_shape_checking, - const Shape& buffer_shape, + const Shape& buffer_on_device_shape, const Shape& execution_shape, const TransferManager& transfer_manager, int parameter_index) { // TODO(misard) Support casting of tuple parameters. - if (strict_shape_checking || buffer_shape.IsTuple()) { - if (!ShapeUtil::Equal(buffer_shape, execution_shape)) { + if (strict_shape_checking || buffer_on_device_shape.IsTuple()) { + if (!ShapeUtil::Compatible(buffer_on_device_shape, execution_shape)) { return InvalidArgument( "Executable expected shape %s for argument %d but got " "incompatible " "shape %s", ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index, - ShapeUtil::HumanStringWithLayout(buffer_shape)); + ShapeUtil::HumanStringWithLayout(buffer_on_device_shape)); } } else { - if (transfer_manager.GetByteSizeRequirement(buffer_shape) != + if (transfer_manager.GetByteSizeRequirement(buffer_on_device_shape) != transfer_manager.GetByteSizeRequirement(execution_shape)) { return InvalidArgument( "Executable expected shape %s for argument %d but got " "incompatible " "shape %s", ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index, - ShapeUtil::HumanStringWithLayout(buffer_shape)); + ShapeUtil::HumanStringWithLayout(buffer_on_device_shape)); } } return OkStatus(); diff --git a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h index b8c6bdee610fe6..25cf4d8ddff523 100644 --- a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h +++ b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h @@ -21,13 +21,14 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/runtime/cpu_event.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/mem.h" +#include "tensorflow/tsl/platform/threadpool.h" #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime namespace xla { diff --git a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc index 0bc6162b819b8d..05d47449c6cd3d 100644 --- a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include namespace xla { diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 9967de1243a745..f7614d7ae9adc4 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -746,9 +746,8 @@ cc_library( srcs = ["refine_polymorphic_shapes.cc"], hdrs = ["refine_polymorphic_shapes.h"], deps = [ - "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla/mlir/utils:error_util", - "@com_google_absl//absl/log", + "//tensorflow/tsl/platform:errors", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/xla/python/ifrt/BUILD b/tensorflow/compiler/xla/python/ifrt/BUILD index ec8774cd5ef187..d2f06b3e175df3 100644 --- a/tensorflow/compiler/xla/python/ifrt/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/BUILD @@ -66,6 +66,7 @@ cc_library( ], deps = [ ":serdes", + ":types_proto_cc", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -75,6 +76,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -146,6 +148,7 @@ xla_cc_test( srcs = ["sharding_test.cc"], deps = [ ":ifrt", + ":sharding_test_util", "//tensorflow/compiler/xla/python/ifrt/ir", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status_matchers", @@ -157,7 +160,7 @@ xla_cc_test( cc_library( name = "test_util", - testonly = 1, + testonly = True, srcs = ["test_util.cc"], hdrs = ["test_util.h"], deps = [ @@ -173,9 +176,22 @@ cc_library( ], ) +cc_library( + name = "sharding_test_util", + testonly = True, + srcs = ["sharding_test_util.cc"], + hdrs = ["sharding_test_util.h"], + deps = [ + ":ifrt", + ":mock", + ":test_util", + "//tensorflow/tsl/platform:test", + ], +) + cc_library( name = "no_impl_test_main", - testonly = 1, + testonly = True, srcs = ["no_impl_test_main.cc"], deps = [ "@com_google_googletest//:gtest", @@ -184,7 +200,7 @@ cc_library( cc_library( name = "array_impl_test_lib", - testonly = 1, + testonly = True, srcs = ["array_impl_test_lib.cc"], deps = [ ":ifrt", @@ -194,7 +210,7 @@ cc_library( "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -208,14 +224,14 @@ xla_cc_test( cc_library( name = "client_impl_test_lib", - testonly = 1, + testonly = True, srcs = ["client_impl_test_lib.cc"], deps = [ ":ifrt", ":test_util", "//tensorflow/tsl/platform:test", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -230,7 +246,7 @@ xla_cc_test( cc_library( name = "tuple_impl_test_lib", - testonly = 1, + testonly = True, srcs = ["tuple_impl_test_lib.cc"], deps = [ ":ifrt", @@ -241,7 +257,7 @@ cc_library( "@com_google_absl//absl/types:span", "@tf_runtime//:ref_count", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -309,3 +325,42 @@ tf_proto_library( name = "serdes_proto", srcs = ["serdes.proto"], ) + +cc_library( + name = "sharding_serdes", + srcs = ["sharding_serdes.cc"], + hdrs = ["sharding_serdes.h"], + deps = [ + ":ifrt", + ":serdes", + ":sharding_proto_cc", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/tsl/platform:statusor", + "@llvm-project//llvm:Support", + ], + alwayslink = True, +) + +xla_cc_test( + name = "sharding_serdes_test", + srcs = ["sharding_serdes_test.cc"], + deps = [ + ":ifrt", + ":serdes", + ":sharding_serdes", + ":sharding_test_util", + "@com_google_absl//absl/functional:bind_front", + "@com_google_googletest//:gtest_main", + ], +) + +tf_proto_library( + name = "types_proto", + srcs = ["types.proto"], +) + +tf_proto_library( + name = "sharding_proto", + srcs = ["sharding.proto"], + protodeps = [":types_proto"], +) diff --git a/tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc b/tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc index adbcf74e6f80c2..9d1768fdc8da80 100644 --- a/tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc +++ b/tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc @@ -355,7 +355,6 @@ TEST(ArrayImplTest, AssembleAndDisassembleArray) { /*on_done_with_host_buffer=*/{})); std::vector> arrays({array0, array1}); - std::vector single_device_shapes({shape, shape}); Shape assembled_shape({4, 3}); ShardingParam sharding_param( /*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 1}}); diff --git a/tensorflow/compiler/xla/python/ifrt/device.cc b/tensorflow/compiler/xla/python/ifrt/device.cc index 0f02149ae48a64..629afe4d515854 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.cc +++ b/tensorflow/compiler/xla/python/ifrt/device.cc @@ -15,11 +15,35 @@ limitations under the License. #include "tensorflow/compiler/xla/python/ifrt/device.h" +#include #include +#include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" + namespace xla { namespace ifrt { +StatusOr DeviceList::FromProto(LookupDeviceFunc lookup_device, + const DeviceListProto& proto) { + DeviceList::Devices devices; + devices.reserve(proto.device_ids_size()); + for (int device_id : proto.device_ids()) { + TF_ASSIGN_OR_RETURN(Device * device, lookup_device(device_id)); + devices.push_back(device); + } + return DeviceList(std::move(devices)); +} + +DeviceListProto DeviceList::ToProto() const { + DeviceListProto proto; + proto.mutable_device_ids()->Reserve(devices().size()); + for (Device* device : devices()) { + proto.mutable_device_ids()->AddAlreadyReserved(device->id()); + } + return proto; +} + std::vector GetDeviceIds(DeviceList device_list) { std::vector ids; ids.reserve(device_list.devices().size()); diff --git a/tensorflow/compiler/xla/python/ifrt/device.h b/tensorflow/compiler/xla/python/ifrt/device.h index a2d5f61dd35c2a..89f123f5e93869 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.h +++ b/tensorflow/compiler/xla/python/ifrt/device.h @@ -20,11 +20,15 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" namespace xla { namespace ifrt { +class Client; + // Short-term alias to reuse `xla::PjRtDevice` without a separate abstract type. using Device = ::xla::PjRtDevice; @@ -40,8 +44,20 @@ class DeviceList { // better performance. using Devices = absl::InlinedVector; + // Function that matches the semantics of `Client::LookupDevice()`. + using LookupDeviceFunc = absl::FunctionRef(int)>; + explicit DeviceList(Devices devices) : devices_(std::move(devices)) {} + // Constructs `DeviceList` from `DeviceListProto`. Devices are looked up using + // `lookup_device`. Device ids in the proto must be consistent with the + // devices returned by `lookup_device`. + static StatusOr FromProto(LookupDeviceFunc lookup_device, + const DeviceListProto& proto); + + // Returns a `DeviceListProto` representation. + DeviceListProto ToProto() const; + absl::Span devices() const { return devices_; } int size() const { return devices_.size(); } diff --git a/tensorflow/compiler/xla/python/ifrt/dtype.cc b/tensorflow/compiler/xla/python/ifrt/dtype.cc index de04817559cdef..fe11b672449133 100644 --- a/tensorflow/compiler/xla/python/ifrt/dtype.cc +++ b/tensorflow/compiler/xla/python/ifrt/dtype.cc @@ -26,6 +26,7 @@ namespace ifrt { std::optional DType::byte_size() const { switch (kind_) { + case kPred: case kS8: case kU8: return 1; @@ -53,7 +54,6 @@ std::optional DType::byte_size() const { std::optional DType::bit_size() const { switch (kind_) { case kPred: - return 1; case kS8: case kU8: return 8; diff --git a/tensorflow/compiler/xla/python/ifrt/dtype.h b/tensorflow/compiler/xla/python/ifrt/dtype.h index 7888a479b09a48..f98e823c00f82f 100644 --- a/tensorflow/compiler/xla/python/ifrt/dtype.h +++ b/tensorflow/compiler/xla/python/ifrt/dtype.h @@ -97,8 +97,8 @@ class DType { bool operator!=(const DType& other) const { return kind_ != other.kind_; } // Returns the byte size of a single element of this DType. Returns - // std::nullopt if there is no fixed size or not aligned to a byte boundary - // (such as kPred). + // std::nullopt if not aligned to a byte boundary or there is no fixed size + // (such as kString). std::optional byte_size() const; // Returns the bit size of a single element of this DType. Returns diff --git a/tensorflow/compiler/xla/python/ifrt/ir/tests/executable_impl_test_base.h b/tensorflow/compiler/xla/python/ifrt/ir/tests/executable_impl_test_base.h index e903b57713a5f7..788a8a7d8c5ab1 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/tests/executable_impl_test_base.h +++ b/tensorflow/compiler/xla/python/ifrt/ir/tests/executable_impl_test_base.h @@ -58,7 +58,7 @@ class IfrtIrExecutableImplTestBase : public testing::Test { absl::StatusOr PickDevices(int count); mlir::MLIRContext mlir_context_; - std::unique_ptr client_; + std::shared_ptr client_; }; } // namespace test_util diff --git a/tensorflow/compiler/xla/python/ifrt/shape.cc b/tensorflow/compiler/xla/python/ifrt/shape.cc index bd3ff1fc8e08b6..07e8e2b81494a5 100644 --- a/tensorflow/compiler/xla/python/ifrt/shape.cc +++ b/tensorflow/compiler/xla/python/ifrt/shape.cc @@ -17,12 +17,37 @@ limitations under the License. #include #include +#include #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace ifrt { +StatusOr Shape::FromProto(const ShapeProto& proto) { + Shape::Dimensions dims; + dims.reserve(proto.dims_size()); + for (int64_t dim : proto.dims()) { + if (dim < 0) { + return InvalidArgument( + "Shape expects non-negative dimension sizes, but got %d", dim); + } + dims.push_back(dim); + } + return Shape(std::move(dims)); +} + +ShapeProto Shape::ToProto() const { + ShapeProto proto; + proto.mutable_dims()->Reserve(dims().size()); + for (int64_t dim : dims()) { + proto.mutable_dims()->AddAlreadyReserved(dim); + } + return proto; +} + int64_t Shape::num_elements() const { int64_t count = 1; for (int64_t d : dims_) { diff --git a/tensorflow/compiler/xla/python/ifrt/shape.h b/tensorflow/compiler/xla/python/ifrt/shape.h index 3558e3518ed84d..f3ce028789d5ef 100644 --- a/tensorflow/compiler/xla/python/ifrt/shape.h +++ b/tensorflow/compiler/xla/python/ifrt/shape.h @@ -22,6 +22,8 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { namespace ifrt { @@ -42,6 +44,12 @@ class Shape { Shape& operator=(const Shape&) = default; Shape& operator=(Shape&&) = default; + // Constructs `Shape` from `ShapeProto`. + static StatusOr FromProto(const ShapeProto& proto); + + // Returns a `ShapeProto` representation. + ShapeProto ToProto() const; + absl::Span dims() const { return dims_; } bool operator==(const Shape& other) const { return dims_ == other.dims_; } diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.cc b/tensorflow/compiler/xla/python/ifrt/sharding.cc index f057ad53fccf83..8caaf9f12a83e4 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding.cc @@ -159,8 +159,10 @@ std::ostream& operator<<(std::ostream& os, const Sharding& sharding) { return os << sharding.DebugString(); } -std::unique_ptr SingleDeviceSharding::Create(Device* device) { - return std::unique_ptr(new SingleDeviceSharding(device)); +std::unique_ptr SingleDeviceSharding::Create( + Device* device) { + return std::unique_ptr( + new SingleDeviceSharding(device)); } StatusOr>>> @@ -187,8 +189,9 @@ std::string SingleDeviceSharding::DebugString() const { devices_.front()->ToString()); } -std::unique_ptr OpaqueSharding::Create(DeviceList devices) { - return std::unique_ptr(new OpaqueSharding(std::move(devices))); +std::unique_ptr OpaqueSharding::Create(DeviceList devices) { + return std::unique_ptr( + new OpaqueSharding(std::move(devices))); } OpaqueSharding::OpaqueSharding(DeviceList devices) @@ -217,10 +220,10 @@ std::string OpaqueSharding::DebugString() const { })); } -std::unique_ptr ConcreteSharding::Create( +std::unique_ptr ConcreteSharding::Create( DeviceList devices, Shape shape, std::vector shard_shapes) { CHECK_EQ(devices.size(), shard_shapes.size()); - return std::unique_ptr(new ConcreteSharding( + return std::unique_ptr(new ConcreteSharding( std::move(devices), std::move(shape), std::move(shard_shapes))); } @@ -270,10 +273,9 @@ std::string ConcreteSharding::DebugString() const { })); } -std::unique_ptr ConcreteEvenSharding::Create(DeviceList devices, - Shape shape, - Shape shard_shape) { - return std::unique_ptr(new ConcreteEvenSharding( +std::unique_ptr ConcreteEvenSharding::Create( + DeviceList devices, Shape shape, Shape shard_shape) { + return std::unique_ptr(new ConcreteEvenSharding( std::move(devices), std::move(shape), std::move(shard_shape))); } @@ -318,7 +320,7 @@ std::string ConcreteEvenSharding::DebugString() const { shape_.DebugString(), shard_shape_.DebugString()); } -StatusOr> ShardingParamSharding::Create( +StatusOr> ShardingParamSharding::Create( ShardingParam sharding_param, DeviceList devices) { int64_t device_count = absl::c_accumulate(sharding_param.minor_to_major().axis_sizes, 1, @@ -329,7 +331,7 @@ StatusOr> ShardingParamSharding::Create( "%d", device_count, devices.size()); } - return std::unique_ptr( + return std::unique_ptr( new ShardingParamSharding(std::move(sharding_param), std::move(devices))); } diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.h b/tensorflow/compiler/xla/python/ifrt/sharding.h index a03ce4ebda8e57..6e3d30e99d2584 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.h +++ b/tensorflow/compiler/xla/python/ifrt/sharding.h @@ -72,7 +72,7 @@ class Sharding : public llvm::RTTIExtends { DeviceList devices_; }; -std::ostream& operator<<(std::ostream& os, const Shape& shape); +std::ostream& operator<<(std::ostream& os, const Sharding& sharding); // Single-device sharding. // @@ -83,7 +83,7 @@ class SingleDeviceSharding final : public llvm::RTTIExtends { public: // Creates a single-device sharding. - static std::unique_ptr Create(Device* device); + static std::unique_ptr Create(Device* device); // Sharding implementation. @@ -110,7 +110,7 @@ class SingleDeviceSharding final class OpaqueSharding : public llvm::RTTIExtends { public: // Creates an opaque sharding. `Disassemble()` will fail. - static std::unique_ptr Create(DeviceList devices); + static std::unique_ptr Create(DeviceList devices); // Sharding implementation. @@ -138,8 +138,8 @@ class ConcreteSharding : public llvm::RTTIExtends { public: // Creates a concrete sharding that may contain non-identical shard shapes. // REQUIRES: devices.size() == shard_shapes.size() - static std::unique_ptr Create(DeviceList devices, Shape shape, - std::vector shard_shapes); + static std::unique_ptr Create( + DeviceList devices, Shape shape, std::vector shard_shapes); Shape shape() const { DCHECK(this); @@ -179,8 +179,9 @@ class ConcreteEvenSharding : public llvm::RTTIExtends { public: // Creates a concrete even sharding. - static std::unique_ptr Create(DeviceList devices, Shape shape, - Shape shard_shape); + static std::unique_ptr Create(DeviceList devices, + Shape shape, + Shape shard_shape); Shape shape() const { DCHECK(this); @@ -216,7 +217,7 @@ class ConcreteEvenSharding class ShardingParamSharding : public llvm::RTTIExtends { public: - static StatusOr> Create( + static StatusOr> Create( ShardingParam sharding_param, DeviceList devices); StatusOr>>> diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.proto b/tensorflow/compiler/xla/python/ifrt/sharding.proto new file mode 100644 index 00000000000000..066bce11413998 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding.proto @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "tensorflow/compiler/xla/python/ifrt/types.proto"; + +// Wire format for `SingleDeviceSharding`. +message SingleDeviceShardingProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + int32 device_id = 1; +} + +// Wire format for `OpaqueSharding`. +message OpaqueShardingProto { + DeviceListProto devices = 1; +} + +// Wire format for `ConcreteSharding`. +message ConcreteShardingProto { + DeviceListProto devices = 1; + ShapeProto shape = 2; + repeated ShapeProto shard_shapes = 3; +} + +// Wire format for `ConcreteEvenSharding`. +message ConcreteEvenShardingProto { + DeviceListProto devices = 1; + ShapeProto shape = 2; + ShapeProto shard_shape = 3; +} diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc new file mode 100644 index 00000000000000..5d4881499a09c9 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc @@ -0,0 +1,242 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/python/ifrt/device.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/shape.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.pb.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { + +char DeserializeShardingOptions::ID = 0; + +namespace { + +// Serialization/deserialization for `SingleDeviceSharding`. +class SingleDeviceShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::SingleDeviceSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const SingleDeviceSharding& sharding = + llvm::cast(serializable); + SingleDeviceShardingProto proto; + proto.set_device_id(sharding.devices().front()->id()); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + SingleDeviceShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized SimpleDeviceSharding"); + } + TF_ASSIGN_OR_RETURN( + Device * device, + deserialize_sharding_options->lookup_device(proto.device_id())); + return SingleDeviceSharding::Create(device); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `OpaqueSharding`. +class OpaqueShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::OpaqueSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const OpaqueSharding& sharding = llvm::cast(serializable); + OpaqueShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + OpaqueShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized OpaqueSharding"); + } + TF_ASSIGN_OR_RETURN( + auto devices, + DeviceList::FromProto(deserialize_sharding_options->lookup_device, + proto.devices())); + return OpaqueSharding::Create(std::move(devices)); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `ConcreteSharding`. +class ConcreteShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::ConcreteSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const ConcreteSharding& sharding = + llvm::cast(serializable); + ConcreteShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_shape() = sharding.shape().ToProto(); + for (const Shape& shape : sharding.shard_shapes()) { + *proto.add_shard_shapes() = shape.ToProto(); + } + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + ConcreteShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized ConcreteSharding"); + } + TF_ASSIGN_OR_RETURN( + auto devices, + DeviceList::FromProto(deserialize_sharding_options->lookup_device, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); + std::vector shard_shapes; + shard_shapes.reserve(proto.shard_shapes_size()); + for (const auto& shard_shape_proto : proto.shard_shapes()) { + TF_ASSIGN_OR_RETURN(auto shard_shape, + Shape::FromProto(shard_shape_proto)); + shard_shapes.push_back(std::move(shard_shape)); + } + return ConcreteSharding::Create(std::move(devices), std::move(shape), + std::move(shard_shapes)); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `ConcreteEvenSharding`. +class ConcreteEvenShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::ConcreteEvenSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const ConcreteEvenSharding& sharding = + llvm::cast(serializable); + ConcreteEvenShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_shape() = sharding.shape().ToProto(); + *proto.mutable_shard_shape() = sharding.shard_shape().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + ConcreteEvenShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized ConcreteEvenSharding"); + } + TF_ASSIGN_OR_RETURN( + auto devices, + DeviceList::FromProto(deserialize_sharding_options->lookup_device, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); + TF_ASSIGN_OR_RETURN(auto shard_shape, + Shape::FromProto(proto.shard_shape())); + return ConcreteEvenSharding::Create(std::move(devices), std::move(shape), + std::move(shard_shape)); + } + + static char ID; // NOLINT +}; + +// TODO(hyeontaek): Implement `ShardingParamShardingSerDes`. + +[[maybe_unused]] char SingleDeviceShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char OpaqueShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char ConcreteShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char ConcreteEvenShardingSerDes::ID = 0; // NOLINT + +// clang-format off +bool register_single_device_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_opaque_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_concrete_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_concrete_even_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); +// clang-format on + +} // namespace + +StatusOr> +GetDeserializeShardingOptions(std::unique_ptr options) { + if (!llvm::isa(options.get())) { + return xla::InvalidArgument("options must be DeserializeShardingOptions"); + } + return std::unique_ptr( + static_cast(options.release())); +} + +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h new file mode 100644 index 00000000000000..7ba47d87df9aac --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h @@ -0,0 +1,52 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ + +#include + +#include "llvm/Support/ExtensibleRTTI.h" +#include "tensorflow/compiler/xla/python/ifrt/device.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace ifrt { + +class Client; + +// Options for deserializing shardings. Function referenced by `lookup_device` +// must remain valid during deserialization. +struct DeserializeShardingOptions + : llvm::RTTIExtends { + explicit DeserializeShardingOptions( + DeviceList::LookupDeviceFunc lookup_device) + : lookup_device(lookup_device) {} + + static char ID; // NOLINT + + // Function that converts device ids to devices. + DeviceList::LookupDeviceFunc lookup_device; +}; + +// Casts `DeserializeOptions` into `DeserializeShardingOptions`. +StatusOr> +GetDeserializeShardingOptions(std::unique_ptr options); + +} // namespace ifrt +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc new file mode 100644 index 00000000000000..5cbd445da77a75 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" + +#include +#include +#include + +#include +#include +#include "absl/functional/bind_front.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::ElementsAreArray; + +class ShardingSerDesTest : public test_util::ShardingTest {}; + +TEST_P(ShardingSerDesTest, SingleDeviceShardingRoundTrip) { + auto sharding = + SingleDeviceSharding::Create(GetDevices({0}).devices().front()); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); +} + +TEST_P(ShardingSerDesTest, OpaqueShardingRoundTrip) { + auto sharding = OpaqueSharding::Create(GetDevices({0, 1})); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); + + const auto* out_sharding = llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); +} + +TEST_P(ShardingSerDesTest, ConcreteShardingRoundTrip) { + auto sharding = ConcreteSharding::Create( + GetDevices({0, 1}), + /*shape=*/Shape({10, 20}), + /*shard_shapes=*/{Shape({3, 20}), Shape({7, 20})}); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_THAT(out_sharding->shape(), sharding->shape()); + EXPECT_THAT(out_sharding->shard_shapes(), + ElementsAreArray(sharding->shard_shapes())); +} + +TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { + auto sharding = ConcreteEvenSharding::Create(GetDevices({0, 1}), + /*shape=*/Shape({10, 20}), + /*shard_shape=*/Shape({5, 20})); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_THAT(out_sharding->shape(), sharding->shape()); + EXPECT_THAT(out_sharding->shard_shape(), sharding->shard_shape()); +} + +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 2, .num_addressable_devices = 2})); + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_test.cc b/tensorflow/compiler/xla/python/ifrt/sharding_test.cc index e1b842fde75517..be95647993278a 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding_test.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "tensorflow/compiler/xla/python/ifrt/device.h" #include "tensorflow/compiler/xla/python/ifrt/ir/sharding_param.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/status_matchers.h" #include "tensorflow/tsl/platform/statusor.h" @@ -33,32 +34,31 @@ namespace ifrt { namespace { using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::HasSubstr; using ::testing::SizeIs; using ::tsl::testing::StatusIs; -DeviceList CreateDummyDevices(int count) { - DeviceList::Devices devices; - devices.reserve(count); - for (int i = 0; i < count; ++i) { - devices.push_back(reinterpret_cast(i + 1)); - } - return DeviceList(std::move(devices)); -} +class SingleDeviceShardingTest : public test_util::ShardingTest {}; +class OpaqueShardingTest : public test_util::ShardingTest {}; +class ConcreteShardingTest : public test_util::ShardingTest {}; +class ConcreteEvenShardingTest : public test_util::ShardingTest {}; +class ShardingParamShardingTest : public test_util::ShardingTest {}; -TEST(SingleDeviceShardingTest, IndexDomains) { +TEST_P(SingleDeviceShardingTest, IndexDomains) { + auto device_list = GetDevices({0}); std::shared_ptr sharding = - SingleDeviceSharding::Create(reinterpret_cast(1)); + SingleDeviceSharding::Create(device_list.devices().front()); Shape shape({10, 20}); TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); EXPECT_THAT(index_domains, ElementsAre(IndexDomain(shape))); } -TEST(SingleDeviceShardingTest, Disassemble) { - auto device = reinterpret_cast(1); +TEST_P(SingleDeviceShardingTest, Disassemble) { + auto device_list = GetDevices({0}); std::shared_ptr sharding = - SingleDeviceSharding::Create(device); + SingleDeviceSharding::Create(device_list.devices().front()); Shape shape({10, 20}); TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); @@ -67,11 +67,12 @@ TEST(SingleDeviceShardingTest, Disassemble) { const auto& [result_shape, result_sharding] = disassembled[0]; ASSERT_EQ(shape, result_shape); ASSERT_TRUE(llvm::isa(*result_sharding)); - EXPECT_THAT(result_sharding->devices().devices(), ElementsAre(device)); + EXPECT_THAT(result_sharding->devices().devices(), + ElementsAreArray(device_list.devices())); } -TEST(OpaqueShardingTest, FailedToDisassemble) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(OpaqueShardingTest, FailedToDisassemble) { + auto device_list = GetDevices({0, 1}); std::shared_ptr sharding = OpaqueSharding::Create(device_list); @@ -82,8 +83,8 @@ TEST(OpaqueShardingTest, FailedToDisassemble) { HasSubstr("OpaqueSharding does not have shard shape information"))); } -TEST(OpaqueShardingTest, IndexDomainsFails) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(OpaqueShardingTest, IndexDomainsFails) { + auto device_list = GetDevices({0, 1}); std::shared_ptr sharding = OpaqueSharding::Create(device_list); @@ -94,8 +95,8 @@ TEST(OpaqueShardingTest, IndexDomainsFails) { HasSubstr("OpaqueSharding does not have index domain information"))); } -TEST(ConcreteShardingTest, Disassemble) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteShardingTest, Disassemble) { + auto device_list = GetDevices({0, 1}); std::vector shard_shapes; shard_shapes.reserve(2); shard_shapes.push_back(Shape({10})); @@ -115,8 +116,8 @@ TEST(ConcreteShardingTest, Disassemble) { } } -TEST(ConcreteShardingTest, DisassembleFailsForUnexpectedShape) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteShardingTest, DisassembleFailsForUnexpectedShape) { + auto device_list = GetDevices({0, 1}); std::vector shard_shapes; shard_shapes.reserve(2); shard_shapes.push_back(Shape({10})); @@ -129,8 +130,8 @@ TEST(ConcreteShardingTest, DisassembleFailsForUnexpectedShape) { HasSubstr("ConcreteSharding can only disassemble"))); } -TEST(ConcreteShardingTest, IndexDomainsFails) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteShardingTest, IndexDomainsFails) { + auto device_list = GetDevices({0, 1}); std::vector shard_shapes; shard_shapes.reserve(2); shard_shapes.push_back(Shape({10})); @@ -144,8 +145,8 @@ TEST(ConcreteShardingTest, IndexDomainsFails) { "domain information"))); } -TEST(ConcreteEvenShardingTest, Disassemble) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteEvenShardingTest, Disassemble) { + auto device_list = GetDevices({0, 1}); std::shared_ptr sharding = ConcreteEvenSharding::Create(device_list, Shape({30}), Shape({15})); @@ -161,8 +162,8 @@ TEST(ConcreteEvenShardingTest, Disassemble) { } } -TEST(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) { + auto device_list = GetDevices({0, 1}); std::shared_ptr sharding = ConcreteEvenSharding::Create(device_list, Shape({30}), Shape({15})); @@ -171,8 +172,8 @@ TEST(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) { HasSubstr("ConcreteEvenSharding can only disassemble"))); } -TEST(ConcreteEvenShardingTest, IndexDomainsFails) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteEvenShardingTest, IndexDomainsFails) { + auto device_list = GetDevices({0, 1}); std::vector shard_shapes; std::shared_ptr sharding = ConcreteEvenSharding::Create(device_list, Shape({30}), Shape({15})); @@ -185,8 +186,8 @@ TEST(ConcreteEvenShardingTest, IndexDomainsFails) { "ConcreteEvenSharding does not have index domain information"))); } -TEST(ShardingParamShardingTest, CreateFailsWhenDeviceCountNotMatch) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ShardingParamShardingTest, CreateFailsWhenDeviceCountNotMatch) { + auto device_list = GetDevices({0, 1}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; @@ -196,13 +197,12 @@ TEST(ShardingParamShardingTest, CreateFailsWhenDeviceCountNotMatch) { "ShardingParam 6 vs from DeviceList 2"))); } -TEST(ShardingParamShardingTest, Disassemble) { - DeviceList device_list = CreateDummyDevices(6); +TEST_P(ShardingParamShardingTest, Disassemble) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); TF_ASSERT_OK_AND_ASSIGN(auto disassembled, param_sharding->Disassemble(Shape({6, 6}))); @@ -216,12 +216,12 @@ TEST(ShardingParamShardingTest, Disassemble) { } } -TEST(ShardingParamShardingTest, DisassembleFailsWhenRankNotMatch) { +TEST_P(ShardingParamShardingTest, DisassembleFailsWhenRankNotMatch) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); EXPECT_THAT( param_sharding->Disassemble(Shape({6, 6, 6})), @@ -230,12 +230,12 @@ TEST(ShardingParamShardingTest, DisassembleFailsWhenRankNotMatch) { "Ranks don't match. From Shape 3 vs from ShardingParam 2"))); } -TEST(ShardingParamShardingTest, DisassembleFailsForUnevenSharding) { +TEST_P(ShardingParamShardingTest, DisassembleFailsForUnevenSharding) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); EXPECT_THAT( param_sharding->Disassemble(Shape({7, 6})), @@ -244,12 +244,12 @@ TEST(ShardingParamShardingTest, DisassembleFailsForUnevenSharding) { HasSubstr("Uneven shard is not supported. dim: 7, dim_shards: 2"))); } -TEST(ShardingParamShardingTest, IndexDomain) { +TEST_P(ShardingParamShardingTest, IndexDomain) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); TF_ASSERT_OK_AND_ASSIGN(auto index_domains, param_sharding->IndexDomains(Shape({6, 6}))); @@ -262,12 +262,12 @@ TEST(ShardingParamShardingTest, IndexDomain) { IndexDomain(Index({3, 4}), Shape({3, 2})))); } -TEST(ShardingParamShardingTest, IndexDomainWithPermutation) { +TEST_P(ShardingParamShardingTest, IndexDomainWithPermutation) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); TF_ASSERT_OK_AND_ASSIGN(auto index_domains, param_sharding->IndexDomains(Shape({6, 6}))); @@ -280,12 +280,12 @@ TEST(ShardingParamShardingTest, IndexDomainWithPermutation) { IndexDomain(Index({3, 4}), Shape({3, 2})))); } -TEST(ShardingParamShardingTest, IndexDomainWithReplication) { +TEST_P(ShardingParamShardingTest, IndexDomainWithReplication) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); TF_ASSERT_OK_AND_ASSIGN(auto index_domains, param_sharding->IndexDomains(Shape({6, 6}))); @@ -298,6 +298,22 @@ TEST(ShardingParamShardingTest, IndexDomainWithReplication) { IndexDomain(Index({3, 0}), Shape({3, 6})))); } +INSTANTIATE_TEST_SUITE_P(NumDevices, SingleDeviceShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 6})); +INSTANTIATE_TEST_SUITE_P(NumDevices, OpaqueShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 6})); +INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 6})); +INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteEvenShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 6})); +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingParamShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 4})); + } // namespace } // namespace ifrt } // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_test_util.cc b/tensorflow/compiler/xla/python/ifrt/sharding_test_util.cc new file mode 100644 index 00000000000000..e43c363eff7f41 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_test_util.cc @@ -0,0 +1,96 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/python/ifrt/device.h" +#include "tensorflow/compiler/xla/python/ifrt/mock.h" +#include "tensorflow/compiler/xla/python/ifrt/test_util.h" +#include "tensorflow/tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace test_util { + +namespace { + +using ::testing::Return; + +// Internal state of a client for sharding tests. +struct ShardingTestClientState { + // Mapping from a device ID to the mock device object. + absl::flat_hash_map> device_map; + // Raw pointers to mock devices. + std::vector devices; +}; + +// Creates a mock client for sharding tests. The client will have a specified +// number of fake addressable and non-addressable devices. Client implements +// `devices()` and `LookupDevice()`. Device implements `id()`, with an arbitrary +// deterministic device ids assigned. +std::shared_ptr MakeShardingTestClient( + int num_devices, int num_addressable_devices) { + auto state = std::make_shared(); + state->device_map.reserve(num_devices); + state->devices.reserve(num_devices); + + for (int i = 0; i < num_addressable_devices; ++i) { + auto device = std::make_unique(); + ON_CALL(*device, id).WillByDefault(Return(i + 10)); + ON_CALL(*device, IsAddressable).WillByDefault(Return(true)); + state->devices.push_back(device.get()); + state->device_map.insert({i + 10, std::move(device)}); + } + for (int i = num_addressable_devices; i < num_devices; ++i) { + auto device = std::make_unique(); + ON_CALL(*device, id).WillByDefault(Return(i + 10)); + ON_CALL(*device, IsAddressable).WillByDefault(Return(false)); + state->devices.push_back(device.get()); + state->device_map.insert({i + 10, std::move(device)}); + } + + auto client = std::make_shared(); + ON_CALL(*client, devices) + .WillByDefault( + [state]() -> absl::Span { return state->devices; }); + ON_CALL(*client, LookupDevice) + .WillByDefault([state](int device_id) -> StatusOr { + auto it = state->device_map.find(device_id); + if (it == state->device_map.end()) { + return InvalidArgument("Unexpected device id: %d", device_id); + } + return it->second.get(); + }); + return client; +} + +} // namespace + +void ShardingTest::SetUp() { + const auto [num_devices, num_addressable_devices] = GetParam(); + client_ = MakeShardingTestClient(num_devices, num_addressable_devices); +} + +DeviceList ShardingTest::GetDevices(absl::Span device_indices) { + return test_util::GetDevices(client_.get(), device_indices).value(); +} + +} // namespace test_util +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_test_util.h b/tensorflow/compiler/xla/python/ifrt/sharding_test_util.h new file mode 100644 index 00000000000000..b7ba399c1689e8 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_test_util.h @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace test_util { + +// Parameters for ShardingTest. +// Requests `num_devices` total devices, where `num_addressable_devices` of them +// are addressable, and the rest of devices are non-addressable. +struct ShardingTestParam { + int num_devices; + int num_addressable_devices; +}; + +// Test fixture for sharding tests. +class ShardingTest : public testing::TestWithParam { + public: + void SetUp() override; + Client* client() { return client_.get(); } + + // Returns `DeviceList` containing devices at given indexes (not ids) within + // `client.devices()`. + // REQUIRES: 0 <= device_indices[i] < num_devices + DeviceList GetDevices(absl::Span device_indices); + + private: + std::shared_ptr client_; +}; + +} // namespace test_util +} // namespace ifrt +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/ifrt/support/BUILD b/tensorflow/compiler/xla/python/ifrt/support/BUILD index 448dca40f6791e..08ecda728a4614 100644 --- a/tensorflow/compiler/xla/python/ifrt/support/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/support/BUILD @@ -29,6 +29,7 @@ xla_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/python/ifrt", + "//tensorflow/compiler/xla/python/ifrt:sharding_test_util", "//tensorflow/compiler/xla/python/ifrt/ir", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status_matchers", diff --git a/tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc b/tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc index ad61143c31cc4f..94d06804e167aa 100644 --- a/tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc +++ b/tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding.h" -#include #include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/ifrt/ir/sharding_param.h" #include "tensorflow/compiler/xla/python/ifrt/shape.h" #include "tensorflow/compiler/xla/python/ifrt/sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -49,15 +49,6 @@ StatusOr ToHloSharding(const ShardingParam& sharding_param, return xla::HloSharding::FromProto(op_sharding); } -DeviceList CreateDummyDevices(int count) { - DeviceList::Devices devices; - devices.reserve(count); - for (int i = 0; i < count; ++i) { - devices.push_back(reinterpret_cast(i + 1)); - } - return DeviceList(std::move(devices)); -} - TEST(ShardingParamToOpShardingTest, Replicated) { ShardingParam sharding_param{/*dim_shards=*/{1, 1, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; @@ -114,27 +105,34 @@ TEST(ShardingParamToOpShardingTest, ErrorOnDeviceAssignment) { StatusIs(tsl::error::OUT_OF_RANGE, "Can't map device 5")); } -void AssertSameTiling(const ShardingParam& sharding_param, - const HloSharding& hlo_sharding, const Shape& shape) { - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr sharding, - ShardingParamSharding::Create(sharding_param, CreateDummyDevices(6))); - const xla::Shape xla_shape(PrimitiveType::F16, shape.dims(), {}, {}); - - TF_ASSERT_OK_AND_ASSIGN(const std::vector index_domains, - sharding->IndexDomains(shape)); - ASSERT_EQ(index_domains.size(), - hlo_sharding.tile_assignment().num_elements()); - const xla::Shape xla_tile_shape = hlo_sharding.TileShape(xla_shape); - for (int i = 0; i < index_domains.size(); ++i) { - SCOPED_TRACE(absl::StrCat("on device ", i)); - EXPECT_EQ(index_domains[i].origin().elements(), - hlo_sharding.TileOffsetForDevice(xla_shape, i)); - EXPECT_EQ(index_domains[i].shape().dims(), xla_tile_shape.dimensions()); +class ShardingParamToOpShardingEquivalentTest : public test_util::ShardingTest { + public: + void AssertSameTiling(const ShardingParam& sharding_param, + const HloSharding& hlo_sharding, const Shape& shape) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN( + std::shared_ptr sharding, + ShardingParamSharding::Create(sharding_param, device_list)); + const xla::Shape xla_shape(PrimitiveType::F16, shape.dims(), {}, {}); + + TF_ASSERT_OK_AND_ASSIGN(const std::vector index_domains, + sharding->IndexDomains(shape)); + ASSERT_EQ(index_domains.size(), + hlo_sharding.tile_assignment().num_elements()); + const xla::Shape xla_tile_shape = hlo_sharding.TileShape(xla_shape); + for (int i = 0; i < index_domains.size(); ++i) { + SCOPED_TRACE(absl::StrCat("on device ", i)); + EXPECT_EQ(index_domains[i].origin().elements(), + hlo_sharding.TileOffsetForDevice(xla_shape, i)); + EXPECT_EQ(index_domains[i].shape().dims(), xla_tile_shape.dimensions()); + } } -} -TEST(ShardingParamToOpShardingEquivalentTest, FullySharded) { + private: + std::shared_ptr client_; +}; + +TEST_P(ShardingParamToOpShardingEquivalentTest, FullySharded) { ShardingParam sharding_param{/*dim_shards=*/{2, 3}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, @@ -142,7 +140,7 @@ TEST(ShardingParamToOpShardingEquivalentTest, FullySharded) { AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } -TEST(ShardingParamToOpShardingEquivalentTest, WithPermutation) { +TEST_P(ShardingParamToOpShardingEquivalentTest, WithPermutation) { ShardingParam sharding_param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, @@ -150,7 +148,7 @@ TEST(ShardingParamToOpShardingEquivalentTest, WithPermutation) { AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } -TEST(ShardingParamToOpShardingEquivalentTest, WithReplication) { +TEST_P(ShardingParamToOpShardingEquivalentTest, WithReplication) { ShardingParam sharding_param{/*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, @@ -158,6 +156,10 @@ TEST(ShardingParamToOpShardingEquivalentTest, WithReplication) { AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingParamToOpShardingEquivalentTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 4})); + } // namespace } // namespace support } // namespace ifrt diff --git a/tensorflow/compiler/xla/python/ifrt/test_util.cc b/tensorflow/compiler/xla/python/ifrt/test_util.cc index 47ecc393db8269..4e73e6d834884d 100644 --- a/tensorflow/compiler/xla/python/ifrt/test_util.cc +++ b/tensorflow/compiler/xla/python/ifrt/test_util.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/compiler/xla/python/ifrt/device.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/tsl/platform/test.h" namespace xla { namespace ifrt { @@ -32,20 +34,20 @@ namespace { class ClientFactory { public: - void Register(std::function>()> factory) { + void Register(std::function>()> factory) { absl::MutexLock lock(&mu_); CHECK(!factory_) << "Client factory has been already registered."; factory_ = std::move(factory); } - std::function>()> Get() const { + std::function>()> Get() const { absl::MutexLock lock(&mu_); return factory_; } private: mutable absl::Mutex mu_; - std::function>()> factory_ + std::function>()> factory_ ABSL_GUARDED_BY(mu_); }; @@ -57,11 +59,11 @@ ClientFactory& GetGlobalClientFactory() { } // namespace void RegisterClientFactory( - std::function>()> factory) { + std::function>()> factory) { GetGlobalClientFactory().Register(std::move(factory)); } -StatusOr> GetClient() { +StatusOr> GetClient() { auto factory = GetGlobalClientFactory().Get(); CHECK(factory) << "Client factory has not been registered."; return factory(); @@ -80,6 +82,20 @@ void SetTestFilterIfNotUserSpecified(absl::string_view custom_filter) { #endif } +absl::StatusOr GetDevices(Client* client, + absl::Span device_indices) { + DeviceList::Devices devices; + devices.reserve(device_indices.size()); + for (int device_index : device_indices) { + if (device_index < 0 || device_index >= client->devices().size()) { + return absl::InvalidArgumentError( + absl::StrCat("Out of range device index: ", device_index)); + } + devices.push_back(client->devices()[device_index]); + } + return DeviceList(std::move(devices)); +} + } // namespace test_util } // namespace ifrt } // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/test_util.h b/tensorflow/compiler/xla/python/ifrt/test_util.h index 18bb418431f757..b65d6c1cf948e4 100644 --- a/tensorflow/compiler/xla/python/ifrt/test_util.h +++ b/tensorflow/compiler/xla/python/ifrt/test_util.h @@ -39,13 +39,13 @@ namespace test_util { // Registers an IFRT client factory function. Must be called only once. void RegisterClientFactory( - std::function>()> factory); + std::function>()> factory); // Returns true iff an IFRT client factory function has been registered. bool IsClientFactoryRegistered(); // Gets a new IFRT client using the registered client factory. -StatusOr> GetClient(); +StatusOr> GetClient(); // Set a default test filter if user doesn't provide one using --gtest_filter. void SetTestFilterIfNotUserSpecified(absl::string_view custom_filter); @@ -80,6 +80,11 @@ void AssertPerShardData( } } +// Helper function that makes `DeviceList` containing devices at given +// indexes (not ids) within `client.devices()`. +absl::StatusOr GetDevices(Client* client, + absl::Span device_indices); + } // namespace test_util } // namespace ifrt } // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/types.proto b/tensorflow/compiler/xla/python/ifrt/types.proto new file mode 100644 index 00000000000000..e9c799bcc1ed6c --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/types.proto @@ -0,0 +1,31 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Wire format for `DeviceList`. +message DeviceListProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + repeated int32 device_ids = 1; +} + +// Wire format for `Shape`. Currently support static shapes with all dimension +// sizes greater than or equal to 0. +message ShapeProto { + repeated int64 dims = 1; +} diff --git a/tensorflow/compiler/xla/python/mlir.cc b/tensorflow/compiler/xla/python/mlir.cc index f121daaaec0a29..12e2f13b7a234d 100644 --- a/tensorflow/compiler/xla/python/mlir.cc +++ b/tensorflow/compiler/xla/python/mlir.cc @@ -233,17 +233,21 @@ void BuildMlirSubmodule(py::module& m) { py::arg("mlir_module")); mlir_module.def( "refine_polymorphic_shapes", - [](std::string mlir_module) -> py::bytes { + [](std::string mlir_module, bool enable_shape_assertions, + bool validate_static_shapes) -> py::bytes { std::string buffer; llvm::raw_string_ostream os(buffer); - xla::ThrowIfError(RefinePolymorphicShapes(mlir_module, os)); + xla::ThrowIfError(RefinePolymorphicShapes( + mlir_module, os, enable_shape_assertions, validate_static_shapes)); return py::bytes(buffer); }, - py::arg("mlir_module"), + py::arg("mlir_module"), py::arg("enable_shape_assertions") = true, + py::arg("validate_static_shapes") = true, R"(Refines the dynamic shapes for a module. The "main" function must have static shapes and all the intermediate dynamic shapes depend only on the input static - shapes. + shapes. Optionally, also validates that the resulting module has + only static shapes. )"); } diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD index e1c1a36bb4ef62..a7023e0b73459d 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD @@ -90,7 +90,7 @@ cc_library( "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -110,10 +110,46 @@ xla_cc_test( ], ) +tf_proto_library( + name = "xla_sharding_proto", + srcs = ["xla_sharding.proto"], + protodeps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/python/ifrt:types_proto", + ], +) + +cc_library( + name = "xla_sharding_serdes", + srcs = ["xla_sharding_serdes.cc"], + deps = [ + ":xla_ifrt", + ":xla_sharding_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/python/ifrt:serdes", + "//tensorflow/compiler/xla/python/ifrt:sharding_serdes", + ], + alwayslink = True, +) + +xla_cc_test( + name = "xla_sharding_serdes_test", + srcs = ["xla_sharding_serdes_test.cc"], + deps = [ + ":xla_ifrt", + ":xla_sharding_serdes", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/python/ifrt:sharding_serdes", + "//tensorflow/compiler/xla/python/ifrt:sharding_test_util", + "@com_google_absl//absl/functional:bind_front", + "@com_google_googletest//:gtest_main", + ], +) + # TODO(hyeontaek): Move this target out of pjrt_ifrt. cc_library( name = "xla_executable_impl_test_lib", - testonly = 1, + testonly = True, srcs = ["xla_executable_impl_test_lib.cc"], deps = [ ":xla_ifrt", @@ -124,7 +160,7 @@ cc_library( "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/strings", ], - alwayslink = 1, + alwayslink = True, ) # TODO(hyeontaek): Move this target out of pjrt_ifrt. @@ -146,6 +182,8 @@ xla_cc_test( deps = [ ":tfrt_cpu_client_test_lib", ":xla_ifrt", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/python/ifrt:sharding_test_util", "//tensorflow/compiler/xla/python/ifrt:tuple_impl_test_lib", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status_matchers", @@ -200,14 +238,14 @@ cc_library( cc_library( name = "tfrt_cpu_client_test_lib", - testonly = 1, + testonly = True, srcs = ["tfrt_cpu_client_test_lib.cc"], deps = [ ":pjrt_ifrt", "//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", "//tensorflow/compiler/xla/python/ifrt:test_util", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc index 181f3be1ff8224..8e399194e9e3ed 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc @@ -26,11 +26,11 @@ namespace { const bool kUnused = (test_util::RegisterClientFactory( - []() -> StatusOr> { + []() -> StatusOr> { TF_ASSIGN_OR_RETURN(auto pjrt_client, xla::GetTfrtCpuClient(/*asynchronous=*/true, /*cpu_device_count=*/2)); - return StatusOr>( + return std::shared_ptr( PjRtClient::Create(std::move(pjrt_client))); }), true); diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto new file mode 100644 index 00000000000000..0ff8040b66233e --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto @@ -0,0 +1,27 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "tensorflow/compiler/xla/python/ifrt/types.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + +// Wire format for `HloSharding`. +message HloShardingProto { + DeviceListProto devices = 1; + xla.OpSharding xla_op_sharding = 2; +} diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc new file mode 100644 index 00000000000000..daff5e2149ff6c --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc @@ -0,0 +1,80 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.pb.h" + +namespace xla { +namespace ifrt { + +namespace { + +// Serialization/deserialization for `HloSharding`. +class HloShardingSerDes : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::HloSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const HloSharding& sharding = llvm::cast(serializable); + HloShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_xla_op_sharding() = sharding.xla_hlo_sharding().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + HloShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized HloSharding"); + } + TF_ASSIGN_OR_RETURN( + auto devices, + DeviceList::FromProto(deserialize_sharding_options->lookup_device, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto xla_hlo_sharding, + xla::HloSharding::FromProto(proto.xla_op_sharding())); + return HloSharding::Create(std::move(devices), std::move(xla_hlo_sharding)); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char HloShardingSerDes::ID = 0; // NOLINT + +// clang-format off +bool register_hlo_sharding_serdes = ([] { + RegisterSerDes( + std::make_unique()); +}(), true); +// clang-format on + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc new file mode 100644 index 00000000000000..1a5bc8ef7b5804 --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include "absl/functional/bind_front.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::ElementsAreArray; + +class XlaShardingSerDesTest : public test_util::ShardingTest {}; + +TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { + auto device_list = GetDevices({0, 1}); + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment(absl::Span({2, 1}))); + auto sharding = HloSharding::Create(device_list, + /*xla_hlo_sharding=*/xla_hlo_sharding); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); + + const auto* out_sharding = llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_EQ(out_sharding->xla_hlo_sharding(), sharding->xla_hlo_sharding()); +} + +INSTANTIATE_TEST_SUITE_P(NumDevices, XlaShardingSerDesTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 2, .num_addressable_devices = 2})); + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc index 408c978fab375f..304b22247a0b3c 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/status_matchers.h" @@ -33,17 +35,10 @@ using ::testing::HasSubstr; using ::testing::SizeIs; using ::tsl::testing::StatusIs; -DeviceList CreateDummyDevices(int count) { - DeviceList::Devices devices; - devices.reserve(count); - for (int i = 0; i < count; ++i) { - devices.push_back(reinterpret_cast(i + 1)); - } - return DeviceList(std::move(devices)); -} +class HloShardingTest : public test_util::ShardingTest {}; -TEST(HloShardingTest, IndexDomainsWithReplication) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, IndexDomainsWithReplication) { + auto device_list = GetDevices({0, 1}); // Fully replicated. auto xla_hlo_sharding = xla::HloSharding::Replicate(); std::shared_ptr sharding = @@ -59,8 +54,8 @@ TEST(HloShardingTest, IndexDomainsWithReplication) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithReplication) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, DisassembleWithReplication) { + auto device_list = GetDevices({0, 1}); // Fully replicated. auto xla_hlo_sharding = xla::HloSharding::Replicate(); std::shared_ptr sharding = @@ -79,10 +74,11 @@ TEST(HloShardingTest, DisassembleWithReplication) { } } -TEST(HloShardingTest, IndexDomainsWithTile) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, IndexDomainsWithTile) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile( + xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -97,10 +93,11 @@ TEST(HloShardingTest, IndexDomainsWithTile) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithTile) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, DisassembleWithTile) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile( + xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -117,10 +114,11 @@ TEST(HloShardingTest, DisassembleWithTile) { } } -TEST(HloShardingTest, IndexDomainsWithUnevenTile) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, IndexDomainsWithUnevenTile) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile( + xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -135,10 +133,11 @@ TEST(HloShardingTest, IndexDomainsWithUnevenTile) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithUnevenTile) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, DisassembleWithUnevenTile) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile( + xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -159,8 +158,8 @@ TEST(HloShardingTest, DisassembleWithUnevenTile) { } } -TEST(HloShardingTest, IndexDomainsWithPartialTile) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, IndexDomainsWithPartialTile) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // replicated by 3 times. auto xla_hlo_sharding = @@ -183,8 +182,8 @@ TEST(HloShardingTest, IndexDomainsWithPartialTile) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithPartialTile) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, DisassembleWithPartialTile) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // replicated by 3 times. auto xla_hlo_sharding = @@ -205,8 +204,8 @@ TEST(HloShardingTest, DisassembleWithPartialTile) { } } -TEST(HloShardingTest, IndexDomainsWithSubgroupReplicated) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, IndexDomainsWithSubgroupReplicated) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // replicated by 3 times. auto xla_hlo_sharding = xla::HloSharding::Subgroup( @@ -229,8 +228,8 @@ TEST(HloShardingTest, IndexDomainsWithSubgroupReplicated) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithSubgroupReplicated) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, DisassembleWithSubgroupReplicated) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // replicated by 3 times. auto xla_hlo_sharding = xla::HloSharding::Subgroup( @@ -251,8 +250,8 @@ TEST(HloShardingTest, DisassembleWithSubgroupReplicated) { } } -TEST(HloShardingTest, IndexDomainsWithSubgroupMaximalSlowPath) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, IndexDomainsWithSubgroupMaximalSlowPath) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // maximal-replicated by 3 times, device#0 in each replication is maximal. auto xla_hlo_sharding = xla::HloSharding::Subgroup( @@ -275,8 +274,8 @@ TEST(HloShardingTest, IndexDomainsWithSubgroupMaximalSlowPath) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // maximal-replicated by 3 times, device#0 in each replication is maximal. auto xla_hlo_sharding = xla::HloSharding::Subgroup( @@ -297,10 +296,11 @@ TEST(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { } } -TEST(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { - auto device_list = CreateDummyDevices(1); +TEST_P(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { + auto device_list = GetDevices({0}); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile( + xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -311,10 +311,11 @@ TEST(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { "device count does not match: 2 vs. 1"))); } -TEST(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile( + xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -326,6 +327,10 @@ TEST(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { HasSubstr("shape must have 2 dimensions, but has 1 dimensions"))); } +INSTANTIATE_TEST_SUITE_P(NumDevices, HloShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 4})); + } // namespace } // namespace ifrt } // namespace xla diff --git a/tensorflow/compiler/xla/python/profiler/internal/BUILD b/tensorflow/compiler/xla/python/profiler/internal/BUILD index 8f4f0a413e8b5f..5ffef7c8e5e284 100644 --- a/tensorflow/compiler/xla/python/profiler/internal/BUILD +++ b/tensorflow/compiler/xla/python/profiler/internal/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("//tensorflow/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") @@ -12,7 +12,7 @@ cc_library( name = "python_hooks", srcs = ["python_hooks.cc"], hdrs = ["python_hooks.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], # Incompatible with -fexceptions. visibility = [ diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc index bae3fc0083e294..2ae556a19b53b7 100644 --- a/tensorflow/compiler/xla/python/pytree.cc +++ b/tensorflow/compiler/xla/python/pytree.cc @@ -185,7 +185,12 @@ void PyTreeDef::FlattenIntoImpl( } else { node.kind = GetKind(handle, &node.custom); auto recurse = [this, &leaf_predicate, &leaves](py::handle child) { + if (Py_EnterRecursiveCall( + " in flatten; PyTree may have cyclical node references.")) { + return; + } FlattenInto(child, leaves, leaf_predicate); + Py_LeaveRecursiveCall(); }; switch (node.kind) { case PyTreeKind::kNone: @@ -265,6 +270,11 @@ PyTreeDef::Flatten(py::handle x, std::optional leaf_predicate) { std::vector leaves; auto tree = std::make_unique(); tree->FlattenInto(x, leaves, leaf_predicate); + // Handle the unbounded recursion error for trees with cyclical node + // references. + if (PyErr_Occurred()) { + throw py::error_already_set(); + } return std::make_pair(std::move(leaves), std::move(tree)); } diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc index 66b0a459438cdf..01a09c1c244909 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc @@ -12,15 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/compiler/xla/python/refine_polymorphic_shapes.h" -#include "absl/log/log.h" +#include +#include + #include "absl/status/status.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -29,49 +35,315 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/xla/mlir/utils/error_util.h" -#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/tsl/platform/errors.h" namespace xla { -xla::Status RefinePolymorphicShapes(llvm::StringRef module_str, - llvm::raw_ostream &os) { - mlir::MLIRContext context; - if (VLOG_IS_ON(3)) context.disableMultithreading(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); +namespace { - mlir::DialectRegistry registry; - mlir::func::registerAllExtensions(registry); - context.appendDialectRegistry(registry); +constexpr llvm::StringRef shapeAssertionName = "shape_assertion"; +constexpr llvm::StringRef errorMessageAttrName = "error_message"; +// We bound the number of error_message_inputs for using llvm::formatv +constexpr int maxErrorMessageInputs = 4; - auto module = mlir::parseSourceString( - llvm::StringRef(module_str.data(), module_str.size()), &context); - if (!module || failed(module->verifyInvariants())) { - return absl::InvalidArgumentError("Cannot parse module."); +// This pass is needed when we have shape assertions. A shape assertion is +// represented via the `stablehlo.custom_call @shape_assertion` +// custom call, and represents an assertion that the first operand +// (`assert_what`) evaluates to `true`. The custom call also has an +// `error_message` string attribute, and a variadic number +// of integer scalar operands that may be used to format the error message. +// The `error_message` may contain format specifiers `{0}`, `{1}`, ..., that +// are replaced with the values of the error message inputs. The formatting is +// done with the `llvm::formatv` function +// (https://llvm.org/docs/ProgrammersManual.html#formatting-strings-the-formatv-function). +// +struct CheckShapeAssertionsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CheckShapeAssertionsPass) + + explicit CheckShapeAssertionsPass(bool enable_shape_assertions = true) + : PassWrapper() { + this->enable_shape_assertions = enable_shape_assertions; + } + + CheckShapeAssertionsPass(const CheckShapeAssertionsPass &pass) { + enable_shape_assertions = pass.enable_shape_assertions; + } + + private: + void runOnOperation() final { + mlir::func::FuncOp func_op = getOperation(); + func_op.walk([this](mlir::stablehlo::CustomCallOp op) { + if (!op.getCallTargetName().equals(shapeAssertionName)) return; + if (!enable_shape_assertions) { + op.erase(); + return; + } + // Check first for ill-formed assertions, rather than silently fail. + if (mlir::failed(verifyShapeAssertion(op))) { + signalPassFailure(); + return; + } + mlir::OperandRange inputs = op.getInputs(); + mlir::SmallVector assertWhat; + if (mlir::failed(mlir::hlo::matchInts(inputs[0], assertWhat))) { + op.emitError() << "expects static assert_what (operand #0)"; + signalPassFailure(); + return; + } + if (assertWhat[0] != 0) { + op.erase(); + return; + } + llvm::StringRef errorMessage = getErrorMessage(op); + mlir::SmallVector errorMessageInputs; + for (int i = 1; i < inputs.size(); ++i) { + mlir::SmallVector input; + if (failed(mlir::hlo::matchInts(inputs[i], input))) { + op.emitError() << "expects static error_message_input (operand #" << i + << ")"; + signalPassFailure(); + return; + } + errorMessageInputs.push_back(input[0]); + } + op.emitError(formatErrorMessage(errorMessage, errorMessageInputs)); + signalPassFailure(); + }); + } + + mlir::LogicalResult verifyShapeAssertion(mlir::stablehlo::CustomCallOp op) { + if (!(1 <= op->getNumOperands() && + op->getNumOperands() <= 1 + maxErrorMessageInputs)) + return op.emitError() << "expects 1 <= size(operands) <= " + << (1 + maxErrorMessageInputs); + int nrErrorMessageInputs = op.getNumOperands() - 1; + if (op->getNumResults() != 0) + return op.emitError("expects size(results) = 0"); + for (const auto &attr : op->getAttrs()) { + if (attr.getName() != "api_version" && + attr.getName() != "backend_config" && + attr.getName() != "call_target_name" && + attr.getName() != "error_message" && + attr.getName() != "has_side_effect") + return op.emitError() + << attr.getName() << " is not a supported attribute"; + } + if (!op.getBackendConfig().empty()) + return op.emitError() << "expects an empty backend_config"; + if (!op.getCallTargetName().equals(shapeAssertionName)) + return op.emitError() << "expects @shape_assertion"; + if (!op.getHasSideEffect()) + return op.emitError() << "expects has_side_effect=true"; + + // input[0] (assert_what) : tensor + auto assertWhatType = + op.getInputs()[0].getType().dyn_cast(); + if (!assertWhatType || !assertWhatType.hasRank() || + assertWhatType.getRank() != 0 || + !assertWhatType.getElementType().isSignlessInteger() || + assertWhatType.getElementTypeBitWidth() != 1) + return op.emitError() << "expects assert_what (operand #0) " + << "to be a constant of type tensor"; + + // input[1:] (error_message_inputs) : tensor or tensor + for (int i = 0; i < nrErrorMessageInputs; ++i) { + auto errorMessageInputType = + op.getInputs()[i + 1].getType().dyn_cast(); + if (!errorMessageInputType || !errorMessageInputType.hasRank() || + errorMessageInputType.getRank() != 0 || + !errorMessageInputType.getElementType().isSignlessInteger() || + (errorMessageInputType.getElementTypeBitWidth() != 32 && + errorMessageInputType.getElementTypeBitWidth() != 64)) + return op.emitError() + << "expects error_message_input (operand #" << (i + 1) << ") " + << "to be a constant of type tensor or tensor"; + } + + if (!op->hasAttr(errorMessageAttrName)) + return op.emitError() << "expects an error_message attribute"; + + // error_message contains valid format specifiers. + std::string errorMessage = getErrorMessage(op).data(); + // format specs: "{" index ["," layout] [":" format] "}" + llvm::Regex formatSpecifierRE = llvm::Regex("{([0-9]+)[,:}]"); + do { + mlir::SmallVector formatSpec; + if (!formatSpecifierRE.match(errorMessage, &formatSpec)) { + break; + } + int index = std::stoi(formatSpec[1].data()); + if (!(0 <= index && index < nrErrorMessageInputs)) { + return op.emitError() + << "expects error_message to contain format specifiers with " + << "error_message_input index less than " << nrErrorMessageInputs + << ". Found specifier " << formatSpec[0]; + } + errorMessage = formatSpecifierRE.sub("", errorMessage); + } while (true); + + return mlir::success(); + } + + llvm::StringRef getErrorMessage(mlir::stablehlo::CustomCallOp op) const { + return op->getAttr(errorMessageAttrName) + .cast() + .getValue(); + } + + std::string formatErrorMessage( + llvm::StringRef errorMessage, + const mlir::SmallVector &errorMessageInputs) const { + int nrErrorMessageInputs = errorMessageInputs.size(); + auto errorMessageFormat = errorMessage.data(); + switch (nrErrorMessageInputs) { + case 0: + return errorMessageFormat; + case 1: + return llvm::formatv(errorMessageFormat, errorMessageInputs[0]); + case 2: + return llvm::formatv(errorMessageFormat, errorMessageInputs[0], + errorMessageInputs[1]); + case 3: + return llvm::formatv(errorMessageFormat, errorMessageInputs[0], + errorMessageInputs[1], errorMessageInputs[2]); + case 4: + return llvm::formatv(errorMessageFormat, errorMessageInputs[0], + errorMessageInputs[1], errorMessageInputs[2], + errorMessageInputs[3]); + default: + return errorMessageFormat; + } + } + + mlir::StringRef getArgument() const override { + return "check-shape-assertions"; } - mlir::PassManager pm(&context); + mlir::StringRef getDescription() const override { + return "Check stablehlo.custom_call @shape_assertion ops."; + } + + Option enable_shape_assertions{ + *this, "enable-shape-assertions", + llvm::cl::desc("Whether shape assertions may generate errors."), + llvm::cl::init(true)}; +}; + +} // namespace + +absl::Status RefinePolymorphicShapes(mlir::ModuleOp module, + bool enable_shape_assertions) { + mlir::MLIRContext *context = module->getContext(); + if (VLOG_IS_ON(3)) context->disableMultithreading(); + + // Verify the module before running passes on it. + // If the module doesn't pass verification, all sorts of weirdness might + // happen if we run the pass manager. + mlir::BaseScopedDiagnosticHandler diag_handler(context); + + if (mlir::failed(mlir::verify(module))) { + return absl::InvalidArgumentError( + absl::StrCat("Module verification failed: ", + diag_handler.ConsumeStatus().ToString())); + } + + mlir::PassManager pm(context); if (VLOG_IS_ON(3)) { auto print_before = [](mlir::Pass *, mlir::Operation *) { return true; }; auto print_after = [](mlir::Pass *, mlir::Operation *) { return true; }; pm.enableIRPrinting(print_before, print_after, /*printModuleScope=*/true, - /*printAfterOnlyOnChange=*/false); + /*printAfterOnlyOnChange=*/true); } + // TODO(necula): we should not need the inliner. pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); pm.addNestedPass( mlir::stablehlo::createStablehloCanonicalizeDynamismPass()); - if (!mlir::succeeded(pm.run(*module))) { - return absl::InternalError("Cannot refine shapes."); + pm.addNestedPass( + std::make_unique(enable_shape_assertions)); + if (!mlir::succeeded(pm.run(module))) { + return absl::InvalidArgumentError( + absl::StrCat("Module shape refinement failed: ", + diag_handler.ConsumeStatus().ToString())); } + return absl::OkStatus(); +} - if (failed(mlir::writeBytecodeToFile(*module, os))) { +absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, + llvm::raw_ostream &os, + bool enable_shape_assertions, + bool validate_static_shapes) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + context.appendDialectRegistry(registry); + + mlir::OwningOpRef module = + mlir::parseSourceString( + llvm::StringRef(module_str.data(), module_str.size()), &context); + if (!module) { + return absl::InvalidArgumentError("Cannot parse module."); + } + TF_RETURN_IF_ERROR(RefinePolymorphicShapes(*module, enable_shape_assertions)); + if (validate_static_shapes) TF_RETURN_IF_ERROR(ValidateStaticShapes(*module)); + if (mlir::failed(mlir::writeBytecodeToFile(*module, os))) { return absl::InternalError("Cannot serialize module."); } - return xla::OkStatus(); + return absl::OkStatus(); +} + +absl::Status ValidateStaticShapes(mlir::ModuleOp module) { + mlir::BaseScopedDiagnosticHandler diag_handler(module->getContext()); + bool moduleHasDynamicShapes = false; + bool moduleHasShapeAssertions = false; + + module->walk([&](mlir::Operation *op) { + // It's sufficient to only check results because operands either come from + // results or from block arguments which are checked below. + auto hasDynamicShape = [](mlir::Value value) { + auto shaped_type = value.getType().dyn_cast(); + return shaped_type ? !shaped_type.hasStaticShape() : false; + }; + bool opHasDynamicShapes = false; + opHasDynamicShapes |= llvm::any_of(op->getResults(), hasDynamicShape); + for (mlir::Region ®ion : op->getRegions()) { + opHasDynamicShapes |= + llvm::any_of(region.getArguments(), hasDynamicShape); + } + if (opHasDynamicShapes) { + moduleHasDynamicShapes = true; + op->emitOpError() << "has dynamic shapes"; + } + + auto customCall = mlir::dyn_cast(op); + if (customCall && + customCall.getCallTargetName().equals(shapeAssertionName)) { + moduleHasShapeAssertions = true; + op->emitOpError() << "has residual shape assertions"; + } + }); + + if (moduleHasDynamicShapes) { + return absl::InvalidArgumentError( + absl::StrCat("Module has dynamic shapes: ", + diag_handler.ConsumeStatus().ToString())); + } + if (moduleHasShapeAssertions) { + return absl::InvalidArgumentError( + absl::StrCat("Module has residual shape assertions: ", + diag_handler.ConsumeStatus().ToString())); + } + return absl::OkStatus(); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h index ac020be1d75977..75237aeff01e89 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h @@ -16,16 +16,32 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_REFINE_POLYMORPHIC_SHAPES_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_REFINE_POLYMORPHIC_SHAPES_H_ +#include "absl/status/status.h" #include "llvm/Support/raw_ostream.h" -#include "tensorflow/compiler/xla/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project namespace xla { // Refines the dynamic shapes for a module whose "main" has static shapes // and all the intermediate dynamic shapes depend only on the input static -// shapes. Serializes the refined module to the given `os`. -xla::Status RefinePolymorphicShapes(llvm::StringRef module_str, - llvm::raw_ostream &os); +// shapes. Upon refinement, validates that the module does not contain remaining +// dynamic shapes. +// If `enable_shape_assertions` is false, then the shape assertions +// are removed from the module, otherwise they are removed only if the +// assertions hold, and result in an error otherwise. +absl::Status RefinePolymorphicShapes(mlir::ModuleOp module, + bool enable_shape_assertions); + +// Like the above but with serialized input and output modules. +// If `validate_static_shapes` is true, then checks that only static shapes +// are left after refinement. +absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, + llvm::raw_ostream &os, + bool enable_shape_assertions, + bool validate_static_shapes); + +// Validates that the module has only static shapes. +absl::Status ValidateStaticShapes(mlir::ModuleOp module); } // namespace xla diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 8b1deb8abf7372..ca088373b92bef 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -44,10 +44,10 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. -_version = 165 +_version = 167 # Version number for MLIR:Python components. -mlir_api_version = 51 +mlir_api_version = 53 xla_platform_names = { 'cpu': 'Host', diff --git a/tensorflow/compiler/xla/python/xla_compiler.cc b/tensorflow/compiler/xla/python/xla_compiler.cc index eb21fd9fa7788a..738e9672ab13d4 100644 --- a/tensorflow/compiler/xla/python/xla_compiler.cc +++ b/tensorflow/compiler/xla/python/xla_compiler.cc @@ -867,6 +867,8 @@ void BuildXlaCompilerSubmodule(py::module& m) { py::class_(m, "ExecutableBuildOptions") .def(py::init<>()) .def("__repr__", &ExecutableBuildOptions::ToString) + .def_property("fdo_profile", &ExecutableBuildOptions::fdo_profile, + &ExecutableBuildOptions::set_fdo_profile) .def_property( "result_layout", [](const ExecutableBuildOptions& options) -> std::optional { diff --git a/tensorflow/compiler/xla/python/xla_extension/__init__.pyi b/tensorflow/compiler/xla/python/xla_extension/__init__.pyi index 38cef1047a4d4f..98e3069d1bef9b 100644 --- a/tensorflow/compiler/xla/python/xla_extension/__init__.pyi +++ b/tensorflow/compiler/xla/python/xla_extension/__init__.pyi @@ -268,6 +268,7 @@ class ExecutableBuildOptions: def __init__(self) -> None: ... def __repr__(self) -> str: ... result_layout: Optional[Shape] + fdo_profile: Optional[bytes] num_replicas: int num_partitions: int debug_options: DebugOptions diff --git a/tensorflow/compiler/xla/python/xla_extension/mlir.pyi b/tensorflow/compiler/xla/python/xla_extension/mlir.pyi index 2e47b1746833e9..f62c6565dc0a41 100644 --- a/tensorflow/compiler/xla/python/xla_extension/mlir.pyi +++ b/tensorflow/compiler/xla/python/xla_extension/mlir.pyi @@ -24,4 +24,6 @@ def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> str: ... def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> str: ... def serialize_portable_artifact(mlir_module: str, target:str) -> bytes: ... def deserialize_portable_artifact(mlir_module: bytes) -> str: ... -def refine_polymorphic_shapes(mlir_module: Union[bytes, str]) -> bytes: ... +def refine_polymorphic_shapes(mlir_module: Union[bytes, str], + enable_shape_assertions: bool = ..., + validate_static_shapes: bool = ...) -> bytes: ... diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD index 1185709fd30075..6dddb9d69858b8 100644 --- a/tensorflow/compiler/xla/python_api/BUILD +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") + # Description: # Python API for XLA. load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") @@ -9,7 +11,7 @@ package( generate_backend_suites() -py_library( +py_strict_library( name = "types", srcs = ["types_.py"], srcs_version = "PY3", @@ -22,7 +24,7 @@ py_library( ], ) -py_library( +py_strict_library( name = "xla_shape", srcs = ["xla_shape.py"], srcs_version = "PY3", @@ -30,10 +32,11 @@ py_library( deps = [ ":types", "//tensorflow/compiler/xla:xla_data_proto_py", + "//third_party/py/numpy", ], ) -py_library( +py_strict_library( name = "xla_literal", srcs = ["xla_literal.py"], srcs_version = "PY3", @@ -42,10 +45,11 @@ py_library( ":types", ":xla_shape", "//tensorflow/compiler/xla:xla_data_proto_py", + "//third_party/py/numpy", ], ) -py_test( +py_strict_test( name = "xla_shape_test", srcs = ["xla_shape_test.py"], python_version = "PY3", @@ -55,12 +59,13 @@ py_test( ], deps = [ ":xla_shape", + "//tensorflow/compiler/xla:xla_data_proto_py", "//third_party/py/numpy", "@absl_py//absl/testing:absltest", ], ) -py_test( +py_strict_test( name = "xla_literal_test", srcs = ["xla_literal_test.py"], python_version = "PY3", @@ -70,6 +75,7 @@ py_test( ], deps = [ ":xla_literal", + "//tensorflow/compiler/xla:xla_data_proto_py", "//third_party/py/numpy", "@absl_py//absl/testing:absltest", ], diff --git a/tensorflow/compiler/xla/runtime/BUILD b/tensorflow/compiler/xla/runtime/BUILD index de4875741b977c..719ea5e796e715 100644 --- a/tensorflow/compiler/xla/runtime/BUILD +++ b/tensorflow/compiler/xla/runtime/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") -load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load("//tensorflow/tsl/platform:build_config.bzl", "if_llvm_system_z_available", "tf_platform_deps") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") @@ -16,7 +16,7 @@ cc_library( name = "arguments", srcs = ["arguments.cc"], hdrs = ["arguments.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":async_runtime", ":types", @@ -42,7 +42,7 @@ cc_library( name = "async_runtime", srcs = ["async_runtime.cc"], hdrs = ["async_runtime.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:platform_port", @@ -66,7 +66,7 @@ xla_cc_test( cc_library( name = "async_values_cache", hdrs = ["async_values_cache.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform", ] + tf_platform_deps( @@ -79,7 +79,7 @@ cc_library( name = "constraints", srcs = ["constraints.cc"], hdrs = ["constraints.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -123,7 +123,7 @@ cc_library( name = "custom_call", srcs = ["custom_call.cc"], hdrs = ["custom_call.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":async_runtime", ":diagnostics", @@ -170,7 +170,7 @@ cc_library( name = "custom_call_registry", srcs = ["custom_call_registry.cc"], hdrs = ["custom_call_registry.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call", "@llvm-project//llvm:Support", @@ -181,7 +181,7 @@ cc_library( name = "diagnostics", srcs = ["diagnostics.cc"], hdrs = ["diagnostics.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":logical_result", "//tensorflow/tsl/platform:logging", @@ -204,7 +204,7 @@ xla_cc_test( cc_library( name = "errors", hdrs = ["errors.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", @@ -215,7 +215,7 @@ cc_library( name = "executable", srcs = ["executable.cc"], hdrs = ["executable.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":arguments", ":async_runtime", @@ -269,7 +269,7 @@ cc_library( name = "execution_engine", srcs = ["execution_engine.cc"], hdrs = ["execution_engine.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":errors", "@com_google_absl//absl/status", @@ -340,7 +340,7 @@ cc_library( name = "jit_executable", srcs = ["jit_executable.cc"], hdrs = ["jit_executable.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":async_values_cache", ":constraints", @@ -359,14 +359,14 @@ cc_library( cc_library( name = "logical_result", hdrs = ["logical_result.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = ["@llvm-project//mlir:Support"], ) cc_library( name = "map_by_type", hdrs = ["map_by_type.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":type_id", "@llvm-project//llvm:Support", @@ -388,7 +388,7 @@ cc_library( name = "memory_mapper", srcs = ["memory_mapper.cc"], hdrs = ["memory_mapper.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform", "@llvm-project//llvm:ExecutionEngine", @@ -402,7 +402,7 @@ cc_library( cc_library( name = "memref_view", hdrs = ["memref_view.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:xla_data_proto_cc", "@com_google_absl//absl/types:span", @@ -412,7 +412,7 @@ cc_library( cc_library( name = "module", hdrs = ["module.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call_registry", "@com_google_absl//absl/status", @@ -424,7 +424,7 @@ cc_library( name = "module_registry", srcs = ["module_registry.cc"], hdrs = ["module_registry.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":module", ], @@ -444,7 +444,7 @@ xla_cc_test( cc_library( name = "results", hdrs = ["results.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":logical_result", ":types", @@ -467,13 +467,13 @@ xla_cc_test( cc_library( name = "runtime", hdrs = ["runtime.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_library( name = "state", hdrs = ["state.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -495,7 +495,7 @@ cc_library( name = "symbolic_shape", srcs = ["symbolic_shape.cc"], hdrs = ["symbolic_shape.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":arguments", ":constraints", @@ -526,7 +526,7 @@ cc_library( name = "types", srcs = ["types.cc"], hdrs = ["types.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -540,7 +540,7 @@ cc_library( cc_library( name = "tracing", hdrs = ["tracing.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call", ":type_id", @@ -551,7 +551,7 @@ cc_library( name = "type_id", srcs = ["type_id.cc"], hdrs = ["type_id.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//mlir:Support", @@ -561,7 +561,7 @@ cc_library( cc_library( name = "compiler", hdrs = ["compiler.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_library( diff --git a/tensorflow/compiler/xla/runtime/custom_call_test.cc b/tensorflow/compiler/xla/runtime/custom_call_test.cc index 28888354a2e09c..0098ec53cf5f9a 100644 --- a/tensorflow/compiler/xla/runtime/custom_call_test.cc +++ b/tensorflow/compiler/xla/runtime/custom_call_test.cc @@ -72,7 +72,8 @@ static absl::StatusOr Compile( }; opts.compiler.create_compilation_pipeline = [=](PassManager& passes) { - CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts); + CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts, + /*add_async_passes=*/true); }; return JitExecutable::Instantiate(source, opts, exported); diff --git a/tensorflow/compiler/xla/runtime/executable.cc b/tensorflow/compiler/xla/runtime/executable.cc index 357d3ce9b03a75..c2e07c5bcffda1 100644 --- a/tensorflow/compiler/xla/runtime/executable.cc +++ b/tensorflow/compiler/xla/runtime/executable.cc @@ -470,6 +470,11 @@ bool Executable::IsAsync(unsigned ordinal) const { return functions_[ordinal].results_memory_layout.has_async_results; } +std::string_view Executable::function_name(unsigned ordinal) const { + assert(ordinal < functions_.size() && "function ordinal out of bounds"); + return functions_[ordinal].name; +} + unsigned Executable::num_results(unsigned ordinal) const { assert(ordinal < functions_.size() && "function ordinal out of bounds"); return functions_[ordinal].runtime_signature.num_results(); diff --git a/tensorflow/compiler/xla/runtime/executable.h b/tensorflow/compiler/xla/runtime/executable.h index b6f76800340afd..44f6332300f473 100644 --- a/tensorflow/compiler/xla/runtime/executable.h +++ b/tensorflow/compiler/xla/runtime/executable.h @@ -162,6 +162,10 @@ class Executable { bool IsAsync(unsigned ordinal) const; bool IsAsync() const { return IsAsync(0); } + // Returns the name of the exported function with the given ordinal. + std::string_view function_name(unsigned ordinal) const; + std::string_view function_name() const { return function_name(0); } + // Returns the number of results of the exported function with given ordinal. unsigned num_results(unsigned ordinal) const; unsigned num_results() const { return num_results(0); } diff --git a/tensorflow/compiler/xla/runtime/ffi/BUILD b/tensorflow/compiler/xla/runtime/ffi/BUILD index c94568b48c9008..c96dcf221ce7d4 100644 --- a/tensorflow/compiler/xla/runtime/ffi/BUILD +++ b/tensorflow/compiler/xla/runtime/ffi/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,13 +18,13 @@ filegroup( cc_library( name = "ffi_abi", hdrs = ["ffi_abi.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_library( name = "ffi_api", hdrs = ["ffi_api.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":ffi_abi", ":ffi_c_api_hdrs", @@ -34,5 +34,5 @@ cc_library( cc_library( name = "ffi_c_api_hdrs", hdrs = ["ffi_c_api.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 00a89e2b928924..6286bef5515464 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1,6 +1,7 @@ # Description: # XLA service implementation. +load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load( @@ -1803,7 +1804,6 @@ cc_library( deps = [ ":heap_simulator", ":hlo_alias_analysis", - ":hlo_ordering", ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", @@ -2309,6 +2309,7 @@ cc_library( srcs = ["algebraic_simplifier.cc"], hdrs = ["algebraic_simplifier.h"], deps = [ + ":hlo_cost_analysis", ":hlo_creation_utils", ":hlo_pass", ":pattern_matcher", @@ -4604,22 +4605,19 @@ cc_library( srcs = ["hlo_rematerialization.cc"], hdrs = ["hlo_rematerialization.h"], deps = [ - ":buffer_value", ":call_graph", - ":flatten_call_graph", + ":hlo_dataflow_analysis", ":hlo_dce", - ":hlo_memory_scheduler", ":hlo_ordering", + ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_query", - "//tensorflow/tsl/platform:logging", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -4635,16 +4633,11 @@ cc_library( testonly = 1, hdrs = ["hlo_rematerialization_test_utils.h"], deps = [ - ":hlo_ordering", - ":hlo_rematerialization", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/tsl/lib/core:status_test_util", ], ) @@ -4652,16 +4645,10 @@ xla_cc_test( name = "hlo_rematerialization_test_utils_test", srcs = ["hlo_rematerialization_test_utils_test.cc"], deps = [ - ":hlo_ordering", ":hlo_rematerialization_test_utils", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/tsl/lib/core:status_test_util", ], ) @@ -4669,6 +4656,7 @@ xla_cc_test( name = "hlo_rematerialization_test", srcs = ["hlo_rematerialization_test.cc"], deps = [ + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_rematerialization", ":hlo_rematerialization_test_utils", @@ -5115,6 +5103,7 @@ cc_library( "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/tsl/platform:protobuf", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -6425,19 +6414,19 @@ xla_py_proto_library( deps = [":hlo_proto"], ) -py_library( +py_strict_library( name = "generate_test_hlo_checks", srcs = ["generate_test_hlo_checks.py"], srcs_version = "PY3", ) -py_test( +py_strict_test( name = "generate_test_hlo_checks_test", srcs = ["generate_test_hlo_checks_test.py"], python_version = "PY3", - # TODO(b/200806426): Test fails in OSS. tags = [ "no_oss", + "nopip", ], deps = [ ":generate_test_hlo_checks", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 16c374647b7aef..63889515ff4a37 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape.h" @@ -395,19 +396,70 @@ bool ValidateTilingOfBitcast( return true; } -double GetDotFlops(const HloInstruction* dot) { - // A dot of arrays of size ab and bc requires ac(2b-1) flops - // In general, we compute the flops per element in the output shape - double contraction_prod = 1; - auto lhs_contracting_dims = - dot->dot_dimension_numbers().lhs_contracting_dimensions(); - for (auto dim : lhs_contracting_dims) { - contraction_prod *= dot->operand(0)->shape().dimensions(dim); +// Constructs the maps that take dims of A and dims of B to dims of AB, mapping +// to -1 for dimensions not present in AB. For an example, consider we are +// computing a dot whose operands have shapes [m,n,p] and [n,q]. Assuming we +// contract over n, this produces an array with shape [m,p,q]. This function +// will return vectors map_a_ab = {0, -1, 1} and map_b_ab = {-1, 2} +std::pair, std::vector> ConstructToDotMaps( + DotDimensionNumbers dnums, const Shape& a_shape, const Shape& b_shape) { + std::vector map_a_ab, map_b_ab; + int ab_index = 0; + // Extract a and b contraction dimensions from dnums + auto a_contracting_dims = dnums.lhs_contracting_dimensions(); + auto b_contracting_dims = dnums.rhs_contracting_dimensions(); + // Iterating through the dimensions of a + for (int a_index = 0; a_index < a_shape.rank(); a_index++) { + if (absl::c_linear_search(a_contracting_dims, a_index)) { + map_a_ab.push_back(-1); + } else { + map_a_ab.push_back(ab_index); + ab_index++; + } } - // Flops include multiplications and adds - double flops_per_output_elem = 2 * contraction_prod - 1; - // We then multiply this number by the number of elements in the output shape - return flops_per_output_elem * ShapeUtil::ElementsIn(dot->shape()); + // Iterating through the dimensions of b + for (int b_index = 0; b_index < b_shape.rank(); b_index++) { + if (absl::c_linear_search(b_contracting_dims, b_index)) { + map_b_ab.push_back(-1); + } else { + map_b_ab.push_back(ab_index); + ab_index++; + } + } + return {map_a_ab, map_b_ab}; +} + +// Constructs the maps that take dims of AB to dims of A and dims of B mapping +// to -1 for dimensions not present in A/B. For an example, consider we are +// computing a dot whose operands have shapes [m,n,p] and [n,q]. Assuming we +// contract over n, this produces an array with shape [m,p,q]. This function +// will return vectors map_ab_a = {0,2,-1} and map_ab_b = {-1,-1,1} +std::pair, std::vector> ConstructFromDotMaps( + const HloInstruction* dot, const Shape& a_shape, const Shape& b_shape) { + // Reserve space for new maps + std::vector map_ab_a, map_ab_b; + map_ab_a.resize(dot->shape().rank(), -1); + map_ab_b.resize(dot->shape().rank(), -1); + // Construct the maps going in the opposite direction + std::vector map_a_ab, map_b_ab; + std::tie(map_a_ab, map_b_ab) = + ConstructToDotMaps(dot->dot_dimension_numbers(), a_shape, b_shape); + // Construct these maps by inverting those above + int a_index = 0; + for (auto ab_index : map_a_ab) { + if (ab_index != -1) { + map_ab_a[ab_index] = a_index; + } + a_index++; + } + int b_index = 0; + for (auto ab_index : map_b_ab) { + if (ab_index != -1) { + map_ab_b[ab_index] = b_index; + } + b_index++; + } + return {map_ab_a, map_ab_b}; } } // namespace @@ -1078,8 +1130,26 @@ Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { ReplaceWithNewInstruction(bitcast, std::move(new_bitcast))); bitcast = new_bitcast_ptr; } + // All bitcasts can be eliminated (assuming layout constraints are satisfied). - ReplaceInstructionIfCompatible(bitcast, bitcast->mutable_operand(0)); + HloInstruction* new_bitcast = bitcast->mutable_operand(0); + if (ReplaceInstructionIfCompatible(bitcast, new_bitcast)) { + bitcast = new_bitcast; + } + + // Check whether we can potentially simplify the bitcast into a broadcast + // operand. + if (bitcast->opcode() == HloOpcode::kBitcast && + bitcast->operand(0)->opcode() == HloOpcode::kBroadcast) { + // DeduceTransposeDimensionsForBitcast() checks whether the bitcast is a + // transpose and returns the dimensions attribute if it is. + auto dimensions = ShapeUtil::DeduceTransposeDimensionsForBitcast( + bitcast->operand(0)->shape(), bitcast->shape()); + if (dimensions.has_value()) { + return SimplifyTransposeOfBroadcast(bitcast, dimensions.value()); + } + } + return OkStatus(); } @@ -2180,6 +2250,84 @@ StatusOr AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( return true; } +// transpose(broadcast(x)) -> broadcast(x), if the transpose leaves the relative +// order of the dimensions of `x` unchanged. +// +// To understand the permutations logic here, consider a simple case. +// +// bcast = f32[1,2,3,4] broadcast(f32[2,4] x), dimensions={1,3} +// trans = f32[2,3,1,4] transpose(f32[1,2,3,4] bcast), dimensions={1,2,0,3} +// +// We want to transform this into +// +// bcast' = f32[2,3,1,4] broadcast(f32[2,4] x), dimensions={0,3} +Status AlgebraicSimplifierVisitor::SimplifyTransposeOfBroadcast( + HloInstruction* transpose, absl::Span dimensions) { + HloInstruction* broadcast = transpose->mutable_operand(0); + if (broadcast->opcode() != HloOpcode::kBroadcast || + !absl::c_is_sorted(broadcast->dimensions())) { + return OkStatus(); + } + + // The algorithm to compute bcast'.dimensions() is: + // + // * Let p' be the inverse of trans.dimensions(); in the example, {2,0,1,3}. + // * bcast'.dimensions() is [p'[dim] for dim in bcast.dimensions()]. In the + // example, p'[1] = 0, meaning that broadcast dim 1 (size 2) ends up at + // index 0 after the transpose. + // + // We also need to check that bcast'.dimensions() is "sorted the same" as + // bcast.dimensions() -- otherwise, we're simply moving the transpose into the + // broadcast op. For now we cowardly refuse to consider broadcasts except + // where their dimensions() are sorted, so we need only check that + // bcast'.dimensions() is sorted. + // + // No one-user requirement on the transpose because having two different + // broadcasts of x should be cheap -- certainly cheaper than using the + // fully-materialized broadcasted+transposed value. + + auto inv_perm = InversePermutation(dimensions); + absl::InlinedVector new_bcast_dims; + for (int64_t dim : broadcast->dimensions()) { + new_bcast_dims.push_back(inv_perm[dim]); + } + if (!absl::c_is_sorted(new_bcast_dims)) { + return OkStatus(); + } + // We don't want to create broadcasts that create implicit transposes. Check + // whether the relative order of the layout of the broadcasted dimensions is + // the same as the broadcast operand layout. + if (options_.is_layout_sensitive()) { + std::vector perm1(new_bcast_dims.size()); + absl::c_iota(perm1, 0); + std::vector perm2 = perm1; + Layout operand_layout = broadcast->mutable_operand(0)->shape().layout(); + absl::c_sort(perm1, [&](int a, int b) { + return operand_layout.minor_to_major(a) < + operand_layout.minor_to_major(b); + }); + Layout transpose_layout = transpose->shape().layout(); + // Extract the part of the layout that corresponds to the broadcasted + // dimensions. + std::vector extracted_layout; + extracted_layout.reserve(new_bcast_dims.size()); + for (int64_t dim : transpose_layout.minor_to_major()) { + if (absl::c_binary_search(new_bcast_dims, dim)) { + extracted_layout.push_back(dim); + } + } + absl::c_sort(perm2, [&](int a, int b) { + return extracted_layout[a] < extracted_layout[b]; + }); + if (perm1 != perm2) { + return OkStatus(); + } + } + return ReplaceInstruction( + transpose, MakeBroadcastHlo(broadcast->mutable_operand(0), new_bcast_dims, + transpose->shape())); +} + StatusOr AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( HloInstruction* dot) { const int64_t rank = dot->shape().rank(); @@ -2832,6 +2980,90 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceWithNewInstruction(dot, std::move(new_instruction)); } + // Reorder nested dots with associativity using flops as a heuristic + if (options_.use_associative_reordering()) { + HloInstruction *a, *b, *c; + HloInstruction *old_inner, *old_outer, *new_inner, *new_outer; + DotDimensionNumbers ab_dnums, ac_dnums, bc_dnums; + // Here we extract the contracting dimensions shared between A and B, A and + // C, and B and C, and use these to build up the dimension numbers for the + // reordered dot A(BC). + if (Match(dot, m::Dot(m::Dot(m::Op(&a), m::Op(&b)), m::Op(&c))) && + dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0) { + // We already have the ab_dnums for free + ab_dnums = dot->operand(0)->dot_dimension_numbers(); + // Get maps for converting AB dimensions to A and B + std::vector map_ab_a, map_ab_b; + std::tie(map_ab_a, map_ab_b) = + ConstructFromDotMaps(dot->operand(0), a->shape(), b->shape()); + // Recover ac_dnums and bc_dnums from ab_c_dnums + DotDimensionNumbers ab_c_dnums = dot->dot_dimension_numbers(); + for (int i = 0; i < ab_c_dnums.lhs_contracting_dimensions_size(); i++) { + auto ab_index = ab_c_dnums.lhs_contracting_dimensions(i); + auto c_index = ab_c_dnums.rhs_contracting_dimensions(i); + if (map_ab_b[ab_index] == -1) { + ac_dnums.add_lhs_contracting_dimensions(map_ab_a[ab_index]); + ac_dnums.add_rhs_contracting_dimensions(c_index); + } else { + bc_dnums.add_lhs_contracting_dimensions(map_ab_b[ab_index]); + bc_dnums.add_rhs_contracting_dimensions(c_index); + } + } + + // Get maps for converting B and C dimensions to BC + std::vector map_b_bc, map_c_bc; + std::tie(map_b_bc, map_c_bc) = + ConstructToDotMaps(bc_dnums, b->shape(), c->shape()); + // Now build a_bc_dnums from ab_dnums and bc_dnums + DotDimensionNumbers a_bc_dnums; + for (int i = 0; i < ab_dnums.lhs_contracting_dimensions_size(); i++) { + auto a_index = ab_dnums.lhs_contracting_dimensions(i); + auto b_index = ab_dnums.rhs_contracting_dimensions(i); + a_bc_dnums.add_lhs_contracting_dimensions(a_index); + a_bc_dnums.add_rhs_contracting_dimensions(map_b_bc[b_index]); + } + for (int i = 0; i < ac_dnums.lhs_contracting_dimensions_size(); i++) { + auto a_index = ac_dnums.lhs_contracting_dimensions(i); + auto c_index = ac_dnums.rhs_contracting_dimensions(i); + a_bc_dnums.add_lhs_contracting_dimensions(a_index); + a_bc_dnums.add_rhs_contracting_dimensions(map_c_bc[c_index]); + } + + // Make Hlo for reordering dot + old_inner = lhs; + old_outer = dot; + TF_ASSIGN_OR_RETURN(new_inner, + MakeDotHlo(b, c, bc_dnums, dot->precision_config(), + dot->shape().element_type())); + TF_ASSIGN_OR_RETURN(new_outer, MakeDotHlo(a, new_inner, a_bc_dnums, + dot->precision_config(), + dot->shape().element_type())); + + // Use HloCostAnalysis to compute flops for both the original and + // reordered instructions, and reorder if doing so decreases flops by a + // factor of the reordering threshold. + const int64_t old_flops = + HloCostAnalysis::GetDotFlops(old_inner->operand(0)->shape(), + old_inner->shape(), + old_inner->dot_dimension_numbers()) + + HloCostAnalysis::GetDotFlops(old_outer->operand(0)->shape(), + old_outer->shape(), + old_outer->dot_dimension_numbers()); + const int64_t new_flops = + HloCostAnalysis::GetDotFlops(new_inner->operand(0)->shape(), + new_inner->shape(), + new_inner->dot_dimension_numbers()) + + HloCostAnalysis::GetDotFlops(new_outer->operand(0)->shape(), + new_outer->shape(), + new_outer->dot_dimension_numbers()); + if (old_flops / new_flops > options_.associative_reordering_threshold()) { + VLOG(10) << "Reordering with associativity"; + return ReplaceInstruction(dot, new_outer); + } + } + // TODO(b/289120301) Implement other direction after first looks good + } + // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) if (!is_packed_nibble && options_.enable_dot_strength_reduction() && @@ -2944,52 +3176,6 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { if (removed_transposes) { return OkStatus(); } - - // Reorder nested dots with associativity using flops as a heuristic - if (options_.use_associative_reordering()) { - // TODO(b/289120301): Update with symmetric contraction form - HloInstruction *a, *b, *c; - HloInstruction *dot_a_b, *dot_ab_c, *dot_b_c, *dot_a_bc; - int64_t left_first_flops, right_first_flops; - if (Match(dot, m::Dot(m::Dot(m::Op(&a), m::Op(&b)), m::Op(&c)))) { - dot_a_b = lhs; - dot_ab_c = dot; - TF_ASSIGN_OR_RETURN( - dot_b_c, - MakeDotHlo(b, c, dot->dot_dimension_numbers(), - dot->precision_config(), dot->shape().element_type())); - TF_ASSIGN_OR_RETURN( - dot_a_bc, - MakeDotHlo(a, dot_b_c, dot_a_b->dot_dimension_numbers(), - dot->precision_config(), dot->shape().element_type())); - left_first_flops = GetDotFlops(dot_a_b) + GetDotFlops(dot_ab_c); - right_first_flops = GetDotFlops(dot_b_c) + GetDotFlops(dot_a_bc); - if (left_first_flops > - options_.associative_reordering_threshold() * right_first_flops) { - return ReplaceInstruction(dot, dot_a_bc); - } - } else if (Match(dot, m::Dot(m::Op(&a), m::Dot(m::Op(&b), m::Op(&c))))) { - dot_b_c = rhs; - dot_a_bc = dot; - TF_ASSIGN_OR_RETURN( - dot_a_b, - MakeDotHlo(a, b, dot->dot_dimension_numbers(), - dot->precision_config(), dot->shape().element_type())); - TF_ASSIGN_OR_RETURN( - dot_ab_c, - MakeDotHlo(dot_a_b, c, dot_b_c->dot_dimension_numbers(), - dot->precision_config(), dot->shape().element_type())); - left_first_flops = GetDotFlops(dot_a_b) + GetDotFlops(dot_ab_c); - right_first_flops = GetDotFlops(dot_b_c) + GetDotFlops(dot_a_bc); - if (right_first_flops > - options_.associative_reordering_threshold() * left_first_flops) { - return ReplaceInstruction(dot, dot_ab_c); - } - } else { - return OkStatus(); - } - } - return OkStatus(); } @@ -6885,47 +7071,8 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { } } - // transpose(broadcast(x)) -> broadcast(x), if the transpose leaves the - // relative order of the dimensions of `x` unchanged. - // - // To understand the permutations logic here, consider a simple case. - // - // bcast = f32[1,2,3,4] broadcast(f32[2,4] x), dimensions={1,3} - // trans = f32[2,3,1,4] transpose(f32[1,2,3,4] bcast), dimensions={1,2,0,3} - // - // We want to transform this into - // - // bcast' = f32[2,3,1,4] broadcast(f32[2,4] x), dimensions={0,3} - // - // The algorithm to compute bcast'.dimensions() is: - // - // * Let p' be the inverse of trans.dimensions(); in the example, {2,0,1,3}. - // * bcast'.dimensions() is [p'[dim] for dim in bcast.dimensions()]. In the - // example, p'[1] = 0, meaning that broadcast dim 1 (size 2) ends up at - // index 0 after the transpose. - // - // We also need to check that bcast'.dimensions() is "sorted the same" as - // bcast.dimensions() -- otherwise, we're simply moving the transpose into the - // broadcast op. For now we cowardly refuse to consider broadcasts except - // where their dimensions() are sorted, so we need only check that - // bcast'.dimensions() is sorted. - // - // No one-user requirement on the transpose because having two different - // broadcasts of x should be cheap -- certainly cheaper than using the - // fully-materialized broadcasted+transposed value. - if (operand->opcode() == HloOpcode::kBroadcast && - absl::c_is_sorted(operand->dimensions())) { - auto inv_perm = InversePermutation(transpose->dimensions()); - absl::InlinedVector new_bcast_dims; - for (int64_t dim : operand->dimensions()) { - new_bcast_dims.push_back(inv_perm[dim]); - } - if (absl::c_is_sorted(new_bcast_dims)) { - return ReplaceInstruction( - transpose, MakeBroadcastHlo(operand->mutable_operand(0), - new_bcast_dims, transpose->shape())); - } - } + TF_RETURN_IF_ERROR( + SimplifyTransposeOfBroadcast(transpose, transpose->dimensions())); return OkStatus(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 0d016cf2ffae3a..0d5a13db092a9c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -443,6 +443,11 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Removes degenerate dimension from dot. StatusOr RemoveDegenerateDimensionFromDot(HloInstruction* dot); + // Moves the transpose to the broadcast if possible. Can also be called with a + // bitcast transpose. + Status SimplifyTransposeOfBroadcast(HloInstruction* transpose, + absl::Span dimensions); + // Converts to primitive type if the input hlo is not that type, otherwise // returns the original hlo. HloInstruction* AsType(HloInstruction* hlo, diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 162aac070b98b7..18eea15b0b1066 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5695,20 +5695,20 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfDot) { PrecisionConfig::HIGHEST); } -TEST_F(AlgebraicSimplifierTest, DotAttentionReorder) { +TEST_F(AlgebraicSimplifierTest, DotAssociativeReorder) { const char* hlo_string = R"( HloModule module ENTRY test { - a = f32[1024,2] parameter(0) - b = f32[2,1024] parameter(1) - c = f32[1024,2] parameter(2) - inner_dot = f32[1024,1024] dot(a,b), - lhs_contracting_dims={1}, - rhs_contracting_dims={0} - ROOT outer_dot = f32[1024,2] dot(inner_dot, c), - lhs_contracting_dims={1}, - rhs_contracting_dims={0} + a = f32[2,3,4,5] parameter(0) + b = f32[6,7,5] parameter(1) + c = f32[4,7] parameter(2) + inner_dot = f32[2,3,4,6,7] dot(a,b), + lhs_contracting_dims={3}, + rhs_contracting_dims={2} + ROOT outer_dot = f32[2,3,6] dot(inner_dot,c), + lhs_contracting_dims={2,4}, + rhs_contracting_dims={0,1} } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -9451,6 +9451,42 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcast) { }))); } +TEST_F(AlgebraicSimplifierTest, TransposeBitcastOfBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + bcast = f32[10,2,3,4]{3,2,1,0} broadcast(f32[2,4]{1,0} parameter(0)), dimensions={1,3} + ROOT trans = f32[2,3,10,4]{3,1,0,2} bitcast(bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + EXPECT_TRUE(RunHloPass(AlgebraicSimplifier(options), m.get()).value()); + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::Broadcast(m::Parameter(0)) + .WithPredicate([](const HloInstruction* instr) { + return instr->dimensions() == std::vector({0, 3}); + }))); +} + +TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcastWithLayoutCheckSkipped) { + const char* kModuleStr = R"( + HloModule m + test { + bcast = f32[10,2,3,4]{3,2,1,0} broadcast(f32[2,4]{1,0} parameter(0)), dimensions={1,3} + ROOT trans = f32[2,3,10,4]{0,1,2,3} transpose(bcast), dimensions={1,2,0,3} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + EXPECT_FALSE(RunHloPass(AlgebraicSimplifier(options), m.get()).value()); +} + TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcastSkipped) { const char* kModuleStr = R"( HloModule m diff --git a/tensorflow/compiler/xla/service/collective_permute_decomposer.cc b/tensorflow/compiler/xla/service/collective_permute_decomposer.cc index 0359dc6d8bfd68..d17a2af5c1576a 100644 --- a/tensorflow/compiler/xla/service/collective_permute_decomposer.cc +++ b/tensorflow/compiler/xla/service/collective_permute_decomposer.cc @@ -98,6 +98,7 @@ Status DecomposeCollectivePermute( int64_t channel_id = collective_permute->channel_id().value_or(0); HloInstruction* data = collective_permute->mutable_operand(0); const Shape& data_shape = data->shape(); + const OpMetadata& metadata = collective_permute->metadata(); xla::FrontendAttributes attributes; std::string source_target_pairs_string = @@ -121,10 +122,12 @@ Status DecomposeCollectivePermute( HloInstruction* recv = computation->AddInstruction( HloInstruction::CreateRecv(data_shape, after_all, channel_id)); recv->set_frontend_attributes(attributes); + recv->set_metadata(metadata); HloInstruction* send = computation->AddInstruction( HloInstruction::CreateSend(data, after_all, channel_id)); send->set_frontend_attributes(attributes); + send->set_metadata(metadata); // We want the Recv to be scheduled before the Send, enforce this with a // control dependency. TF_RETURN_IF_ERROR(recv->AddControlDependencyTo(send)); diff --git a/tensorflow/compiler/xla/service/collective_permute_decomposer_test.cc b/tensorflow/compiler/xla/service/collective_permute_decomposer_test.cc index 20d01ed2e00f49..0aec6e6ceba85b 100644 --- a/tensorflow/compiler/xla/service/collective_permute_decomposer_test.cc +++ b/tensorflow/compiler/xla/service/collective_permute_decomposer_test.cc @@ -94,7 +94,8 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { ENTRY test_computation { p = u32[] replica-id() start = (u32[], u32[]) collective-permute-start(p), - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, + metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} ROOT done = u32[] collective-permute-done(start) } )"; @@ -105,6 +106,12 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); EXPECT_TRUE(changed); + auto check_metadata = [](const HloInstruction* inst) { + EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add"); + EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py"); + EXPECT_EQ(inst->metadata().source_line(), 35); + }; + HloInstruction* after_all = FindInstruction(module.get(), "after-all"); HloInstruction* recv = FindInstruction(module.get(), "recv"); EXPECT_EQ(recv->operand(0), after_all); @@ -113,6 +120,7 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { recv->ToString(), HasSubstr( "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + check_metadata(recv); HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); EXPECT_EQ(recv_done->operand(0), recv); @@ -124,6 +132,7 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { send->ToString(), HasSubstr( "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + check_metadata(send); HloInstruction* send_done = FindInstruction(module.get(), "send-done"); EXPECT_EQ(send_done->operand(0), send); EXPECT_EQ(send_done->control_predecessors()[0], recv_done); diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index de568a260c64eb..d75cab95226ce0 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -244,6 +244,8 @@ class Compiler { std::function, Shape>>( const HloModule& module)> layout_canonicalization_callback = {}; + + bool enable_debug_info_manager = true; }; virtual ~Compiler() = default; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 0e04e5295c7351..6c434c35acf17e 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -74,13 +74,25 @@ filegroup( "runtime_single_threaded_conv2d.cc", "runtime_single_threaded_conv3d.cc", "runtime_single_threaded_fft.cc", - "runtime_single_threaded_matmul.cc", + "runtime_single_threaded_matmul_c128.cc", + "runtime_single_threaded_matmul_c64.cc", + "runtime_single_threaded_matmul_common.h", + "runtime_single_threaded_matmul_f16.cc", + "runtime_single_threaded_matmul_f32.cc", + "runtime_single_threaded_matmul_f64.cc", + "runtime_single_threaded_matmul_s32.cc", "runtime_topk.cc", # Multi-threaded support. "runtime_conv2d.cc", "runtime_conv3d.cc", "runtime_fft.cc", - "runtime_matmul.cc", + "runtime_matmul_c128.cc", + "runtime_matmul_c64.cc", + "runtime_matmul_common.h", + "runtime_matmul_f16.cc", + "runtime_matmul_f32.cc", + "runtime_matmul_f64.cc", + "runtime_matmul_s32.cc", "runtime_fork_join.cc", ], visibility = [":friends"], @@ -410,15 +422,16 @@ cc_library( "@llvm-project//mlir:BufferizationToMemRef", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:ComplexToStandard", - "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:ShapeTransforms", @@ -426,6 +439,7 @@ cc_library( "@llvm-project//mlir:TensorToLinalg", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorToLLVM", "@llvm-project//mlir:VectorToSCF", "@llvm-project//mlir:VectorTransforms", ], @@ -944,7 +958,15 @@ cc_library( cc_library( name = "runtime_matmul", - srcs = ["runtime_matmul.cc"], + srcs = [ + "runtime_matmul_c128.cc", + "runtime_matmul_c64.cc", + "runtime_matmul_common.h", + "runtime_matmul_f16.cc", + "runtime_matmul_f32.cc", + "runtime_matmul_f64.cc", + "runtime_matmul_s32.cc", + ], hdrs = ["runtime_matmul.h"], copts = runtime_copts(), visibility = ["//visibility:public"], @@ -1066,7 +1088,15 @@ cc_library( cc_library( name = "runtime_single_threaded_matmul_impl", - srcs = ["runtime_single_threaded_matmul.cc"], + srcs = [ + "runtime_single_threaded_matmul_c128.cc", + "runtime_single_threaded_matmul_c64.cc", + "runtime_single_threaded_matmul_common.h", + "runtime_single_threaded_matmul_f16.cc", + "runtime_single_threaded_matmul_f32.cc", + "runtime_single_threaded_matmul_f64.cc", + "runtime_single_threaded_matmul_s32.cc", + ], hdrs = ["runtime_single_threaded_matmul.h"], compatible_with = get_compatible_with_portable(), copts = runtime_copts(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index bb74517b91cfb3..a358c25fa471d6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -75,8 +75,8 @@ CpuExecutable::CpuExecutable( std::make_shared(assignment_->ToProto()); } if (has_module()) { - XlaDebugInfoManager::Get()->RegisterModule( - module().unique_id(), shared_module(), buffer_assignment_); + XlaDebugInfoManager::Get()->RegisterModule(shared_module(), + buffer_assignment_); } // Resolve symbols in the constructor rather than at execution time to avoid @@ -110,8 +110,8 @@ CpuExecutable::CpuExecutable( std::make_shared(assignment_->ToProto()); } if (has_module()) { - XlaDebugInfoManager::Get()->RegisterModule( - module().unique_id(), shared_module(), buffer_assignment_); + XlaDebugInfoManager::Get()->RegisterModule(shared_module(), + buffer_assignment_); } } diff --git a/tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc index dd9221df95e793..08496d0e72c8d3 100644 --- a/tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -20,10 +20,12 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project -#include "mlir/Conversion/Passes.h" // from @llvm-project +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" // from @llvm-project diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_c128.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_c128.cc new file mode 100644 index 00000000000000..c53692de7f3a2b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_c128.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_c64.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_c64.cc new file mode 100644 index 00000000000000..9c3482d6ef5049 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_c64.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, + k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h similarity index 66% rename from tensorflow/compiler/xla/service/cpu/runtime_matmul.cc rename to tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h index 21ca5ed6402578..6acefadba6f5d4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_COMMON_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_COMMON_H_ + +#include #define EIGEN_USE_THREADS @@ -26,9 +29,9 @@ limitations under the License. #include "tensorflow/tsl/framework/contraction/eigen_contraction_kernel.h" #endif -namespace { +namespace xla { -bool Is16BytesAligned(void* ptr) { +static inline bool Is16BytesAligned(void* ptr) { return reinterpret_cast(ptr) % 16 == 0; } @@ -146,59 +149,6 @@ void BatchMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs, batch_size, transpose_lhs, transpose_rhs); } -} // namespace - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( - const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m, - int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( - const void* run_options_ptr, double* out, double* lhs, double* rhs, - int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); -} +} // namespace xla -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64( - const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128( - const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32( - const void* run_options_ptr, int32_t* out, int32_t* lhs, int32_t* rhs, - int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenBatchMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m, - int64_t n, int64_t k, int64_t batch_size, int32_t transpose_lhs, - int32_t transpose_rhs) { - BatchMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - batch_size, transpose_lhs, transpose_rhs); -} +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_COMMON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_f16.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f16.cc new file mode 100644 index 00000000000000..d18516805bb45d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f16.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_f32.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f32.cc new file mode 100644 index 00000000000000..6d84a3ac5c8193 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f32.cc @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m, + int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenBatchMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m, + int64_t n, int64_t k, int64_t batch_size, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::BatchMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + batch_size, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_f64.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f64.cc new file mode 100644 index 00000000000000..1424d17fa5f6e6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f64.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, + int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_s32.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_s32.cc new file mode 100644 index 00000000000000..6c93cb53f5eb31 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_s32.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32( + const void* run_options_ptr, int32_t* out, int32_t* lhs, int32_t* rhs, + int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc deleted file mode 100644 index d5f0b6b93a6258..00000000000000 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" - -#include "absl/base/attributes.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) -#include "tensorflow/tsl/framework/contraction/eigen_contraction_kernel.h" -#endif - -namespace { - -bool Is16BytesAligned(void* ptr) { - return reinterpret_cast(ptr) % 16 == 0; -} - -template -void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64_t m, - int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - int64_t lhs_rows = m; - int64_t lhs_cols = k; - if (transpose_lhs) { - std::swap(lhs_rows, lhs_cols); - } - - int64_t rhs_rows = k; - int64_t rhs_cols = n; - if (transpose_rhs) { - std::swap(rhs_rows, rhs_cols); - } - - const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, - lhs_cols); - const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, - rhs_cols); - Eigen::TensorMap, Alignment> C(out, m, n); - - typedef typename Eigen::Tensor::DimensionPair DimPair; - int lhs_contract_dim = transpose_lhs ? 0 : 1; - int rhs_contract_dim = transpose_rhs ? 1 : 0; - const Eigen::array dims( - {DimPair(lhs_contract_dim, rhs_contract_dim)}); - - // Matrix multiply is a special case of the "contract" operation where - // the contraction is performed along dimension 1 of the lhs and dimension - // 0 of the rhs. - C = A.contract(B, dims); -} - -template -void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, - T* rhs, int64_t m, int64_t n, int64_t k, - int32_t transpose_lhs, - int32_t transpose_rhs) { - bool all_buffers_16b_aligned = - Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); - - if (!all_buffers_16b_aligned) { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); - } - - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -} // namespace - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulF16( - const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, - n, k, transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, - float* out, float* lhs, - float* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, - double* out, double* lhs, - double* rhs, int64_t m, - int64_t n, int64_t k, - int32_t transpose_lhs, - int32_t transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulC64( - const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - SingleThreadedMatMulDispatch>( - run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulC128( - const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - SingleThreadedMatMulDispatch>( - run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr, - int32_t* out, int32_t* lhs, - int32_t* rhs, int64_t m, - int64_t n, int64_t k, - int32_t transpose_lhs, - int32_t transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h index 9473eb7f56fc52..1ac85a4f125404 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ #include +#include #include "third_party/eigen3/Eigen/Core" diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c128.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c128.cc new file mode 100644 index 00000000000000..81199c14daf7f8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c128.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulC128( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch>( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c64.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c64.cc new file mode 100644 index 00000000000000..6a176435912403 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c64.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulC64( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch>( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h new file mode 100644 index 00000000000000..d91d8f5258c71e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_COMMON_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_COMMON_H_ + +#include + +#include "absl/base/attributes.h" +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/tsl/framework/contraction/eigen_contraction_kernel.h" +#endif + +namespace xla { + +static inline bool Is16BytesAligned(void* ptr) { + return reinterpret_cast(ptr) % 16 == 0; +} + +template +void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, + int64_t m, int64_t n, int64_t k, + int32_t transpose_lhs, int32_t transpose_rhs) { + int64_t lhs_rows = m; + int64_t lhs_cols = k; + if (transpose_lhs) { + std::swap(lhs_rows, lhs_cols); + } + + int64_t rhs_rows = k; + int64_t rhs_cols = n; + if (transpose_rhs) { + std::swap(rhs_rows, rhs_cols); + } + + const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, + lhs_cols); + const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, + rhs_cols); + Eigen::TensorMap, Alignment> C(out, m, n); + + typedef typename Eigen::Tensor::DimensionPair DimPair; + int lhs_contract_dim = transpose_lhs ? 0 : 1; + int rhs_contract_dim = transpose_rhs ? 1 : 0; + const Eigen::array dims( + {DimPair(lhs_contract_dim, rhs_contract_dim)}); + + // Matrix multiply is a special case of the "contract" operation where + // the contraction is performed along dimension 1 of the lhs and dimension + // 0 of the rhs. + C = A.contract(B, dims); +} + +template +void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, + T* rhs, int64_t m, int64_t n, int64_t k, + int32_t transpose_lhs, + int32_t transpose_rhs) { + bool all_buffers_16b_aligned = + Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); + + if (!all_buffers_16b_aligned) { + SingleThreadedMatMul( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + } + + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_COMMON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f16.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f16.cc new file mode 100644 index 00000000000000..76a5b93af75e5e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f16.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f32.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f32.cc new file mode 100644 index 00000000000000..6f3271180e9b2a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f32.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, + float* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, + k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f64.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f64.cc new file mode 100644 index 00000000000000..15191c7f151dc8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f64.cc @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64_t m, + int64_t n, int64_t k, + int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_s32.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_s32.cc new file mode 100644 index 00000000000000..cf854e5c8f3527 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_s32.cc @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr, + int32_t* out, int32_t* lhs, + int32_t* rhs, int64_t m, + int64_t n, int64_t k, + int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/dynamic_window_utils.h b/tensorflow/compiler/xla/service/dynamic_window_utils.h index 40e891fad4b59e..11392ea33884e4 100644 --- a/tensorflow/compiler/xla/service/dynamic_window_utils.h +++ b/tensorflow/compiler/xla/service/dynamic_window_utils.h @@ -31,7 +31,7 @@ struct DynamicWindowDims { HloInstruction* output_size; }; -// This mirrors the logic in GetWindowedOutputSizeVerboseV2 but with HLOs as +// This mirrors the logic in GetWindowedOutputSizeVerbose but with HLOs as // inputs and outputs. DynamicWindowDims GetWindowedOutputSize(HloInstruction* input_size, int64_t window_size, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index fbf0aeb320537e..f39efc59ba4397 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -27,7 +27,7 @@ load( "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -79,7 +79,7 @@ cc_library( name = "gpu_executable_run_options", srcs = ["gpu_executable_run_options.cc"], hdrs = ["gpu_executable_run_options.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", @@ -118,14 +118,10 @@ cc_library( hdrs = [ "launch_dimensions.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":gpu_device_info", - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/mlir_hlo:lhlo", - "//tensorflow/tsl/platform:logging", - "@llvm-project//mlir:IR", ], ) @@ -218,7 +214,7 @@ cc_library( name = "target_util", srcs = ["target_util.cc"], hdrs = ["target_util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -250,7 +246,7 @@ cc_library( name = "gpu_device_info", srcs = ["gpu_device_info.cc"], hdrs = ["gpu_device_info.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/stream_executor:device_description_proto_cc", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", @@ -262,7 +258,7 @@ cc_library( testonly = 1, srcs = ["gpu_device_info_for_tests.cc"], hdrs = ["gpu_device_info_for_tests.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":gpu_device_info", ], @@ -439,6 +435,7 @@ cc_library( "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -464,7 +461,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:Transforms", - "@triton//:TritonDialect", + "@triton//:TritonDialects", "@triton//:TritonTransforms", ] + if_cuda_is_configured([ "@triton//:TritonGPUToLLVM", @@ -496,11 +493,15 @@ xla_test( "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", "//tensorflow/compiler/xla/tests:verified_hlo_module", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:path", "//tensorflow/tsl/platform:status_matchers", "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", @@ -568,7 +569,6 @@ cc_library( ":buffer_comparator", ":gemm_rewriter", ":gemm_rewriter_triton", - ":gpu_asm_opts_util", ":gpu_device_info", ":gpu_float_support", ":gpu_fusible", @@ -636,7 +636,7 @@ cc_library( name = "parallel_loop_emitter", srcs = ["parallel_loop_emitter.cc"], hdrs = ["parallel_loop_emitter.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":launch_dimensions", ":target_util", @@ -988,7 +988,8 @@ cc_library( name = "ir_emission_utils", srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), + defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":target_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1004,7 +1005,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Core", "@llvm-project//mlir:ArithDialect", - ], + ] + if_cuda_is_configured([ + ":gpu_asm_opts_util", + "//tensorflow/compiler/xla/stream_executor/gpu:asm_compiler", + ]), ) xla_cc_test( @@ -1027,7 +1031,7 @@ cc_library( name = "cublas_cudnn", srcs = ["cublas_cudnn.cc"], hdrs = ["cublas_cudnn.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/tsl/platform:statusor", @@ -1153,20 +1157,25 @@ cc_library( ":matmul_utils", "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:permutation_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_query", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -1202,6 +1211,7 @@ xla_cc_test( ":gemm_rewriter_triton", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1240,6 +1250,7 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:statusor", ]) + if_cuda_is_configured([ "//tensorflow/compiler/xla/stream_executor/cuda:cublas_lt_header", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", @@ -1325,15 +1336,12 @@ cc_library( "@com_google_absl//absl/types:span", "//tensorflow/compiler/xla:autotune_results_proto_cc", "//tensorflow/compiler/xla:autotuning_proto_cc", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_stream_header", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_timer_header", @@ -1373,7 +1381,7 @@ cc_library( name = "matmul_utils", srcs = ["matmul_utils.cc"], hdrs = ["matmul_utils.h"], - compatible_with = get_compatible_with_cloud(), + # compatible_with = get_compatible_with_cloud(), defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ ":backend_configs_cc", @@ -1510,9 +1518,9 @@ xla_cc_test( ) cc_library( - name = "gpu_conv_algorithm_picker", - srcs = if_gpu_is_configured(["gpu_conv_algorithm_picker.cc"]), - hdrs = if_gpu_is_configured(["gpu_conv_algorithm_picker.h"]), + name = "conv_algorithm_picker", + srcs = if_gpu_is_configured(["conv_algorithm_picker.cc"]), + hdrs = if_gpu_is_configured(["conv_algorithm_picker.h"]), copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured([ "-DTENSORFLOW_USE_ROCM=1", ]), @@ -1549,8 +1557,8 @@ cc_library( ) xla_cc_test( - name = "gpu_conv_algorithm_picker_test", - srcs = if_gpu_is_configured(["gpu_conv_algorithm_picker_test.cc"]), + name = "conv_algorithm_picker_test", + srcs = if_gpu_is_configured(["conv_algorithm_picker_test.cc"]), tags = [ "gpu", "noasan", @@ -1558,7 +1566,7 @@ xla_cc_test( "requires-gpu-sm70", ], deps = [ - ":gpu_conv_algorithm_picker", + ":conv_algorithm_picker", ":gpu_conv_rewriter", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service:pattern_matcher", @@ -1875,6 +1883,7 @@ xla_cc_test( srcs = ["softmax_rewriter_triton_test.cc"], deps = [ ":softmax_rewriter_triton", + "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep @@ -2384,7 +2393,6 @@ cc_library( "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//llvm:TransformUtils", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", @@ -2399,7 +2407,6 @@ cc_library( "//tensorflow/compiler/xla/hlo/transforms:hlo_constant_splitter", "//tensorflow/compiler/xla/mlir/backends/gpu/transforms:passes", "//tensorflow/compiler/xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//tensorflow/compiler/xla/mlir_hlo:transforms_gpu_passes", "//tensorflow/compiler/xla/runtime:jit_executable", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:all_gather_broadcast_reorder", @@ -2489,7 +2496,6 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:device_description_proto_cc_impl", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/compiler/xla/stream_executor/cuda:cuda_platform_id", - "//tensorflow/compiler/xla/stream_executor/rocm:rocm_platform_id", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", "//tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla", @@ -2515,6 +2521,7 @@ xla_cc_test( "//tensorflow/compiler/xla:autotune_results_proto_cc", "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:xla_debug_info_manager", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/tsl/lib/core:status_test_util", @@ -2573,7 +2580,7 @@ cc_library( ":gemm_algorithm_picker", ":gpu_asm_opts_util", ":gpu_compiler", - ":gpu_conv_algorithm_picker", + ":conv_algorithm_picker", ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_executable", @@ -2584,6 +2591,7 @@ cc_library( ":triangular_solve_rewriter", ":triton_autotuner", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:IRReader", @@ -2698,7 +2706,7 @@ cc_library( ":cusolver_rewriter", ":gemm_rewriter", ":gpu_compiler", - ":gpu_conv_algorithm_picker", + ":conv_algorithm_picker", ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_layout_assignment", @@ -2949,6 +2957,7 @@ cc_library( name = "gpu_asm_opts_util", srcs = ["gpu_asm_opts_util.cc"], hdrs = ["gpu_asm_opts_util.h"], + compatible_with = get_compatible_with_portable(), copts = tsl_copts(), deps = [ "//tensorflow/compiler/xla:xla_proto_cc", @@ -2961,11 +2970,10 @@ cc_library( name = "gpu_hlo_cost_analysis", srcs = ["gpu_hlo_cost_analysis.cc"], hdrs = ["gpu_hlo_cost_analysis.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":backend_configs_cc", ":cublas_cudnn", - ":ir_emission_utils", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo_cost_analysis", @@ -3177,12 +3185,14 @@ xla_cc_test( xla_cc_test( name = "cudnn_fused_conv_rewriter_test", srcs = ["cudnn_fused_conv_rewriter_test.cc"], + shard_count = 10, tags = [ "gpu", "no_oss", "noasan", "nomsan", - "requires-gpu-sm70", + # This test runs some fusions that are only supported on Ampere+. + "requires-gpu-sm80", ], deps = [ ":backend_configs_cc", diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index 416027ad998c7e..6365652cb7d1c6 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index 7555b6fcf69e3f..fd4440bcaeefaa 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -41,9 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h" @@ -103,15 +100,36 @@ std::vector ExecutionInputsFromBuffers( } // namespace -StatusOr> +AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, + Compiler* compiler, + se::StreamExecutor& stream_executor, + se::Stream& stream, + se::DeviceMemoryAllocator& allocator, + const DebugOptions& opts) + : config_(config), + compiler_(compiler), + stream_executor_(stream_executor), + stream_(stream), + allocator_(allocator), + opts_(opts) { + // Avoid dumping compilation steps. + opts_.set_xla_dump_to(""); + opts_.set_xla_gpu_dump_autotune_results_to(""); + opts_.set_xla_gpu_load_autotune_results_from(""); + opts_.set_xla_gpu_dump_llvmir(false); + // Avoid using another thread pool. + opts_.set_xla_gpu_force_compilation_parallelism(1); + // Avoid using GPU graphs as we don't want to measure graph construction time. + opts_.set_xla_gpu_cuda_graph_level(0); +} + +StatusOr> AutotunerCompileUtil::GenerateAndProfileExecutable( - const HloComputation& hlo_computation, const AutotuneResult& config, - const AutotuneCacheKey& cache_key, se::Stream* stream, - absl::Span input_buffers, - se::DeviceMemoryBase output_buffer, ExtractModuleFn extractor) { - TF_ASSIGN_OR_RETURN( - Executable * executable, - Compile(hlo_computation, config, cache_key, std::move(extractor))); + const AutotuneResult& config, const AutotuneCacheKey& cache_key, + se::Stream* stream, absl::Span input_buffers, + GenerateModuleFn extractor) { + TF_ASSIGN_OR_RETURN(Executable * executable, + Compile(config, cache_key, std::move(extractor))); if (!executable) { return {std::nullopt}; @@ -133,18 +151,13 @@ AutotunerCompileUtil::GenerateAndProfileExecutable( Execute(*executable, std::move(execution_inputs))); TF_ASSIGN_OR_RETURN(absl::Duration timer_duration, timer.GetElapsedDuration()); - ScopedShapedBuffer result = execution_output.ConsumeResult(); - TF_RET_CHECK(output_buffer.size() == result.root_buffer().size()); - // TODO(cheshire): Copying should not be required. Instead, we can add a new - // aliased parameter. - stream->ThenMemcpy(&output_buffer, result.root_buffer(), - result.root_buffer().size()); - return std::make_optional(timer_duration); + return std::make_optional( + timer_duration, execution_output.Commit().ConsumeResult()); } StatusOr AutotunerCompileUtil::Compile( - const HloComputation& hlo_computation, const AutotuneResult& res, - const AutotuneCacheKey& cache_key, ExtractModuleFn extractor) { + const AutotuneResult& res, const AutotuneCacheKey& cache_key, + GenerateModuleFn extractor) { CompilationKey key{cache_key, res}; { absl::MutexLock lock(&executable_cache_mutex); @@ -156,15 +169,14 @@ StatusOr AutotunerCompileUtil::Compile( } TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - CompileNoCache(hlo_computation, std::move(extractor))); + CompileNoCache(std::move(extractor))); absl::MutexLock lock(&executable_cache_mutex); auto [it, inserted] = executable_cache.emplace(key, std::move(executable)); return it->second.get(); } StatusOr> AutotunerCompileUtil::CompileNoCache( - const HloComputation& original_computation, - ExtractModuleFn module_extractor) { + GenerateModuleFn module_extractor) { StatusOr> new_hlo_module = module_extractor(); if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) { // Incompatible value of split-k is an expected failure. @@ -172,15 +184,33 @@ StatusOr> AutotunerCompileUtil::CompileNoCache( } else if (!new_hlo_module.status().ok()) { return new_hlo_module.status(); } - return RunBackend(original_computation, std::move(*new_hlo_module)); + (*new_hlo_module)->config().set_debug_options(opts_); + + StatusOr> out = compiler_->RunBackend( + std::move(*new_hlo_module), &stream_executor_, + Compiler::CompileOptions{&allocator_, /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*enable_debug_info_manager=*/false}); + if (out.status().code() == absl::StatusCode::kResourceExhausted) { + // Being out of shared memory budget is an expected failure. + return std::unique_ptr(); + } + return out; } -/*static*/ StatusOr AutotunerCompileUtil::Create( - se::Stream& stream, se::DeviceMemoryAllocator& allocator) { - se::StreamExecutor& stream_executor = *stream.parent(); +/*static*/ StatusOr> +AutotunerCompileUtil::Create(const AutotuneConfig& config, + const DebugOptions& opts) { + if (config.IsDeviceless()) { + return std::nullopt; + } + se::StreamExecutor* stream_exec = config.GetExecutor(); + se::DeviceMemoryAllocator* allocator = config.GetAllocator(); + TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream()); TF_ASSIGN_OR_RETURN(Compiler * compiler, - Compiler::GetForPlatform(stream_executor.platform())); - return AutotunerCompileUtil(compiler, stream_executor, stream, allocator); + Compiler::GetForPlatform(stream_exec->platform())); + return AutotunerCompileUtil(config, compiler, *stream_exec, *stream, + *allocator, opts); } StatusOr AutotunerCompileUtil::Execute( @@ -202,29 +232,6 @@ StatusOr AutotunerCompileUtil::Execute( return std::move(output); } -StatusOr> AutotunerCompileUtil::RunBackend( - const HloComputation& original_computation, - std::unique_ptr module) { - DebugOptions options = - original_computation.parent()->config().debug_options(); - // Avoid dumping compilation steps. - options.set_xla_dump_to(""); - options.set_xla_gpu_dump_autotune_results_to(""); - options.set_xla_gpu_load_autotune_results_from(""); - options.set_xla_gpu_dump_llvmir(false); - // Avoid using another thread pool. - options.set_xla_gpu_force_compilation_parallelism(1); - options.set_xla_gpu_enable_xla_runtime_executable(false); - module->config().set_debug_options(options); - StatusOr> out = - compiler_->RunBackend(std::move(module), &stream_executor_, &allocator_); - if (out.status().code() == absl::StatusCode::kResourceExhausted) { - // Being out of shared memory budget is an expected failure. - return std::unique_ptr(); - } - return out; -} - /*static*/ void AutotunerCompileUtil::ClearCompilationCache() { absl::MutexLock lock(&executable_cache_mutex); executable_cache.clear(); diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h index da3f56ab774bde..1ad8ddb5ba9acc 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h @@ -16,19 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ -#include - -#include -#include -#include -#include -#include #include #include -#include -#include -#include -#include #include #include "tensorflow/compiler/xla/autotune_results.pb.h" @@ -47,63 +36,72 @@ namespace gpu { // Autotuning utils which require compiling fusions separately. Requires a // separate target, as runtime autotuning cannot perform compilation. +// +// Uses a global cache, *not* unique per instance. class AutotunerCompileUtil { public: - using ExtractModuleFn = + using GenerateModuleFn = absl::AnyInvocable>()>; - static StatusOr Create( - se::Stream& stream, se::DeviceMemoryAllocator& allocator); - - AutotunerCompileUtil(Compiler* compiler, se::StreamExecutor& stream_executor, - se::Stream& stream, se::DeviceMemoryAllocator& allocator) - : compiler_(compiler), - stream_executor_(stream_executor), - stream_(stream), - allocator_(allocator) {} - - // Runs the compiled executable with the given extractor, cached with - // . Returns std::nullopt on expected failure, bad Status - // otherwise. - // Uses a global cache, *not* unique per instance. - StatusOr> GenerateAndProfileExecutable( - const HloComputation& hlo_computation, const AutotuneResult& config, - const AutotuneCacheKey& cache_key, se::Stream* stream, - absl::Span input_buffers, - se::DeviceMemoryBase output_buffer, ExtractModuleFn extractor); - - // Generic method to compile a given computation in isolation using a given - // pipeline, cached on AutotuneResult and AutotuneCacheKey. + // Generates a compile util for a platform associated with the `stream`. + // + // Returns an empty optional if the AutotuneConfig is deviceless, as + // autotuning is impossible in that case. + static StatusOr> Create( + const AutotuneConfig& config, const DebugOptions& opts); + + struct ProfilingOutput { + ProfilingOutput(absl::Duration duration, ScopedShapedBuffer&& buffer) + : duration(duration), output(std::move(buffer)) {} + + absl::Duration duration; + ScopedShapedBuffer output; + }; + + // Generates an executable first, given the module generator function in + // `extractor`. + // + // Runs the resulting executable with the given extractor, cached with + // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad + // `Status` otherwise. + StatusOr> GenerateAndProfileExecutable( + const AutotuneResult& config, const AutotuneCacheKey& cache_key, + se::Stream* stream, absl::Span input_buffers, + GenerateModuleFn extractor); + + // Generic method to compile a generated module from `extractor` in isolation. // // On *expected* failures we will store an empty unique_ptr in cache. // // Returns: - // - on *expected* failure - // - Executable if everything goes fine. - // - Status on *unexpected* failure. + // - `nullptr` on *expected* failure + // - `Executable` if everything goes fine. + // - `Status` on *unexpected* failure. StatusOr Compile( - const HloComputation& hlo_computation, const AutotuneResult& res, - const AutotuneCacheKey& cache_key, - AutotunerCompileUtil::ExtractModuleFn extractor); + const AutotuneResult& res, const AutotuneCacheKey& cache_key, + AutotunerCompileUtil::GenerateModuleFn extractor); + // Clears the global compilation cache. static void ClearCompilationCache(); private: - StatusOr> RunBackend( - const HloComputation& original_computation, - std::unique_ptr module); + AutotunerCompileUtil(const AutotuneConfig& config, Compiler* compiler, + se::StreamExecutor& stream_executor, se::Stream& stream, + se::DeviceMemoryAllocator& allocator, + const DebugOptions& opts); StatusOr> CompileNoCache( - const HloComputation& original_computation, - AutotunerCompileUtil::ExtractModuleFn module_extractor); + AutotunerCompileUtil::GenerateModuleFn module_extractor); StatusOr Execute(Executable& executable, std::vector arguments); + AutotuneConfig config_; Compiler* compiler_; se::StreamExecutor& stream_executor_; se::Stream& stream_; se::DeviceMemoryAllocator& allocator_; + DebugOptions opts_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_util.cc index 76037c8dd3b75f..69016104aa5b64 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" namespace xla { @@ -259,5 +260,20 @@ AutotunerUtil::ExtractComputationIntoNewModule( return new_hlo_module; } +/*static*/ StatusOr AutotunerUtil::CreateRedzoneAllocator( + const AutotuneConfig& config, const DebugOptions& opts, + se::Stream* force_stream) { + se::Stream* stream = force_stream; + if (stream == nullptr) { + TF_ASSIGN_OR_RETURN(stream, config.GetStream()); + } + return se::RedzoneAllocator( + stream, config.GetAllocator(), PtxOptsFromDebugOptions(opts), + /*memory_limit=*/std::numeric_limits::max(), + /*redzone_size=*/config.should_check_correctness() + ? opts.xla_gpu_redzone_padding_bytes() + : 0); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_util.h index c50ad60e2c3601..f759e1d8f2ca71 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_util.h @@ -43,7 +43,7 @@ struct DeviceConfig { // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. - se::DeviceMemoryAllocator* allocator; // may be null + se::DeviceMemoryAllocator* allocator = nullptr; // may be null }; struct DevicelessConfig { @@ -123,7 +123,13 @@ class AutotuneConfig { se::DeviceMemoryAllocator* GetAllocator() const { CHECK(std::holds_alternative(config_)); - return std::get(config_).allocator; + auto& cf = std::get(config_); + return cf.allocator ? cf.allocator : GetExecutor()->GetAllocator(); + } + + StatusOr GetStream() const { + CHECK(std::holds_alternative(config_)); + return GetAllocator()->GetStream(GetExecutor()->device_ordinal()); } se::CudaComputeCapability GetCudaComputeCapability() const { @@ -159,6 +165,12 @@ struct AutotunerUtil { const HloInstruction* instr, const AutotuneConfig& config, const AutotuneNoCacheFn& autotune_fn); + // Creates a RedzoneAllocator from a given config. If `force_stream` is + // provided, than it is used for checking redzones. + static StatusOr CreateRedzoneAllocator( + const AutotuneConfig& config, const DebugOptions& opts, + se::Stream* force_stream = nullptr); + // Functions to save/load XLA's autotuning results. // // This is used for ahead-of-time autotuning. Specifically: diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index ffabfd56d3bb93..cd2f21a9eb8656 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -39,6 +39,15 @@ message CudnnConvBackendConfig { // is provided, this field must be 0. double side_input_scale = 5; + // The negative slope coefficient alpha for leaky_relu activation, used only + // when activation_mode is kLeakyRelu. + // + // leakyrelu(x) is defined as x > 0 ? x : alpha * x. + // + // Since this is a proto3 proto, leakyrelu_alpha is 0 if not specified (in + // which case the leakyrelu activation is equivalent to relu). + double leakyrelu_alpha = 8; + // If the filter (and bias, if present) have been reordered, set this flag. // It's placed into a `oneof` block to skip the serialization when not set. oneof filter_and_bias_reordering_oneof { diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index b27ffc0fbd6e07..e8771d4c4b2bb8 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -117,6 +117,7 @@ static Status LowerToXlaGpuRuntime(mlir::ModuleOp module, GpuPipelineOpts opts; opts.cuda_graph_level = debug_options.xla_gpu_cuda_graph_level(); + opts.min_graph_size = debug_options.xla_gpu_cuda_graph_min_graph_size(); opts.enable_concurrent_region = debug_options.xla_gpu_cuda_graph_enable_concurrent_region(); populateXlaGpuRuntimePasses(pm, thunk_sequence, opts); @@ -320,18 +321,17 @@ Status CompileModuleToLlvmIrImpl( { HloPassPipeline pipeline("remat-pipeline"); - HloRematerialization::RematerializationSizes sizes; - pipeline.AddPass( + HloRematerialization::Options options( [pointer_size](const Shape& shape) { return GetSizeOfShape(shape, pointer_size); }, // Assume 75% of the total device memory is available for XLA. /*memory_limit_bytes=*/gpu_device_info.device_memory_size * 0.75, - /*sizes=*/&sizes, - HloRematerialization::RematerializationPass::kPostFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, /*compact_shape_function=*/nullptr, HloRematerialization::RematerializationMode::kRecomputeAndCompress); + HloRematerialization::RematerializationSizes sizes; + pipeline.AddPass(options, sizes); TF_ASSIGN_OR_RETURN(bool changed, pipeline.Run(hlo_module)); if (changed) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc similarity index 95% rename from tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc rename to tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc index fb2af10e7b3880..46e87ec03531e8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #include #include @@ -142,10 +142,10 @@ StatusOr> GetAlgorithms( BiasTypeForInputType(input_type), output_type, /* conv_input_scale = */ config.conv_result_scale, /* side_input_scale = */ config.fusion->side_input_scale, - /* leakyrelu_alpha = */ 0.0, stream, config.input_descriptor, - config.filter_descriptor, config.bias_descriptor, - config.output_descriptor, config.conv_desc, use_fallback, - config.fusion->mode, numeric_options, &runners)); + /* leakyrelu_alpha = */ config.fusion->leakyrelu_alpha, stream, + config.input_descriptor, config.filter_descriptor, + config.bias_descriptor, config.output_descriptor, config.conv_desc, + use_fallback, config.fusion->mode, numeric_options, &runners)); for (auto& runner : runners) { TF_ASSIGN_OR_RETURN( auto runner_cache, @@ -262,7 +262,6 @@ void PrintPlatformInfo(const se::Stream* stream) { } } -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) // Returns true if the redzones in `allocator`'s allocations are unmodified. // // If the redzones are modified, logs an error, sets the appropriate failure @@ -306,7 +305,6 @@ StatusOr CheckRedzones(const se::RedzoneAllocator& allocator, PrintPlatformInfo(stream); return false; } -#endif } // anonymous namespace @@ -360,18 +358,9 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( // allocator either points to this->allocator_ or, if that's null, to a // se::StreamExecutorMemoryAllocator for stream_exec. - se::DeviceMemoryAllocator* device_allocator = config_.GetAllocator(); - se::DeviceMemoryAllocator* allocator; - optional se_allocator; - if (device_allocator != nullptr) { - allocator = device_allocator; - } else { - se_allocator.emplace(stream_exec); - allocator = &*se_allocator; - } + se::DeviceMemoryAllocator* allocator = config_.GetAllocator(); - TF_ASSIGN_OR_RETURN(se::Stream* const stream, - allocator->GetStream(stream_exec->device_ordinal())); + TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream()); StatusOr result_or(InternalError("Unknown platform.")); // Check StreamExecutor on which platform it is. ROCm and Cuda implementation // have diverged. Specifically, we need to make sure redzone allocator related @@ -380,15 +369,10 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( result_or = PickBestAlgorithmNoCacheRocm(instr, allocator, stream); } else if (stream_exec->platform_kind() == se::PlatformKind::kCuda) { #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) - // Right now Redzone allocator is available in Cuda target only. - auto hlo_module_config = instr->GetModule()->config(); - se::RedzoneAllocator input_output_allocator( - stream, allocator, - PtxOptsFromDebugOptions(hlo_module_config.debug_options()), - /*memory_limit=*/std::numeric_limits::max(), - ShouldCheckConv(hlo_module_config) - ? hlo_module_config.debug_options().xla_gpu_redzone_padding_bytes() - : 0); + DebugOptions debug_opts = instr->GetModule()->config().debug_options(); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator input_output_allocator, + AutotunerUtil::CreateRedzoneAllocator(config_, debug_opts)); TF_ASSIGN_OR_RETURN( AutotuneRuntimeArguments runtime_arguments, @@ -519,19 +503,12 @@ StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( "Disqualified for implicit RELU."); } - const int64_t rz_space_limit = - runtime_arguments.hlo_module_config.debug_options() - .xla_gpu_redzone_scratch_max_megabytes() * - (1LL << 20); - se::RedzoneAllocator scratch_allocator( - stream, allocator, - PtxOptsFromDebugOptions( - runtime_arguments.hlo_module_config.debug_options()), - /*memory_limit=*/rz_space_limit, - ShouldCheckConv(runtime_arguments.hlo_module_config) - ? runtime_arguments.hlo_module_config.debug_options() - .xla_gpu_redzone_padding_bytes() - : 0); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator scratch_allocator, + AutotunerUtil::CreateRedzoneAllocator( + config_, runtime_arguments.hlo_module_config.debug_options(), + stream)); + se::dnn::ProfileResult profile_result; VLOG(4) << "Trying algorithm " << alg.ToString() << " for " << instr_str; @@ -846,23 +823,20 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmWithAllocatedBuffer( - const GpuConvConfig conv_config, + const AutotuneConfig& config, const GpuConvConfig conv_config, const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, + const DebugOptions& debug_options, const std::vector buffers, const se::DeviceMemoryBase result_buffer) { #if GOOGLE_CUDA Shape output_shape = conv_config.output_shape; HloModuleConfig hlo_module_config; - hlo_module_config.set_debug_options(*debug_options); + hlo_module_config.set_debug_options(debug_options); se::Stream* stream = run_options->stream(); se::DeviceMemoryAllocator* allocator = run_options->allocator(); - se::RedzoneAllocator input_output_allocator( - stream, allocator, PtxOptsFromDebugOptions(*debug_options), - /*memory_limit=*/std::numeric_limits::max(), - ShouldCheckConv(hlo_module_config) - ? debug_options->xla_gpu_redzone_padding_bytes() - : 0); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator input_output_allocator, + AutotunerUtil::CreateRedzoneAllocator(config, debug_options, stream)); GpuConvAlgorithmPicker::AutotuneRuntimeArguments autotune_runtime_arguments = {output_shape, hlo_module_config, buffers, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h similarity index 94% rename from tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h rename to tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h index 4b903e228755c2..0d265a2e7f680b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ #include #include @@ -96,8 +96,9 @@ class GpuConvAlgorithmPicker : public HloModulePass { // Run autotuning on allocated buffers and pick the best algorithm. StatusOr PickBestAlgorithmWithAllocatedBuffer( - GpuConvConfig conv_config, const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, + const AutotuneConfig& config, GpuConvConfig conv_config, + const ServiceExecutableRunOptions* run_options, + const DebugOptions& debug_options, std::vector buffers, se::DeviceMemoryBase result_buffer); @@ -162,4 +163,4 @@ class GpuConvAlgorithmPicker : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker_test.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc similarity index 96% rename from tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker_test.cc rename to tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc index 3a9c0ad4919258..73b2b6a8d52e2a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -54,7 +54,8 @@ ENTRY main { bool changed = false; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(), m.get())); changed = false; - DebugOptions opts; + DebugOptions opts = DefaultDebugOptionsIgnoringFlags(); + AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts}; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvAlgorithmPicker(cfg), m.get())); diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc index 7abb4848a187c3..78ea68419d5367 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h" #if GOOGLE_CUDA || TF_HIPBLASLT + +#include #include #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" @@ -32,7 +34,8 @@ namespace xla { namespace gpu { CublasLtMatmulThunk::CublasLtMatmulThunk( - ThunkInfo thunk_info, cublas_lt::MatmulPlan plan, int64_t algorithm_idx, + ThunkInfo thunk_info, GemmConfig gemm_config, + se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, BufferAllocation::Slice bias_buffer, BufferAllocation::Slice aux_buffer, @@ -40,7 +43,8 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( BufferAllocation::Slice c_scale, BufferAllocation::Slice d_scale, BufferAllocation::Slice d_amax) : Thunk(Kind::kCublasLtMatmul, thunk_info), - plan_(std::move(plan)), + gemm_config_(std::move(gemm_config)), + epilogue_(epilogue), algorithm_idx_(algorithm_idx), a_buffer_(a_buffer), b_buffer_(b_buffer), @@ -55,10 +59,12 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( d_amax_buffer_(d_amax) {} Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { + TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan * plan, + GetMatmulPlan(params.stream)); if (!algorithm_) { TF_ASSIGN_OR_RETURN( std::vector algorithms, - plan_.GetAlgorithms(params.stream)); + plan->GetAlgorithms(params.stream)); TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); algorithm_ = algorithms[algorithm_idx_]; } @@ -93,14 +99,29 @@ Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { se::OwningScratchAllocator<> scratch_allocator(allocs.device_ordinal(), allocs.memory_allocator()); - return plan_.ExecuteOnStream( + return plan->ExecuteOnStream( params.stream, allocs.GetDeviceAddress(a_buffer_), allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, *algorithm_, scratch_allocator); } +StatusOr CublasLtMatmulThunk::GetMatmulPlan( + const stream_executor::Stream* stream) { + absl::MutexLock lock(&matmul_plans_cache_mutex_); + auto it = matmul_plans_cache_.find(stream); + if (it == matmul_plans_cache_.end()) { + TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan, + cublas_lt::MatmulPlan::From(gemm_config_, epilogue_)); + it = matmul_plans_cache_ + .insert({stream, + std::make_unique(std::move(plan))}) + .first; + } + return it->second.get(); +} + } // namespace gpu } // namespace xla -#endif \ No newline at end of file +#endif diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h index da4a7bb47019b1..662e677426a54d 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h @@ -21,6 +21,7 @@ limitations under the License. #endif #if GOOGLE_CUDA || TF_HIPBLASLT +#include #include #include @@ -33,13 +34,15 @@ limitations under the License. #else #include "tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h" #endif +#include "tensorflow/tsl/platform/statusor.h" namespace xla { namespace gpu { class CublasLtMatmulThunk : public Thunk { public: - CublasLtMatmulThunk(ThunkInfo thunk_info, cublas_lt::MatmulPlan plan, + CublasLtMatmulThunk(ThunkInfo thunk_info, GemmConfig gemm_config, + se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, @@ -55,7 +58,16 @@ class CublasLtMatmulThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - cublas_lt::MatmulPlan plan_; + StatusOr GetMatmulPlan( + const stream_executor::Stream* stream); + + absl::Mutex matmul_plans_cache_mutex_; + absl::flat_hash_map> + matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); + + GemmConfig gemm_config_; + se::gpu::BlasLt::Epilogue epilogue_; int64_t algorithm_idx_; BufferAllocation::Slice a_buffer_; BufferAllocation::Slice b_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 1041dd4de5f5b9..8bf17cca183d0f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -59,17 +59,24 @@ bool IsConvDepthwise(const HloInstruction* instr) { return input_feature_count == feature_group_count; } +// We don't want to upgrade depthwise convolutions to ConvBiasActivation, +// because the fused CUDNN functions are slower for some of those. bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) { return IsConvCustomCall(instr) && !IsConvDepthwise(instr); } -bool IsExponentialMinusOne(const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kExpm1; -} - -bool HasThreeUsers(const HloInstruction* instr) { - int64_t user_count = instr->user_count(); - return user_count == 3; +// elu, relu6, and leaky-relu activations are supported in cudnn via the +// "runtime fusion" engine, which JIT compiles C++ code. This can be slow to +// compile, so we guard it with a debug option. +// +// nvidia currently recommends that we enable this only on Ampere+, but we've +// tested on Turing (sm75) and it seems to work fine. +// +// Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default +// due to apparent bugs in cudnn 8.9.0. See debug_options_flags.cc for details. +bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts, + se::CudaComputeCapability cc) { + return debug_opts.xla_gpu_use_runtime_fusion() && cc.IsAtLeast(7, 5); } // Can instr be converted to type `dst_ty` without losing any precision? For @@ -247,8 +254,6 @@ StatusOr FuseConvAlpha(HloComputation* comp) { HloInstruction* gte = nullptr; HloInstruction* alpha = nullptr; - // We don't want to upgrade depthwise convolutions to ConvBiasActivation, - // because the fused CUDNN functions are slower for some of those. auto pattern = m::MultiplyAnyOrder( m::GetTupleElement( >e, m::Op(&conv).WithPredicate(IsNonDepthwiseConvCustomCall), 0) @@ -298,8 +303,6 @@ StatusOr FuseBiasOrSideInput(HloComputation* comp) { HloInstruction* gte = nullptr; HloInstruction* addend = nullptr; - // We don't want to upgrade depthwise convolutions to ConvBiasActivation, - // because the fused CUDNN functions are slower for some of those. auto pattern = m::AddAnyOrder( m::GetTupleElement(>e, m::Op(&conv) @@ -510,43 +513,42 @@ StatusOr FuseSideInputAlpha(HloComputation* comp) { } StatusOr FuseElu(HloComputation* comp, se::CudaComputeCapability cc) { + if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), + cc)) { + return false; + } + bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { - const DebugOptions& debug_options = - instr->GetModule()->config().debug_options(); - if (!debug_options.xla_gpu_use_runtime_fusion() || - !cc.IsAtLeast(se::CudaComputeCapability::AMPERE)) { - return false; - } - - HloInstruction* gte; + HloInstruction *gte1, *gte2, *gte3; HloInstruction* conv; HloInstruction* expm1; - // In Elu computation, the GetTupleElement node will have three users: - // Compare, ExponentialMinusOnem, and Select. - // We don't want to upgrade depthwise convolutions to ConvBiasActivation, - // because the fused CUDNN functions are slower for some of those. - auto gte_pattern = - m::GetTupleElement(>e, - m::Op(&conv) - .WithPredicate(IsNonDepthwiseConvCustomCall) - .WithOneUse()) - .WithElementType(F16) - .WithPredicate(HasThreeUsers); if (!Match(instr, - m::Select(m::Compare(gte_pattern, + m::Select(m::Compare(m::GetTupleElement(>e1, m::Op()), m::Broadcast(m::ConstantEffectiveScalar(0))) .WithComparisonDirection(ComparisonDirection::kGt) .WithOneUse(), - gte_pattern, + m::GetTupleElement( + >e2, + m::Op(&conv) + .WithPredicate(IsNonDepthwiseConvCustomCall) + .WithOneUse(), + /*tuple_index=*/0) + // TODO(jlebar): Why only fp16? + .WithElementType(F16), m::Op(&expm1) - .WithPredicate(IsExponentialMinusOne) - .WithOperand(0, gte_pattern) + .WithOpcode(HloOpcode::kExpm1) + .WithOperand(0, m::GetTupleElement(>e3, m::Op())) .WithOneUse()))) { continue; } + // The three GTEs should be the same, and these should be the only uses. + if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) { + continue; + } + // In some cases, the XLA optimizes the inputs of the convolution by // moving and broadcasting the bias to the side input, e.g., when the input // spatial dimensions are all ones and filter spatial dimentsions are all @@ -584,7 +586,7 @@ StatusOr FuseElu(HloComputation* comp, se::CudaComputeCapability cc) { TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kElu); TF_RETURN_IF_ERROR(conv->set_backend_config(config)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); changed = true; } return changed; @@ -595,8 +597,6 @@ StatusOr FuseRelu(HloComputation* comp) { for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* gte; HloInstruction* conv; - // We don't want to upgrade depthwise convolutions to ConvBiasActivation, - // because the fused CUDNN functions are slower for some of those. if (!Match(instr, m::MaximumAnyOrder( m::Broadcast(m::ConstantEffectiveScalar(0)), @@ -627,6 +627,115 @@ StatusOr FuseRelu(HloComputation* comp) { return changed; } +StatusOr FuseRelu6(HloComputation* comp, se::CudaComputeCapability cc) { + if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), + cc)) { + return false; + } + + bool changed = false; + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { + HloInstruction *gte, *conv; + if (!Match( + instr, + m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)), + m::GetTupleElement( + >e, m::Op(&conv) + .WithPredicate(IsNonDepthwiseConvCustomCall) + .WithOneUse()) + // TODO(jlebar): Why only fp16? + .WithElementType(F16) + .WithOneUse(), + m::Broadcast(m::ConstantEffectiveScalar(6))))) { + continue; + } + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config()); + if (config.activation_mode() != se::dnn::kNone) { + continue; + } + + // cudnn runtime fusions seem to be very slow when a side input is present. + // TODO(kaixih@nvidia): remove this check when cuDNN fixes it. + if (conv->operands().size() > 3) { + continue; + } + + if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + return absl::StrCat("FuseRelu6: ", conv->ToString()); + })) { + continue; + } + TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + config.set_activation_mode(se::dnn::kRelu6); + TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); + changed = true; + } + return changed; +} + +StatusOr FuseLeakyRelu(HloComputation* comp, + se::CudaComputeCapability cc) { + if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), + cc)) { + return false; + } + + bool changed = false; + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { + HloInstruction *gte1, *gte2, *gte3, *conv, *alpha; + if (!Match(instr, + m::Select( + m::Compare(m::GetTupleElement(>e1, m::Op()), + m::Broadcast(m::ConstantEffectiveScalar(0))) + .WithComparisonDirection(ComparisonDirection::kGt) + .WithOneUse(), + m::GetTupleElement( + >e2, m::Op(&conv) + .WithPredicate(IsNonDepthwiseConvCustomCall) + .WithOneUse()) + // TODO(jlebar): Why only fp16? + .WithElementType(F16), + m::Multiply(m::GetTupleElement(>e3, m::Op()), + m::Broadcast(m::ConstantEffectiveScalar(&alpha))) + .WithOneUse()))) { + continue; + } + + // The three GTEs should be the same, and these should be the only uses. + if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) { + continue; + } + + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config()); + if (config.activation_mode() != se::dnn::kNone) { + continue; + } + + // cudnn runtime fusions seem to be very slow when a side input is present. + // TODO(kaixih@nvidia): remove this check when cuDNN fixes it. + if (conv->operands().size() > 3) { + continue; + } + + if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + return absl::StrCat("FuseLeakyRelu: ", conv->ToString()); + })) { + continue; + } + TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + config.set_activation_mode(se::dnn::kLeakyRelu); + TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); + config.set_leakyrelu_alpha(alpha_f64.GetFirstElement()); + TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); + changed = true; + } + return changed; +} + StatusOr FuseConvertToF16(HloComputation* comp) { bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { @@ -934,6 +1043,10 @@ StatusOr CudnnFusedConvRewriter::Run( any_changed |= changed; TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_)); any_changed |= changed; + TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_)); + any_changed |= changed; + TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_)); + any_changed |= changed; TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp)); any_changed |= changed; @@ -953,6 +1066,10 @@ StatusOr CudnnFusedConvRewriter::Run( any_changed |= changed; TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_)); any_changed |= changed; + TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_)); + any_changed |= changed; + TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_)); + any_changed |= changed; // Check that we don't have any convs outputting integer types other than // s8 - cudnn does not support these. They should have been transformed to diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 1ce89568c4948e..0ac90d1ca480aa 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -80,6 +80,7 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { HloModuleConfig config = GetModuleConfigForTest(); DebugOptions debug_opts = config.debug_options(); debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions"); + debug_opts.set_xla_gpu_use_runtime_fusion(true); config.set_debug_options(debug_opts); auto result = backend().compiler()->RunHloPasses( @@ -215,11 +216,6 @@ TEST_F(CudnnFusedConvRewriterTest, DontFuseBiasWithDepthwiseConv) { } TEST_F(CudnnFusedConvRewriterTest, TestElu) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Conv-Bias-Elu fusion is supported and recommended with " - "the Nvidia Ampere+ GPUs."; - } // sum = conv(x, w) + bias // select(compare(sum, 0, GT), sum, exponential-minus-one(sum)); TestMatchWithAllTypes(R"( @@ -243,12 +239,6 @@ TEST_F(CudnnFusedConvRewriterTest, TestElu) { } TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Conv-Bias-Elu fusion is supported and recommended with " - "the Nvidia Ampere+ GPUs."; - } - // sum = conv(x, w) + bias // select(compare(sum, 0, GT), sum, exponential-minus-one(sum)); TestNotMatchWithAllTypes(R"( @@ -271,6 +261,59 @@ TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) { })"); } +TEST_F(CudnnFusedConvRewriterTest, TestRelu6) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with " + "the Nvidia Ampere+ GPUs."; + } + // sum = conv(x, w) + bias + // clamp(0, sum, 6); + TestMatchWithAllTypes(R"( + HloModule Test + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + six = TYPE[] constant(6) + sixes = TYPE[1,3,3,64] broadcast(six), dimensions={} + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + sum = TYPE[1,3,3,64] add(conv, broadcasted_bias) + ROOT relu6 = TYPE[1,3,3,64] clamp(zeros, sum, sixes) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() + << "Conv-Bias-LeakyRelu fusion is supported and recommended with " + "the Nvidia Ampere+ GPUs."; + } + // sum = conv(x, w) + bias + // select(compare(sum, 0, GT), sum, multiply(sum, alpha)); + TestMatchWithAllTypes(R"( + HloModule Test + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha = TYPE[] constant(0.2) + alphas = TYPE[1,3,3,64] broadcast(alpha), dimensions={} + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + sum = TYPE[1,3,3,64] add(conv, broadcasted_bias) + cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT + mul = TYPE[1,3,3,64] multiply(sum, alphas) + ROOT elu = TYPE[1,3,3,64] select(cmp, sum, mul) + })"); +} + TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) { // max(0, conv(x, w) + side_input); TestMatchWithAllTypes(R"( @@ -931,11 +974,6 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) { } TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Conv-Bias-Elu fusion is supported and recommended with " - "the Nvidia Ampere+ GPUs."; - } const std::string module_str = R"( HloModule Test @@ -955,10 +993,14 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { ROOT elu = select(cmp, sum, expm1) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + // elu fusion is only active on Ampere+. + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -998,6 +1040,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { ROOT root = tuple(elu, not_elu) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); @@ -1028,6 +1073,190 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } +TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { + const std::string module_str = R"( + HloModule Test + ENTRY Test { + inputs = f16[1,17,9,9] parameter(0) + filters = f16[3,3,17,32] parameter(1) + bias = f16[32] parameter(2) + bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1} + zero = f16[] constant(0) + zeros = f16[1,32,9,9] broadcast(zero), dimensions={} + sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={} + conv = f16[1,32,9,9] convolution(inputs, filters), + window={size=3x3 pad=1_1x1_1}, + dim_labels=bf01_01io->bf01 + sum = add(conv, bias_broadcast) + ROOT relu = clamp(zeros, sum, sixes) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); + + GpuConvRewriter rewriter; + TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); + // relu6 fusion is only enabled on Ampere+. + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; + TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); + SCOPED_TRACE(m->ToString()); + const HloInstruction* conv; + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::GetTupleElement( + m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}, + m::Parameter(0), m::Parameter(1), m::Parameter(2)), + 0) + .WithShape(F16, {1, 32, 9, 9}))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + conv->backend_config()); + EXPECT_EQ(config.activation_mode(), se::dnn::kRelu6); +} + +TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { + const std::string module_str = R"( + HloModule Test + ENTRY Test { + inputs = f16[1,17,9,9] parameter(0) + filters = f16[3,3,17,32] parameter(1) + bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1} + zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={} + sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={} + conv = f16[1,32,9,9] convolution(inputs, filters), + window={size=3x3 pad=1_1x1_1}, + dim_labels=bf01_01io->bf01 + sum = add(conv, bias) + relu = clamp(zeros, sum, sixes) + not_relu = minimum(sum, zeros) + ROOT root = tuple(relu, not_relu) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); + + GpuConvRewriter rewriter; + TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); + CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); + + SCOPED_TRACE(m->ToString()); + const HloInstruction* conv; + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)), + m::GetTupleElement( + m::CustomCall( + &conv, {kCudnnConvBiasActivationForwardCallTarget}, + m::Parameter(0), m::Parameter(1), m::Parameter(2)), + 0) + .WithShape(F16, {1, 32, 9, 9}), + m::Broadcast(m::ConstantEffectiveScalar(6))), + m::Minimum()))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + conv->backend_config()); + EXPECT_EQ(config.activation_mode(), se::dnn::kNone); +} + +TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { + const std::string module_str = R"( + HloModule Test + ENTRY Test { + inputs = f16[1,16,9,9] parameter(0) + filters = f16[3,3,16,32] parameter(1) + bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1} + zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={} + alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={} + conv = f16[1,32,9,9] convolution(inputs, filters), + window={size=3x3 pad=1_1x1_1}, + dim_labels=bf01_01io->bf01 + sum = add(conv, bias) + cmp = compare(sum, zeros), direction=GT + mul = multiply(sum, alphas) + ROOT leaky_relu = select(cmp, sum, mul) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); + + GpuConvRewriter rewriter; + TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); + // Leaky-relu fusion is only enabled on Ampere+. + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; + TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); + + SCOPED_TRACE(m->ToString()); + const HloInstruction* conv; + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::GetTupleElement( + m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}, + m::Parameter(0), m::Parameter(1), m::Parameter(2)), + 0) + .WithShape(F16, {1, 32, 9, 9}))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + conv->backend_config()); + EXPECT_EQ(config.activation_mode(), se::dnn::kLeakyRelu); +} + +TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { + const std::string module_str = R"( + HloModule Test + ENTRY Test { + inputs = f16[1,16,9,9] parameter(0) + filters = f16[3,3,16,32] parameter(1) + bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1} + zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={} + alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={} + conv = f16[1,32,9,9] convolution(inputs, filters), + window={size=3x3 pad=1_1x1_1}, + dim_labels=bf01_01io->bf01 + sum = add(conv, bias) + cmp = compare(sum, zeros), direction=GT + mul = multiply(sum, alphas) + leaky_relu = select(cmp, sum, mul) + not_leaky_relu = minimum(sum, zeros) + ROOT root = tuple(leaky_relu, not_leaky_relu) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); + + GpuConvRewriter rewriter; + TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); + CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); + + SCOPED_TRACE(m->ToString()); + const HloInstruction* conv; + auto gte_pattern = + m::GetTupleElement( + m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}, + m::Parameter(0), m::Parameter(1), m::Parameter(2)), + 0) + .WithShape(F16, {1, 32, 9, 9}); + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Select(m::Compare(gte_pattern, + m::Broadcast(m::ConstantEffectiveScalar(0))) + .WithComparisonDirection(ComparisonDirection::kGt) + .WithOneUse(), + gte_pattern, + m::Multiply(gte_pattern, + m::Broadcast(m::ConstantEffectiveScalar()))), + m::Minimum()))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + conv->backend_config()); + EXPECT_EQ(config.activation_mode(), se::dnn::kNone); +} + TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) { const std::string module_str = R"( HloModule Test diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_mha_rewriter.cc index f2b7d0a8ac85ad..49bbbf562e91a2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/statusor.h" diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 4f4fe23d379d4f..98a01d2824694f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -53,22 +53,6 @@ limitations under the License. namespace xla { namespace gpu { -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -static se::RedzoneAllocator CreateRedzoneAllocator( - se::Stream* stream, se::DeviceMemoryAllocator* allocator, - const DebugOptions& debug_options, const AutotuneConfig& config) { - // TODO(jlebar): The memory limit here should by rights be - // debug_options.xla_gpu_redzone_scratch_max_megabytes(), but tests OOM when - // we do that. Are the tests wrong, or is the option named incorrectly? - return se::RedzoneAllocator( - stream, allocator, PtxOptsFromDebugOptions(debug_options), - /*memory_limit=*/std::numeric_limits::max(), - /*redzone_size=*/config.should_check_correctness() - ? debug_options.xla_gpu_redzone_padding_bytes() - : 0); -} -#endif - // Returns the index (into `algorithms`) of the fastest algorithm. template StatusOr GetBestAlgorithm( @@ -244,13 +228,8 @@ StatusOr DoGemmAutotuneNoCache( } VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); - se::StreamExecutor* executor = autotune_config.GetExecutor(); se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator(); - if (allocator == nullptr) { - allocator = executor->GetAllocator(); - } - TF_ASSIGN_OR_RETURN(se::Stream* const stream, - allocator->GetStream(executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream()); GemmBackendConfig gemm_config = gemm->backend_config().value(); const DebugOptions& debug_options = @@ -261,8 +240,9 @@ StatusOr DoGemmAutotuneNoCache( // Don't run autotuning concurrently on the same GPU. absl::MutexLock gpu_lock(&GetGpuMutex(stream->parent())); - se::RedzoneAllocator buffer_allocator = - CreateRedzoneAllocator(stream, allocator, debug_options, autotune_config); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator buffer_allocator, + AutotunerUtil::CreateRedzoneAllocator(autotune_config, debug_options)); int64_t rng_state = 0; TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 6b28352ccd61ab..97c54e52a91626 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -22,12 +22,15 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/autotuning.pb.h" @@ -37,18 +40,23 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/cublas_padding_requirements.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -57,6 +65,25 @@ limitations under the License. namespace xla { namespace gpu { + +bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { + for (int dim = 0; dim < TensorIterationSpec::kMaxDimsPerTensor; ++dim) { + if (dim_iteration_specs_[dim].size() != other[dim].size()) { + return false; + } + for (int fragment = 0; fragment < dim_iteration_specs_[dim].size(); + ++fragment) { + if (dim_iteration_specs_[dim][fragment].stride != + other[dim][fragment].stride || + dim_iteration_specs_[dim][fragment].count != + other[dim][fragment].count) { + return false; + } + } + } + return true; +} + namespace { // Batch dimensions of an operand of a dot instruction. @@ -95,10 +122,10 @@ int64_t NonContractingDimensionIndex(const HloInstruction& dot, } // Data types that are tested to work in the triton GEMM emitter. -bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { +bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { auto cuda_compute_capability = std::get(gpu_version); - switch (t) { + switch (type) { case PRED: case S8: case S16: @@ -114,21 +141,34 @@ bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { } } -Status RequireTritonFusibleConvert(const HloInstruction* input, - GpuVersion gpu_version) { - if (!IsSupportedDataType(input->operand(0)->shape().element_type(), - gpu_version)) { - return Unimplemented("unsupported data type"); +// Tells if f(a+b) == f(a) + f(b). +bool IsDistributiveOverAddition(const HloInstruction& hlo) { + // The list is most likely incomplete. + // For example division can be added too but only for operand #0. + if (hlo.opcode() == HloOpcode::kMultiply || + hlo.opcode() == HloOpcode::kNegate || + hlo.opcode() == HloOpcode::kBitcast || + hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kCopy || + hlo.opcode() == HloOpcode::kTranspose || + hlo.opcode() == HloOpcode::kConvert || + hlo.opcode() == HloOpcode::kBroadcast || + hlo.opcode() == HloOpcode::kSlice) { + return true; } + return false; +} + +FusionDecision RequireTritonFusibleConvert(const HloInstruction* input, + GpuVersion gpu_version) { // TODO(b/266862494): Can pick up almost any // convert, but if it's reducing the data volume it should rather be fused // to the output of the producer kernel. However not all operations support // output fusion - then it should be fused here anyway! if (ShapeUtil::ByteSizeOf(input->operand(0)->shape()) > ShapeUtil::ByteSizeOf(input->shape())) { - return FailedPrecondition("narrowing conversion"); + return "Narrowing conversion."; } - return OkStatus(); + return FusionDecision{}; } // Handles numbers of dimensions of a target HLO instruction @@ -142,6 +182,13 @@ class DimensionOrder { int64_t target_dim_number; int subdim_number; int64_t size; + bool operator==(const DimDescription& other) const { + return target_dim_number == other.target_dim_number && + subdim_number == other.subdim_number && size == other.size; + } + std::string ToString() const { + return absl::StrCat(target_dim_number, ":", subdim_number, ":", size); + } }; // Sequence describing all dimensions of HLO's output shape // in layout minor-to-major (physical) order. @@ -152,9 +199,12 @@ class DimensionOrder { // `hlo` is currently supposed to be an operand of dot(); // dimension indices describing the operand // are stored along with the dimension order for later analysis. - explicit DimensionOrder(const HloInstruction* hlo, - const int64_t splittable_dimension_index = -1) - : splittable_dimension_index_(splittable_dimension_index) { + explicit DimensionOrder( + const HloInstruction* hlo, const int64_t splittable_dimension_index = -1, + const int64_t splittable_dimension_supported_major_size = 0) + : splittable_dimension_index_(splittable_dimension_index), + splittable_dimension_supported_major_part_size_( + splittable_dimension_supported_major_size) { dim_order_.reserve(hlo->shape().rank()); for (const int64_t i : hlo->shape().layout().minor_to_major()) { dim_order_.push_back({i, 0, hlo->shape().dimensions(i)}); @@ -167,38 +217,42 @@ class DimensionOrder { int operand_number, int64_t split_k = 1); // Create dimension order describing dot's output. - static DimensionOrder FromDotOutput(const HloInstruction& dot); + static DimensionOrder FromDotOutput( + const HloInstruction& dot, int64_t split_k = 1, + int64_t splittable_dimension_supported_major_part_size = 0); + + enum class TransformDirection { kInputToOutput, kOutputToInput }; - // Transforms the DimensionOrder so that from a description of the output - // of `hlo` it becomes a description of the input of `hlo`. - Status HandleInstruction(const HloInstruction* hlo) { + // Transforms the DimensionOrder so that from a description one side + // of `hlo` it becomes a description of the other side of `hlo`. + FusionDecision HandleInstruction(const HloInstruction* hlo, + TransformDirection direction) { VLOG(7) << hlo->ToString(); - if (hlo->opcode() == HloOpcode::kParameter) { - return OkStatus(); + if (hlo->opcode() == HloOpcode::kParameter || + hlo_query::IsScalarConstant(hlo)) { + return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kTranspose || hlo->opcode() == HloOpcode::kCopy) { - return HandleCopyOrTranspose(hlo); + return HandleCopyOrTransposeOrBroadcast(hlo, direction); + } else if (hlo->opcode() == HloOpcode::kBroadcast) { + if (direction != TransformDirection::kOutputToInput) { + return "Unsupported broadcast direction."; + } + return HandleCopyOrTransposeOrBroadcast(hlo, direction); } else if (hlo->operand_count() > 0 && IsTritonSupportedElementwise( hlo->opcode(), hlo->operand(0)->shape().element_type())) { - return OkStatus(); + return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kBitcast) { - return HandleBitcast(hlo); + return HandleBitcast(hlo, direction); } else if (hlo->opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape())) { - return Unimplemented("Non-bitcast reshape."); + return "Non-bitcast reshape."; } - return HandleBitcast(hlo); - } else if (hlo_query::IsScalarConstant(hlo) || - hlo_query::IsBroadcastOfScalarConstant(*hlo)) { - // Dimension order collapses on a scalar, for simplicity leave it equal - // to the output one for now. - return OkStatus(); - } else { - return Unimplemented("Instruction: %s", hlo->ToString()); + return HandleBitcast(hlo, direction); } - return OkStatus(); + return "Unimplemented instruction."; } // Get the raw data of the dimension order. @@ -210,20 +264,41 @@ class DimensionOrder { return splittable_dimension_index_; } + // Tells whether `size` major part of a dimension can be physically split. + bool IsSupportedSplittableDimensionMajorPartSize(int64_t size) const { + // 0 means no specific size requirement. + return splittable_dimension_supported_major_part_size_ == 0 || + splittable_dimension_supported_major_part_size_ == size; + } + + // Tells that two dimension orders describe the same tensor physical layout. + bool IsPhysicallyEquivalent(const DimensionOrder& other) const; + + std::string ToString() const { + return absl::StrJoin(dim_order_, "-", + [](std::string* out, const DimDescription& d) { + absl::StrAppend(out, d.ToString()); + }); + } + private: // See HandleInstruction() for the general description of Handle*(). - Status HandleBitcast(const HloInstruction* hlo); - Status HandleCopyOrTranspose(const HloInstruction* hlo); + FusionDecision HandleBitcast(const HloInstruction*, TransformDirection); + FusionDecision HandleCopyOrTransposeOrBroadcast(const HloInstruction*, + TransformDirection); DimOrderVector dim_order_; - int64_t splittable_dimension_index_; + const int64_t splittable_dimension_index_; + const int64_t splittable_dimension_supported_major_part_size_; }; -DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( +using DimIterationSpec = TensorIterationSpec::DimIterationSpec; + +TensorIterationSpec DimensionOrderToTensorIterationSpec( const DimensionOrder& order) { const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); - DotFusionAnalysis::TensorIterationSpec tensor_spec; + TensorIterationSpec tensor_spec; int64_t accumulated_stride = 1; for (int dim_order_index = 0; dim_order_index < dim_order_vector.size(); ++dim_order_index) { @@ -236,8 +311,7 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( continue; } - DotFusionAnalysis::DimIterationSpec& dim_spec = - tensor_spec[dim.target_dim_number]; + DimIterationSpec& dim_spec = tensor_spec[dim.target_dim_number]; if (dim_order_index > 0 && dim_order_vector[dim_order_index - 1].target_dim_number == dim.target_dim_number) { @@ -257,7 +331,7 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( accumulated_stride *= dim.size; } // Create all absent dimensions as degenerate ones to simplify later queries. - for (DotFusionAnalysis::DimIterationSpec& dim_spec : tensor_spec) { + for (DimIterationSpec& dim_spec : tensor_spec) { if (dim_spec.empty()) { dim_spec.push_back({/*stride=*/0, /*count=*/1, /*subfragments=*/{1}}); } @@ -265,6 +339,11 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( return tensor_spec; } +bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { + return DimensionOrderToTensorIterationSpec(*this) == + DimensionOrderToTensorIterationSpec(other); +} + DimensionOrder DimensionOrder::FromDotOperand(const HloInstruction& dot, const int operand_number, const int64_t split_k) { @@ -283,104 +362,124 @@ DimensionOrder DimensionOrder::FromDotOperand(const HloInstruction& dot, return DimensionOrder(operand); } -DimensionOrder DimensionOrder::FromDotOutput(const HloInstruction& dot) { - return DimensionOrder(&dot); +DimensionOrder DimensionOrder::FromDotOutput( + const HloInstruction& dot, const int64_t split_k, + const int64_t splittable_dimension_supported_major_part_size) { + // Allow non-contracting dimension originating from LHS to split if + // this dimension is split at the output at the same ratio as + // at the input. + int64_t splittable_dimension_index = -1; + if (splittable_dimension_supported_major_part_size > 1) { + // Split-K dimension is the first one in the output if present; + // LHS non-contracting follows (batch is absent in this case). + splittable_dimension_index = (split_k > 1) ? 1 : 0; + } + return DimensionOrder(&dot, splittable_dimension_index, + splittable_dimension_supported_major_part_size); } -Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { - const Shape& operand_shape = hlo->operand(0)->shape(); - DimOrderVector operand_dim_order; - operand_dim_order.reserve(dim_order_.size()); - // Size of not yet assigned part of current operand dimension. - int64_t operand_remaining_size = 1; - // Iterate in parallel over output dimension order and operand dimensions +FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo, + TransformDirection direction) { + const Shape& target_shape = (direction == TransformDirection::kOutputToInput) + ? hlo->operand(0)->shape() + : hlo->shape(); + DimOrderVector target_dim_order; + target_dim_order.reserve(dim_order_.size()); + // Size of not yet assigned part of current target dimension. + int64_t target_remaining_size = 1; + // Iterate in parallel over source dimension order and target dimensions // in minor_to_major order. Find groups of dimensions of equal size - // and project the output dimension order onto the operand. - auto operand_dim_iter = operand_shape.layout().minor_to_major().cbegin(); - for (auto out_dim = dim_order_.cbegin(); out_dim != dim_order_.cend(); - ++out_dim) { - if (operand_remaining_size >= out_dim->size) { - if (operand_remaining_size % out_dim->size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + // and project the source dimension order onto the target. + auto target_dim_iter = target_shape.layout().minor_to_major().cbegin(); + for (auto src_dim = dim_order_.cbegin(); src_dim != dim_order_.cend(); + ++src_dim) { + if (target_remaining_size >= src_dim->size) { + if (target_remaining_size % src_dim->size) { + return "Unsupported bitcast"; } - // Output dimension fragment completely fits into the operand one: + // Source dimension fragment completely fits into the target one: // just copy it as is. - operand_dim_order.push_back(*out_dim); - // Update the size of the remaining part of the operand that is - // carried over to next output dimensions. - operand_remaining_size /= out_dim->size; + target_dim_order.push_back(*src_dim); + // Update the size of the remaining part of the target that is + // carried over to next source dimensions. + target_remaining_size /= src_dim->size; } else { - // Output is larger than input. Assign further operand dimensions. - // Size of the not yet assigned part of the output dimension. - int64_t out_remaining_size = out_dim->size; + // Source is larger than target. Assign further target dimensions. + // Size of the not yet assigned part of the source dimension. + int64_t src_remaining_size = src_dim->size; // Subdimension index tracking dimension splits. - int subdim_index = out_dim->subdim_number; - if (operand_remaining_size > 1) { - // If there is a remaining fragment of a previous operand dimension + int subdim_index = src_dim->subdim_number; + if (target_remaining_size > 1) { + // If there is a remaining fragment of a previous target dimension // assign it first. - if (out_remaining_size % operand_remaining_size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + if (src_remaining_size % target_remaining_size) { + return "Unsupported bitcast"; } - operand_dim_order.push_back( - {out_dim->target_dim_number, subdim_index, operand_remaining_size}); + target_dim_order.push_back( + {src_dim->target_dim_number, subdim_index, target_remaining_size}); ++subdim_index; // Update the size of the fragment remaining to assign. - out_remaining_size /= operand_remaining_size; - operand_remaining_size = 1; + src_remaining_size /= target_remaining_size; + target_remaining_size = 1; } - while (out_remaining_size > 1) { - // Assign operand dimensions until the output remainder is covered. - int64_t operand_dim_size = operand_shape.dimensions(*operand_dim_iter); - int64_t new_fragment_size = operand_dim_size; - if (operand_dim_size > out_remaining_size) { - // If adding the next operand dimension exceeds output fragment size - // assign the remainder of the output and carry over the remainder - // of the operand. - if (operand_dim_size % out_remaining_size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + while (src_remaining_size > 1) { + // Assign target dimensions until the source remainder is covered. + int64_t target_dim_size = target_shape.dimensions(*target_dim_iter); + int64_t new_fragment_size = target_dim_size; + if (target_dim_size > src_remaining_size) { + // If adding the next target dimension exceeds source fragment size + // assign the remainder of the source and carry over the remainder + // of the target. + if (target_dim_size % src_remaining_size) { + return "Unsupported bitcast"; } - operand_remaining_size = operand_dim_size / out_remaining_size; - new_fragment_size = out_remaining_size; + target_remaining_size = target_dim_size / src_remaining_size; + new_fragment_size = src_remaining_size; } - operand_dim_order.push_back( - {out_dim->target_dim_number, subdim_index, new_fragment_size}); - out_remaining_size /= new_fragment_size; - ++operand_dim_iter; + target_dim_order.push_back( + {src_dim->target_dim_number, subdim_index, new_fragment_size}); + src_remaining_size /= new_fragment_size; + ++target_dim_iter; ++subdim_index; } } } - CHECK_EQ(operand_remaining_size, 1); + CHECK_EQ(target_remaining_size, 1); - // Handle remaining major dimensions of the operand. Call all degenerate + // Handle remaining major dimensions of the target. Call all degenerate // ones subdimensions of the most-major non-degenerate one. Otherwise // give up. - int subdim_index = operand_dim_order.back().subdim_number + 1; - while (operand_dim_iter != operand_shape.layout().minor_to_major().cend()) { - if (operand_shape.dimensions(*operand_dim_iter) != 1) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + int subdim_index = target_dim_order.back().subdim_number + 1; + while (target_dim_iter != target_shape.layout().minor_to_major().cend()) { + if (target_shape.dimensions(*target_dim_iter) != 1) { + return "Unsupported bitcast"; } - operand_dim_order.push_back( - {operand_dim_order.back().target_dim_number, subdim_index, 1}); + target_dim_order.push_back( + {target_dim_order.back().target_dim_number, subdim_index, 1}); ++subdim_index; - ++operand_dim_iter; + ++target_dim_iter; } - dim_order_ = operand_dim_order; - return OkStatus(); + dim_order_ = target_dim_order; + return FusionDecision{}; } -Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { +FusionDecision DimensionOrder::HandleCopyOrTransposeOrBroadcast( + const HloInstruction* hlo, const TransformDirection direction) { // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by // dimension, apply permutations, then finally remove the grouping. // Group subdimensions by iterating over them in the same order as over // dimensions and matching by total size. - std::vector out_physical; - out_physical.reserve(hlo->shape().rank()); + const HloInstruction* src = + (direction == TransformDirection::kOutputToInput) ? hlo : hlo->operand(0); + const HloInstruction* dst = + (direction == TransformDirection::kOutputToInput) ? hlo->operand(0) : hlo; + std::vector src_physical; + src_physical.reserve(src->shape().rank()); auto dim_order_it = dim_order_.cbegin(); - for (int64_t dim_index : hlo->shape().layout().minor_to_major()) { - const int64_t dim_size = hlo->shape().dimensions(dim_index); + for (int64_t dim_index : src->shape().layout().minor_to_major()) { + const int64_t dim_size = src->shape().dimensions(dim_index); int64_t subdim_size_accumulator = 1; DimOrderVector subdim_group; do { @@ -389,55 +488,65 @@ Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { ++dim_order_it; } while (subdim_size_accumulator < dim_size); CHECK_EQ(subdim_size_accumulator, dim_size); - out_physical.push_back(subdim_group); + src_physical.push_back(subdim_group); } - // Out physical -> out logical. - std::vector out_logical; - out_logical.resize(out_physical.size()); - for (int i = 0; i < out_physical.size(); ++i) { - out_logical[hlo->shape().layout().minor_to_major(i)] = out_physical[i]; + // Source physical -> source logical. + std::vector src_logical; + src_logical.resize(src_physical.size()); + for (int i = 0; i < src_physical.size(); ++i) { + src_logical[src->shape().layout().minor_to_major(i)] = src_physical[i]; } - // Out logical -> operand logical. - std::vector operand_logical; + // Source logical -> destination logical. + std::vector dst_logical; if (hlo->opcode() == HloOpcode::kTranspose) { - auto transpose = ::xla::Cast(hlo); - operand_logical.resize(out_logical.size()); - for (int i = 0; i < out_logical.size(); ++i) { - operand_logical[transpose->dimensions()[i]] = out_logical[i]; + const auto transpose = Cast(hlo); + std::vector permutation(transpose->dimensions().cbegin(), + transpose->dimensions().cend()); + if (direction == TransformDirection::kInputToOutput) { + permutation = InversePermutation(permutation); + } + dst_logical.resize(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { + dst_logical[permutation[i]] = src_logical[i]; + } + } else if (hlo->opcode() == HloOpcode::kBroadcast) { + const auto broadcast = Cast(hlo); + dst_logical.resize(broadcast->dimensions().size()); + for (int i = 0; i < broadcast->dimensions().size(); ++i) { + dst_logical[i] = src_logical[broadcast->dimensions()[i]]; } } else { // Copy preserves the logical shape, just permutes the layout. - const Shape& operand_shape = hlo->operand(0)->shape(); - CHECK(ShapeUtil::SameDimensions(hlo->shape(), operand_shape)); - operand_logical = out_logical; + CHECK(ShapeUtil::SameDimensions(src->shape(), dst->shape())); + dst_logical = src_logical; } - // Operand logical -> operand physical and ungroup subdimensions. - const Layout& operand_layout = hlo->operand(0)->shape().layout(); + // Destination logical -> destination physical and ungroup subdimensions. + const Layout& dst_layout = dst->shape().layout(); dim_order_.clear(); - for (int64_t dim_idx : operand_layout.minor_to_major()) { - for (const DimDescription& subdim : operand_logical[dim_idx]) { + for (int64_t dim_idx : dst_layout.minor_to_major()) { + for (const DimDescription& subdim : dst_logical[dim_idx]) { dim_order_.push_back(subdim); } } - return OkStatus(); + return FusionDecision{}; } // Tells if the dimension order is supported by the triton GEMM emitter. // Only the dimension indicated by SplittableDimensionIndex() can be split // physically once by other dimensions. Other ones can be only split logically. // All subdimensions within a dimension have to be ordered. -Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { - std::array subdim_counters = { +FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { + std::array subdim_counters = { -1, -1, -1, -1}; - std::array split_counters = { + std::array split_counters = { -1, -1, -1, -1}; const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); + VLOG(8) << order.ToString(); for (int i = 0; i < dim_order_vector.size(); i++) { const auto [dim_number, subdim_number, size] = dim_order_vector[i]; - VLOG(8) << dim_number << "\t" << subdim_number << "\t" << size; if (subdim_counters[dim_number] != subdim_number - 1) { - return Unimplemented("Transpose within a dimension."); + return "Transpose within a dimension."; } ++subdim_counters[dim_number]; if (size == 1) { @@ -445,33 +554,380 @@ Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { } if (i == 0 || dim_order_vector[i - 1].target_dim_number != dim_number) { ++split_counters[dim_number]; - if (dim_number == order.SplittableDimensionIndex()) { + if (dim_number == order.SplittableDimensionIndex() && + order.IsSupportedSplittableDimensionMajorPartSize(size)) { if (split_counters[dim_number] > 1) { - return Unimplemented("2nd split of a splittable dimension."); + return "2nd split of a splittable dimension."; } } else if (split_counters[dim_number] > 0) { - return Unimplemented("Split of a non-splittable dimension."); + return "Split of a non-splittable dimension."; } } } - return OkStatus(); + return FusionDecision{}; +} + +// Difference of input and output data volumes of an instruction. +int64_t InputMinusOutputBytes(const HloInstruction& hlo) { + CHECK(!hlo.shape().IsTuple()); + int64_t input_size = 0; + for (const HloInstruction* operand : hlo.operands()) { + CHECK(!operand->shape().IsTuple()); + input_size += ShapeUtil::ByteSizeOf(operand->shape()); + } + return input_size - ShapeUtil::ByteSizeOf(hlo.shape()); +} + +// Tells if an instruction has no user into which it could be fused. +// More cases should be added here. +bool CanNotBeFusedIntoAUser(const HloInstruction& hlo) { + return hlo.IsRoot() || (hlo.user_count() == 1 && hlo.users()[0]->IsRoot() && + hlo.users()[0]->opcode() == HloOpcode::kTuple); +} + +// Tells if an instruction has no input into which it could be fused. +// More cases should be added here. +bool CanNotBeFusedIntoAProducer(const HloInstruction& hlo) { + return hlo_query::AllOperandsAreParametersOrConstants(hlo); +} + +// Let input and output data volumes of a fusion grow by small amounts. +constexpr int kIoToleranceBytes = 1024; + +// Tells that fusing an instruction as an input is efficient. +bool IsInputWorthFusing(const HloInstruction& hlo) { + if (hlo.user_count() > 1) { + return false; + } + return hlo_query::AllOperandsAreParametersOrConstants(hlo) || + InputMinusOutputBytes(hlo) <= kIoToleranceBytes; } -// Transforms dim_order describing the output of `hlo` into a -// description of its input if it is supported by the triton GEMM emitter. -Status CanFuse(const HloInstruction* hlo, DimensionOrder& dim_order, - const GpuVersion gpu_version) { - if (hlo->opcode() == HloOpcode::kConvert) { - return RequireTritonFusibleConvert(hlo, gpu_version); - } else if (hlo->IsElementwise() && hlo->opcode() != HloOpcode::kCopy) { - // Temporarily forbid fusing elementwise operations - // other than copy and convert. - return Unimplemented("Unsupported elementwise operation"); - } - TF_RETURN_IF_ERROR(dim_order.HandleInstruction(hlo)); +// Tells that fusing an instruction as an output is efficient. +bool IsOutputWorthFusing(const HloInstruction& hlo) { + return CanNotBeFusedIntoAUser(hlo) || + InputMinusOutputBytes(hlo) >= -kIoToleranceBytes; +} + +// Checks if the instruction is possible and profitable to fuse. +// If so tries to transform dim_order describing one side of `hlo` into a +// description of its other side if it is supported by the triton GEMM emitter. +FusionDecision CanFuse(const HloInstruction& hlo, bool as_input, + DimensionOrder& dim_order, + absl::flat_hash_map& old_to_new_mapping, + const GpuVersion gpu_version) { + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return "Unsupported instruction."; + } + for (const HloInstruction* operand : hlo.operands()) { + if (!IsSupportedDataType(operand->shape().element_type(), gpu_version)) { + return "Unsupported input data type."; + } + } + if (!IsSupportedDataType(hlo.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + if (hlo.opcode() == HloOpcode::kBroadcast && + !hlo_query::IsScalarConstant(hlo.operand(0))) { + return "Skipping unsupported broadcast."; + } + if (as_input) { + if (hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_triton_fusion_level() < 2) { + if (hlo.opcode() == HloOpcode::kConvert) { + if (FusionDecision decision = + RequireTritonFusibleConvert(&hlo, gpu_version); + !decision) { + return decision; + } + } else if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy) { + return "Ignored elementwise operation"; + } + } else { + if (!CanNotBeFusedIntoAProducer(hlo) && !IsInputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as input."; + } + } + } else { + if (hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_triton_fusion_level() < 2) { + return "Skipping fusing outputs at low fusion levels."; + } + for (const HloInstruction* operand : hlo.operands()) { + // Skip already fused operands. + if (old_to_new_mapping.contains(operand)) { + continue; + } + // Currently only broadcasts of scalar constants or parameters + // are accepted as other inputs of non-unary operations + // in the output fusion. + if (hlo_query::IsBroadcastOfScalarConstant(*operand) || + operand->opcode() == HloOpcode::kParameter) { + continue; + } + return "Has multiple inputs - not properly analyzed yet."; + } + if (!IsOutputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as output."; + } + } + + if (FusionDecision decision = dim_order.HandleInstruction( + &hlo, as_input ? DimensionOrder::TransformDirection::kOutputToInput + : DimensionOrder::TransformDirection::kInputToOutput); + !decision) { + return decision; + } + return RequireTritonGemmSupportedDimOrder(dim_order); } +// Clone an instruction into the fusion. +void Fuse(HloInstruction& hlo, + absl::flat_hash_map& + old_to_new_mapping, + std::vector& fusion_inputs, + HloComputation::Builder& builder) { + if (old_to_new_mapping.contains(&hlo)) { + return; + } + VLOG(3) << "Fusing " << hlo.ToString(); + auto get_or_add_parameter = [&](HloInstruction& instr) { + if (auto it = old_to_new_mapping.find(&instr); + it != old_to_new_mapping.end()) { + return it->second; + } + fusion_inputs.push_back(&instr); + return old_to_new_mapping + .insert({&instr, + builder.AddInstruction(HloInstruction::CreateParameter( + fusion_inputs.size() - 1, instr.shape(), + absl::StrCat("parameter_", fusion_inputs.size() - 1)))}) + .first->second; + }; + if (hlo.opcode() == HloOpcode::kParameter || + hlo.opcode() == HloOpcode::kGetTupleElement) { + get_or_add_parameter(hlo); + } else { + std::vector hlo_new_operands; + for (HloInstruction* operand : hlo.operands()) { + hlo_new_operands.push_back(get_or_add_parameter(*operand)); + } + old_to_new_mapping[&hlo] = builder.AddInstruction( + hlo.CloneWithNewOperands(hlo.shape(), hlo_new_operands)); + } +} + +// Tells how many new parameters does a fusion gain by fusing the operation as +// an input. +int64_t NumAddedParameters(const HloInstruction& hlo) { + // Non-scalar constant is equivalent to a parameter: one input, one output. + if (hlo.opcode() == HloOpcode::kConstant && + !ShapeUtil::IsScalar(hlo.shape())) { + return 0; + } + // All other instructions add all own inputs and remove own single output. + return hlo.operand_count() - 1; +} + +// Fuse an instruction with all its fusible inputs. +// If an input is not fusible stop there and make a parameter of the new +// fusion, otherwise put it onto stack and check its own inputs first. +void FuseWithInputsRecursively( + HloInstruction* root, DimensionOrder root_dim_order, + // Dimension orders describing inputs of corresponding instructions. + absl::flat_hash_map& dim_orders, + const GpuVersion gpu_version, + absl::flat_hash_map& + old_to_new_mapping, + std::vector& fusion_inputs, + HloComputation::Builder& builder) { + absl::flat_hash_set visited; + std::stack to_fuse; + // Instructions at the edge 'to_fuse' that can either get fused too or + // become parameters of the fusion. Used to track the number of parameters + // of the fusion. + absl::flat_hash_set inputs; + // Currently only one physically unique dim order per scope is supported. + // Let it change while the scope has one input; afterwards require all + // of them to be physically compatible. + const HloInstruction* reference_dim_order_hlo = nullptr; + if (CanFuse(*root, /*as_input=*/true, root_dim_order, old_to_new_mapping, + gpu_version)) { + to_fuse.push(root); + inputs.insert(root->operands().begin(), root->operands().end()); + // root_dim_order went through output -> input transformation here. + CHECK(dim_orders.insert({root, root_dim_order}).second) << root->ToString(); + } + visited.insert(root); + while (!to_fuse.empty()) { + bool top_is_ready_to_fuse = true; + HloInstruction* hlo = to_fuse.top(); + if (reference_dim_order_hlo == nullptr && hlo->operand_count() > 1) { + reference_dim_order_hlo = hlo; + } + for (HloInstruction* operand : hlo->mutable_operands()) { + if (visited.insert(operand).second) { + // Stop adding new parameters. + if (inputs.size() >= DotFusionAnalysis::kMaxParameterPerScope && + NumAddedParameters(*operand) > 0) { + continue; + } + // Operand's output is described by its consumer's input. + DimensionOrder operand_dim_order(dim_orders.at(hlo)); + // CanFuse() makes output -> input transformation of + // operand_dim_order if succeeds. + if (CanFuse(*operand, /*as_input=*/true, operand_dim_order, + old_to_new_mapping, gpu_version)) { + if (reference_dim_order_hlo != nullptr && + !operand_dim_order.IsPhysicallyEquivalent( + dim_orders.at(reference_dim_order_hlo))) { + continue; + } + to_fuse.push(operand); + if (operand->opcode() != HloOpcode::kParameter) { + inputs.erase(operand); + } + inputs.insert(operand->operands().begin(), operand->operands().end()); + // Save the dimension order description of operand's input. + CHECK(dim_orders.insert({operand, operand_dim_order}).second) + << operand->ToString(); + top_is_ready_to_fuse = false; + } + } + } + if (top_is_ready_to_fuse) { + Fuse(*hlo, old_to_new_mapping, fusion_inputs, builder); + to_fuse.pop(); + } + } +} + +// Fuses dot and the compatible and profitable to fuse operations around it +// into a new fusion computation constructed using the builder. fusion_inputs +// get populated with the non-fused instructions that become operands of the +// call to this fusion. fusion_output_ptr (if not nullptr) gets assigned the +// original instruction that has to be replaced by the call to the fusion. +StatusOr FuseDot(HloInstruction& dot, + const GpuVersion gpu_version, + HloComputation::Builder& builder, + std::vector& fusion_inputs, + HloInstruction** fusion_output_ptr) { + VLOG(5) << dot.ToString(); + if (FusionDecision can_handle = CanTritonHandleGEMM(dot, gpu_version); + !can_handle) { + VLOG(3) << can_handle.Explain(); + return can_handle; + } + + // Original instruction -> fused one. + absl::flat_hash_map + old_to_new_mapping; + + // Separate traversal from LHS and RHS inputs of the dot: they use + // differently shaped tiles but may go through same HLO graph nodes. + // Direct dot inputs have well defined dimension orders. + + auto fuse_inputs = [&](int operand_number) + -> StatusOr> { + absl::flat_hash_map dim_orders; + int operand_count_before = fusion_inputs.size(); + // Direct dot inputs have well defined dimension orders. + FuseWithInputsRecursively( + dot.mutable_operand(operand_number), + DimensionOrder::FromDotOperand(dot, operand_number), dim_orders, + gpu_version, old_to_new_mapping, fusion_inputs, builder); + TF_RET_CHECK(fusion_inputs.size() - operand_count_before <= + DotFusionAnalysis::kMaxParameterPerScope); + return dim_orders; + }; + // Check if non-contracting dimension originating from LHS operand in the + // output can be split. This currently requires this dimension being split + // in the operand the same way. + int64_t lhs_nc_split_major_part = -1; + { + TF_ASSIGN_OR_RETURN(const auto lhs_dim_orders, fuse_inputs(0)); + // Looking at first LHS parameter to find split non-contracting dimension + // is sufficient because currently all parameters of one scope have to use + // the same tiling. + auto first_lhs_parameter_it = lhs_dim_orders.cbegin(); + while (first_lhs_parameter_it != lhs_dim_orders.cend()) { + if (first_lhs_parameter_it->first->opcode() == HloOpcode::kParameter) { + break; + } + ++first_lhs_parameter_it; + } + if (first_lhs_parameter_it != lhs_dim_orders.cend()) { + const auto lhs_nc_iter_spec = DimensionOrderToTensorIterationSpec( + first_lhs_parameter_it->second)[NonContractingDimensionIndex(dot, 0)]; + if (lhs_nc_iter_spec.size() > 1) { + lhs_nc_split_major_part = lhs_nc_iter_spec.at(1).count; + } + } + } + TF_RET_CHECK(fuse_inputs(1).ok()); + + Fuse(dot, old_to_new_mapping, fusion_inputs, builder); + + // Fusion at dot's output. + + // These describe _outputs_ of corresponding HLOs. + absl::flat_hash_map out_dim_orders; + out_dim_orders.insert( + {&dot, DimensionOrder::FromDotOutput(dot, /*split_k=*/1, + lhs_nc_split_major_part)}); + HloInstruction* fusion_output = ˙ + bool output_changed = true; + while (output_changed) { + output_changed = false; + if (fusion_output->user_count() != 1) { + break; + } + HloInstruction* user = fusion_output->users()[0]; + if (!IsDistributiveOverAddition(*user)) { + break; + } + // Describes the output of `current_output` = input of `user`. + DimensionOrder dim_order(out_dim_orders.at(fusion_output)); + if (CanFuse(*user, /*as_input=*/false, dim_order, old_to_new_mapping, + gpu_version)) { + // Now it describes the output of the user. + CHECK(out_dim_orders.insert({user, dim_order}).second); + for (HloInstruction* operand : user->operands()) { + if (!old_to_new_mapping.contains(operand)) { + // Here we need again a dim order describing inputs of the user. + FuseWithInputsRecursively( + operand, DimensionOrder(out_dim_orders.at(fusion_output)), + out_dim_orders, gpu_version, old_to_new_mapping, fusion_inputs, + builder); + } + } + Fuse(*user, old_to_new_mapping, fusion_inputs, builder); + fusion_output = user; + output_changed = true; + } + } + if (fusion_output_ptr != nullptr) { + *fusion_output_ptr = fusion_output; + } + if (dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any()) { + return FusionDecision{}; + } + for (const auto& iter : old_to_new_mapping) { + if (iter.second->opcode() == HloOpcode::kConvert || + iter.second->opcode() == HloOpcode::kTranspose) { + return FusionDecision{}; + } + } + return "No profitable operations to fuse."; +} + // Extracts into fused computations parts of HLO graph including dot() // operations that can target the triton GEMM emitter. class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { @@ -482,117 +938,50 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { // if so - fuses all its compatible inputs and outputs as a new computation // and replaces the original dot() with a call to the computation. Status HandleDot(HloInstruction* dot) override { - VLOG(5) << dot->ToString(); - - if (!CanTritonHandleGEMM(*dot, gpu_version_)) { + std::string fusion_name = absl::StrCat("triton_gemm_", dot->name()); + HloComputation::Builder builder(absl::StrCat(fusion_name, "_computation")); + std::vector fusion_inputs; + HloInstruction* fusion_output = nullptr; + TF_ASSIGN_OR_RETURN( + const FusionDecision should_fuse, + FuseDot(*dot, gpu_version_, builder, fusion_inputs, &fusion_output)); + if (builder.last_added_instruction() == nullptr) { return OkStatus(); } - // If a GEMM requiring padding for cuBLAS is encountered here this // happened because earlier ShouldTritonHandleGEMM() accepted it and padding - // was skipped. Do not check ShouldTritonHandleGEMM() again then. + // was skipped. Accept it ignoring profitability checks. if (!CublasRequiresPadding( - *xla::Cast(dot), + *Cast(dot), std::get(gpu_version_)) && - !ShouldTritonHandleGEMM(*dot, gpu_version_)) { + !should_fuse) { return OkStatus(); } - // TODO(b/266857789): also fuse convert(dot()) at output if present: - // seen on s8xf32->bf16 - std::string suggested_name = absl::StrCat("triton_gemm_", dot->name()); - HloComputation::Builder builder( - absl::StrCat(suggested_name, "_computation")); - // Original instruction -> fused one. - absl::flat_hash_map - old_to_new_mapping; - absl::flat_hash_set visited; - std::vector call_operands; - // Traverse and fuse dot() inputs bottom-up starting from direct operands. - // If an input is not fusible stop there and make it a parameter of the new - // fusion, otherwise put it onto stack and check its own inputs first. - std::stack to_fuse; - // Dimension orders describing inputs of corresponding instructions. - absl::flat_hash_map dim_orders; - to_fuse.push(dot); - while (!to_fuse.empty()) { - bool top_is_ready_to_fuse = true; - HloInstruction* hlo = to_fuse.top(); - for (HloInstruction* operand : hlo->mutable_operands()) { - if (visited.insert(operand).second) { - DimensionOrder operand_dim_order = [&] { - // Direct dot inputs are described by default dimension orders. - if (operand == dot->operand(0)) { - return DimensionOrder::FromDotOperand(*dot, 0); - } else if (operand == dot->operand(1)) { - return DimensionOrder::FromDotOperand(*dot, 1); - } - // Otherwise operand's output is described by its consumer's input. - return DimensionOrder(dim_orders.at(hlo)); - }(); - // CanFuse() makes output -> input transformation of - // operand_dim_order if succeeds. - if (CanFuse(operand, operand_dim_order, gpu_version_).ok()) { - VLOG(3) << "Fusing " << operand->ToString(); - to_fuse.push(operand); - // Save the dimension order description of operand's input. - dim_orders.insert({operand, operand_dim_order}); - top_is_ready_to_fuse = false; - } - } - } - if (top_is_ready_to_fuse) { - if (hlo->opcode() == HloOpcode::kParameter || - hlo->opcode() == HloOpcode::kGetTupleElement) { - old_to_new_mapping[hlo] = - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size(), hlo->shape(), - absl::StrCat("parameter_", call_operands.size()))); - call_operands.push_back(hlo); - } else { - std::vector hlo_new_operands; - for (HloInstruction* operand : hlo->operands()) { - const auto iter = old_to_new_mapping.find(operand); - if (iter != old_to_new_mapping.end()) { - hlo_new_operands.push_back(iter->second); - } else { - hlo_new_operands.push_back( - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size(), operand->shape(), - absl::StrCat("parameter_", call_operands.size())))); - call_operands.push_back(operand); - } - } - old_to_new_mapping[hlo] = builder.AddInstruction( - hlo->CloneWithNewOperands(hlo->shape(), hlo_new_operands)); - } - to_fuse.pop(); - } - } HloComputation* computation = dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); HloInstruction* dot_fusion = dot->parent()->AddInstruction(HloInstruction::CreateFusion( computation->root_instruction()->shape(), - HloInstruction::FusionKind::kCustom, call_operands, computation)); - dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, - suggested_name); + HloInstruction::FusionKind::kCustom, fusion_inputs, computation)); + dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name); TF_ASSIGN_OR_RETURN(auto backend_config, dot_fusion->backend_config()); backend_config.set_kind(std::string(kTritonGemmFusionKind)); TF_RETURN_IF_ERROR(dot_fusion->set_backend_config(backend_config)); - if (dot->IsRoot()) { - dot->parent()->set_root_instruction(dot_fusion); + if (fusion_output->IsRoot()) { + fusion_output->parent()->set_root_instruction(dot_fusion); TF_RETURN_IF_ERROR( - dot->parent()->RemoveInstructionAndUnusedOperands(dot)); + fusion_output->parent()->RemoveInstructionAndUnusedOperands( + fusion_output)); MarkAsChanged(); } else { - TF_RETURN_IF_ERROR(ReplaceInstruction(dot, dot_fusion)); + TF_RETURN_IF_ERROR(ReplaceInstruction(fusion_output, dot_fusion)); } - VLOG(5) << computation->ToString(); + XLA_VLOG_LINES(5, computation->ToString()); return OkStatus(); } @@ -643,7 +1032,7 @@ StatusOr MakeSplitKOperand( for (const HloInstruction* param : analysis.ScopeParameters(scope)) { // If an operand of dot does not read any parameters its K dimension // does not need analysis for fragmentation. - const DotFusionAnalysis::DimIterationSpec* spec = + const DimIterationSpec* spec = analysis.IterSpec(scope, param, contracting_dim_idx); // Split contracting dimension is not implemented yet. CHECK_EQ(spec->size(), 1); @@ -704,9 +1093,6 @@ Status MakeDotComputationSplitKBatch( DotDimensionNumbers new_dim_numbers; const int64_t lhs_contracting_idx = ContractingDimensionIndex(*dot, 0); - TF_ASSIGN_OR_RETURN( - HloInstruction * lhs, - MakeSplitKOperand(*dot, analysis, tiling, lhs_contracting_idx, 0)); CopyIncrementingAboveThreshold( old_dim_numbers.lhs_contracting_dimensions(), *new_dim_numbers.mutable_lhs_contracting_dimensions(), @@ -717,9 +1103,6 @@ Status MakeDotComputationSplitKBatch( *new_dim_numbers.mutable_lhs_batch_dimensions(), lhs_contracting_idx); const int64_t rhs_contracting_idx = ContractingDimensionIndex(*dot, 1); - TF_ASSIGN_OR_RETURN( - HloInstruction * rhs, - MakeSplitKOperand(*dot, analysis, tiling, rhs_contracting_idx, 1)); CopyIncrementingAboveThreshold( old_dim_numbers.rhs_contracting_dimensions(), *new_dim_numbers.mutable_rhs_contracting_dimensions(), @@ -729,16 +1112,67 @@ Status MakeDotComputationSplitKBatch( old_dim_numbers.rhs_batch_dimensions(), *new_dim_numbers.mutable_rhs_batch_dimensions(), rhs_contracting_idx); - HloInstruction* new_dot = - MakeDotHlo(lhs, rhs, new_dim_numbers, dot->precision_config(), - dot->shape().element_type()) - .value(); - // `new_dot` will have default output layout even if `dot` had a custom one. - // We will set the original output layout on the reduce operation. + // Collect HLOs to transform between dot output and root. These will + // get a new major most batch dimension sized as split K factor. Other inputs + // of these HLOs will get broadcasted. + std::stack to_process; + // Store the same HLOs also in a hash set for quick lookups. + absl::flat_hash_set to_process_set; + HloInstruction* current = dot; + do { + to_process.push(current); + CHECK(to_process_set.insert(current).second); + if (current->users().empty()) { + break; + } + CHECK_EQ(current->user_count(), 1); + current = current->users()[0]; + if (!IsDistributiveOverAddition(*current)) { + return Cancelled("Operation non-distributive over addition after dot."); + } + } while (true); + + // Process the collected HLOs from computation root to dot. + while (!to_process.empty()) { + HloInstruction* current = to_process.top(); + to_process.pop(); + // Add split-K dimension to `current`. + HloInstruction* expanded; + if (current == dot) { + TF_ASSIGN_OR_RETURN( + HloInstruction * lhs, + MakeSplitKOperand(*dot, analysis, tiling, lhs_contracting_idx, 0)); + TF_ASSIGN_OR_RETURN( + HloInstruction * rhs, + MakeSplitKOperand(*dot, analysis, tiling, rhs_contracting_idx, 1)); + expanded = MakeDotHlo(lhs, rhs, new_dim_numbers, dot->precision_config(), + dot->shape().element_type()) + .value(); + dot->SetupDerivedInstruction(expanded); + } else { + expanded = computation->AddInstruction( + current->CloneWithNewShape(ShapeUtil::PrependMajorDimension( + tiling.split_k(), current->shape()))); + } + TF_RETURN_IF_ERROR(current->ReplaceAllUsesWithDifferentShape(expanded)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(current)); + // Broadcast operands. + if (current == dot) { + continue; + } + for (int i = 0; i < expanded->operands().size(); ++i) { + HloInstruction* operand = expanded->mutable_operand(i); + if (!to_process_set.contains(operand)) { + std::vector broadcast_dimensions(operand->shape().rank()); + absl::c_iota(broadcast_dimensions, 1); + TF_RETURN_IF_ERROR(expanded->ReplaceOperandWithDifferentShape( + i, MakeBroadcastHlo(operand, broadcast_dimensions, + ShapeUtil::PrependMajorDimension( + tiling.split_k(), operand->shape())))); + } + } + } - dot->SetupDerivedInstruction(new_dot); - TF_RETURN_IF_ERROR(dot->ReplaceAllUsesWithDifferentShape(new_dot)); - TF_RETURN_IF_ERROR(dot->parent()->RemoveInstruction(dot)); if (disable_reduced_precision_reduction) { PrimitiveType output_type = computation->root_instruction()->shape().element_type(); @@ -752,6 +1186,65 @@ Status MakeDotComputationSplitKBatch( return OkStatus(); } +// Propagate dimension orders in consumer->producer direction starting at +// `origin` with input `origin_dim_order` till parameters of the computation. +// Store the found parameters and their iteration specs. +Status PropagateDimensionOrdersToParameters( + const HloInstruction& origin, const DimensionOrder& origin_dim_order, + absl::flat_hash_set& parameters, + absl::flat_hash_map& + iter_specs) { + absl::flat_hash_set visited; + std::queue to_process; + // Dimension orders describing inputs of corresponding instructions. + absl::flat_hash_map dim_orders; + TF_RET_CHECK(RequireTritonGemmSupportedDimOrder(origin_dim_order)); + dim_orders.insert({&origin, origin_dim_order}); + visited.insert(&origin); + to_process.push(&origin); + while (!to_process.empty()) { + const HloInstruction* hlo = to_process.front(); + to_process.pop(); + if (hlo->opcode() == HloOpcode::kParameter) { + // One parameter corresponds to one iteration spec in the results of the + // analysis. This describes well situations when a parameter has one or + // more elementwise users - they share the same tiling. Situations when + // one instruction is read differently by different users in the same + // scope of the dot are currently prevented during the fusion. + TF_RET_CHECK(parameters.insert(hlo).second); + VLOG(5) << hlo->ToString(); + } + for (const HloInstruction* operand : hlo->operands()) { + if (!visited.insert(operand).second) { + continue; + } + if (operand->opcode() == HloOpcode::kDot) { + // Encountering the dot itself happens during the processing of the + // output fusion. The propagation should stop at it. + continue; + } + // Operand's output is described by its consumer's input. + auto [it, inserted] = + dim_orders.insert({operand, DimensionOrder(dim_orders.at(hlo))}); + TF_RET_CHECK(inserted); + DimensionOrder& hlo_operand_dim_order = it->second; + TF_RET_CHECK(hlo_operand_dim_order.HandleInstruction( + operand, DimensionOrder::TransformDirection::kOutputToInput)) + << operand->ToString(); + TF_RET_CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)); + to_process.push(operand); + } + } + // For now all parameters of one scope have to use the same tiling. + for (const HloInstruction* parameter : parameters) { + TF_RET_CHECK(dim_orders.at(parameter).IsPhysicallyEquivalent( + dim_orders.at(*parameters.cbegin()))); + iter_specs[parameter] = + DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); + } + return OkStatus(); +} + } // anonymous namespace // BF16 is supported in a sense that all operations on it are implemented @@ -878,70 +1371,66 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, for (const Scope scope : {Scope::LHS, Scope::RHS}) { const int operand_number = static_cast(scope); - const HloInstruction* dot_operand = dot->operand(operand_number); - absl::flat_hash_set visited; - std::queue to_process; - // Dimension orders describing inputs of corresponding instructions. - absl::flat_hash_map dim_orders; + const HloInstruction* operand = dot->operand(operand_number); DimensionOrder dot_operand_dim_order = DimensionOrder::FromDotOperand(*dot, operand_number, split_k); - TF_CHECK_OK(dot_operand_dim_order.HandleInstruction(dot_operand)); - TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) - << dot_computation->ToString(); - dim_orders.insert({dot_operand, dot_operand_dim_order}); - visited.insert(dot_operand); - to_process.push(dot_operand); - while (!to_process.empty()) { - const HloInstruction* hlo = to_process.front(); - to_process.pop(); - if (hlo->opcode() == HloOpcode::kParameter) { - CHECK(parameters_[scope].insert(hlo).second); - VLOG(5) << hlo->ToString(); - } - for (const HloInstruction* hlo_operand : hlo->operands()) { - if (!visited.insert(hlo_operand).second) { - continue; - } - // Operand's output is described by its consumer's input. - auto [it, inserted] = dim_orders.insert( - {hlo_operand, DimensionOrder(dim_orders.at(hlo))}); - CHECK(inserted); - DimensionOrder& hlo_operand_dim_order = it->second; - TF_CHECK_OK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); - TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) - << " " << dot_computation->ToString(); - to_process.push(hlo_operand); - } - } + CHECK(dot_operand_dim_order.HandleInstruction( + operand, DimensionOrder::TransformDirection::kOutputToInput)); + CHECK_OK(PropagateDimensionOrdersToParameters( + *operand, dot_operand_dim_order, parameters_[scope], + iter_specs_[scope])); + } - for (const HloInstruction* parameter : parameters_[scope]) { - iter_specs_[scope][parameter] = - DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); + int64_t lhs_nc_split_major_part_size = -1; + if (!ScopeParameters(Scope::LHS).empty()) { + const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = + IterSpec(Scope::LHS, *ScopeParameters(Scope::LHS).cbegin(), + NonContractingDimensionIndex(*dot, 0)); + if (lhs_nc_iter_spec->size() > 1) { + lhs_nc_split_major_part_size = lhs_nc_iter_spec->at(1).count; } } - - DimensionOrder dim_order = DimensionOrder::FromDotOutput(*dot); + DimensionOrder dim_order = DimensionOrder::FromDotOutput( + *dot, split_k, lhs_nc_split_major_part_size); + const HloInstruction* output = dot; + // Currently supported is one fusion output and one path from dot to it. + // Propagate dimension order from dot to root. + while (!output->IsRoot()) { + CHECK_EQ(output->user_count(), 1); + output = output->users()[0]; + CHECK(dim_order.HandleInstruction( + output, DimensionOrder::TransformDirection::kInputToOutput)); + CHECK(RequireTritonGemmSupportedDimOrder(dim_order)); + } CHECK(iter_specs_[Scope::OUTPUT] - .insert({dot, DimensionOrderToTensorIterationSpec(dim_order)}) + .insert({output, DimensionOrderToTensorIterationSpec(dim_order)}) .second); + if (output != dot) { + // Propagate back to parameters of the output fusion. + CHECK(dim_order.HandleInstruction( + output, DimensionOrder::TransformDirection::kOutputToInput)); + CHECK_OK(PropagateDimensionOrdersToParameters(*output, dim_order, + parameters_[Scope::OUTPUT], + iter_specs_[Scope::OUTPUT])); + } } -const DotFusionAnalysis::DimIterationSpec* DotFusionAnalysis::IterSpec( +const DimIterationSpec* DotFusionAnalysis::IterSpec( const DotFusionAnalysis::Scope scope, const HloInstruction* hlo, const int dimension) const { auto ret = iter_specs_.at(scope).find(hlo); if (ret != iter_specs_.at(scope).end()) { - return &ret->second.at(dimension); + return &ret->second[dimension]; } return nullptr; } -bool CanTritonHandleGEMM(const HloInstruction& dot, - const GpuVersion gpu_version) { +FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, + const GpuVersion gpu_version) { if (dot.opcode() != HloOpcode::kDot || absl::c_any_of(dot.precision_config().operand_precision(), [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return false; + return "Non-default precision."; } auto supported_output_type = [&](const PrimitiveType t) { @@ -961,21 +1450,21 @@ bool CanTritonHandleGEMM(const HloInstruction& dot, // TODO(b/266862493): Support more output types. if (!supported_output_type(dot.shape().element_type())) { - return false; + return "Unsupported output data type."; } if (!IsSupportedDataType(dot.operand(0)->shape().element_type(), gpu_version) || !IsSupportedDataType(dot.operand(1)->shape().element_type(), gpu_version)) { - return false; + return "Unsupported input data type."; } const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // TODO(b/269580541): support multiple batch dimensions. if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return false; + return "Multiple batch dimensions."; } // Cases where lhs or rhs have no non-contracting dims are not handled. @@ -985,48 +1474,18 @@ bool CanTritonHandleGEMM(const HloInstruction& dot, dim_numbers.rhs_batch_dimensions().size() + dim_numbers.rhs_contracting_dimensions().size() == dot.operand(1)->shape().rank()) { - return false; + return "No non-contracting dimensions."; } - return true; + return FusionDecision{}; } -bool ShouldTritonHandleGEMM(const HloInstruction& dot, - const GpuVersion gpu_version) { - if (dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any()) { - return true; - } - - // Traverse HLO graph part checking that it both can be fused - // and is worth fusing. - auto has_triton_fusible_inputs = [&gpu_version](const HloInstruction& dot, - const int operand_number) { - DimensionOrder dim_order = - DimensionOrder::FromDotOperand(dot, operand_number); - std::queue queue; - queue.push(dot.operand(operand_number)); - while (!queue.empty()) { - const HloInstruction* current = queue.front(); - queue.pop(); - if (!CanFuse(current, dim_order, gpu_version).ok()) { - continue; - } - // Stop as soon as a profitable operation is fused. - if (current->opcode() == HloOpcode::kConvert || - current->opcode() == HloOpcode::kTranspose) { - return true; - } - for (const HloInstruction* operand : current->operands()) { - queue.push(operand); - } - } - return false; - }; - - return has_triton_fusible_inputs(dot, 0) || has_triton_fusible_inputs(dot, 1); - - // TODO(b/266857789): either check that no output fusion (axpy, relu etc) - // is expected or actually support it. +bool ShouldTritonHandleGEMM(HloInstruction& dot, const GpuVersion gpu_version) { + std::vector fusion_inputs; + HloComputation::Builder builder("disposable"); + return FuseDot(dot, gpu_version, builder, fusion_inputs, + /*fusion_output_ptr=*/nullptr) + ->CanFuse(); } StatusOr GemmRewriterTriton::Run( diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h index 715c79d9114659..6619d16196d1b3 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" namespace xla { namespace gpu { @@ -52,13 +53,13 @@ Status MakeDotSplitKBatch(HloInstruction* dot_fusion, const AutotuneResult::TritonGemmKey& tiling); // Filters GEMMs which can be handled using Triton. -bool CanTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); +FusionDecision CanTritonHandleGEMM(const HloInstruction&, + GpuVersion gpu_version); // Filters GEMMs which are better to handle using Triton. -bool ShouldTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); +bool ShouldTritonHandleGEMM(HloInstruction&, GpuVersion gpu_version); -// Analysis of iteration of HLO shapes within a fusion around dot(). -class DotFusionAnalysis { +class TensorIterationSpec { public: // Description of basic iteration: `count` elements separated by `stride`. struct IterationSpecFragment { @@ -68,16 +69,42 @@ class DotFusionAnalysis { // of several HLO dimensions. Product of subfragments equals `count`. std::vector subfragments; }; - // Description of complex iteration over a sequence of several strides. // Describes a logically contiguous dimension of a tensor physically // separated into multiple fragments by other dimensions. using DimIterationSpec = std::vector; // At most: contracting, non-contracting, split-K, another batch. - static const int kMaxDimsPerTensor = 4; - using TensorIterationSpec = std::array; + static constexpr int kMaxDimsPerTensor = 4; + using StorageType = std::array; + + const DimIterationSpec& operator[](int dimension) const { + return dim_iteration_specs_[dimension]; + } + + DimIterationSpec& operator[](int dimension) { + return dim_iteration_specs_[dimension]; + } + + // Compares physical layouts of tensors ignoring subfragments of dimensions. + bool operator==(const TensorIterationSpec& other) const; + + StorageType::iterator begin() { return dim_iteration_specs_.begin(); } + StorageType::iterator end() { return dim_iteration_specs_.end(); } + StorageType::const_iterator cbegin() const { + return dim_iteration_specs_.cbegin(); + } + StorageType::const_iterator cend() const { + return dim_iteration_specs_.cend(); + } + + private: + StorageType dim_iteration_specs_; +}; +// Analysis of iteration of HLO shapes within a fusion around dot(). +class DotFusionAnalysis { + public: // Execute analysis of dot fusion computation. // split_k indicates whether this operation was converted to the split-K // form and tells the analysis how to interpret the batch dimensions. @@ -88,9 +115,15 @@ class DotFusionAnalysis { // defined by left operand, right operand and output. enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; + // Every parameter requires a separate piece of shared memory for asynchronous + // loads. Multiple parameters are approximately equivalent to multiple + // pipeline stages. + static constexpr int kMaxParameterPerScope = 4; + // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. - const DimIterationSpec* IterSpec(Scope scope, const HloInstruction*, - int dimension) const; + const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, + const HloInstruction*, + int dimension) const; // Parameter HLO instructions used in a scope of `dot`. const absl::flat_hash_set& ScopeParameters( const Scope scope) const { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index d02faa5b3abdc9..f256ef6b0c3a15 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/cublas_padding_requirements.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" @@ -94,7 +95,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); } -TEST_F(GemmRewriterTritonTest, DoNotFuseConstant) { +TEST_F(GemmRewriterTritonTest, DoNotFuseVectorConstants) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( HloModule m @@ -102,14 +103,27 @@ HloModule m ENTRY e { p0 = s8[60,5] parameter(0) c0 = f16[60,5] convert(p0) - cst1 = f16[600] constant({...}) - r1 = f16[5,120] reshape(cst1) + cst1 = f16[5] constant({...}) + r1 = f16[5,120] broadcast(cst1), dimensions={0} ROOT d = f16[60,120] dot(c0, r1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Constant(), m::Parameter()))); + GmockMatch(m::Fusion(m::Parameter(), m::Broadcast()))); +} + +TEST_F(GemmRewriterTritonTest, DoNotTriggerOnUnsupportedOutputConversions) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f16[128,256] parameter(0) + p1 = f16[256,512] parameter(1) + r = f16[128,512] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT c = u8[128,512] convert(r) +})")); + EXPECT_FALSE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); } using TritonDotAnalysisTest = HloTestBase; @@ -404,6 +418,169 @@ ENTRY e { /*subfragments=*/ElementsAre(3)))); } +TEST_F(TritonDotAnalysisTest, TransposeOutput) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + dot = bf16[24,3]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + bc = bf16[12,2,3]{2,1,0} bitcast(dot) + ROOT t = bf16[3,12,2]{2,1,0} transpose(bc), dimensions={2,0,1} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + ROOT r = bf16[3,12,2]{2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* dot_output = dot_computation->root_instruction(); + const DotFusionAnalysis analysis(dot_computation); + EXPECT_THAT( + *analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, dot_output, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, + /*subfragments=*/ElementsAre(2, 12)))); + EXPECT_THAT( + *analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, dot_output, 1), + ElementsAre(FieldsAre(/*stride=*/24, /*count=*/3, + /*subfragments=*/ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, OutputParameterIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + dot = bf16[24,3]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = f16[3,24]{1,0} parameter(2) + p2t = f16[24,3]{1,0} transpose(p2), dimensions={1,0} + p2tc = bf16[24,3]{1,0} convert(p2t) + ROOT r = bf16[24,3]{1,0} divide(p2tc, dot) +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + p2 = f16[3,24]{1,0} parameter(2) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1, p2), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* output_param = + dot_computation->parameter_instruction(2); + const DotFusionAnalysis analysis(dot_computation); + EXPECT_EQ( + analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, output_param, 0) + ->size(), + 1); + EXPECT_THAT( + *analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, output_param, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, + /*subfragments=*/ElementsAre(24)))); + EXPECT_EQ( + analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, output_param, 1) + ->size(), + 1); + EXPECT_THAT( + *analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, output_param, 1), + ElementsAre(FieldsAre(/*stride=*/24, /*count=*/3, + /*subfragments=*/ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, InputBroadcastFromScalarIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[] parameter(1) + p1b = bf16[4,3] broadcast(p1) + ROOT dot = bf16[24,3]{1,0} dot(p0, p1b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[] parameter(1) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* scalar = dot_computation->parameter_instruction(1); + const DotFusionAnalysis analysis(dot_computation); + EXPECT_EQ(analysis.IterSpec(DotFusionAnalysis::Scope::RHS, scalar, 0)->size(), + 1); + EXPECT_THAT(*analysis.IterSpec(DotFusionAnalysis::Scope::RHS, scalar, 0), + ElementsAre(FieldsAre(/*stride=*/0, /*count=*/1, + /*subfragments=*/ElementsAre(1)))); +} + +TEST_F(TritonDotAnalysisTest, InputBroadcastFromVectorIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4] parameter(1) + p1b = bf16[4,3] broadcast(p1), dimensions={0} + ROOT dot = bf16[24,3]{1,0} dot(p0, p1b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4] parameter(1) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* vector = dot_computation->parameter_instruction(1); + const DotFusionAnalysis analysis(dot_computation); + EXPECT_EQ(analysis.IterSpec(DotFusionAnalysis::Scope::RHS, vector, 0)->size(), + 1); + EXPECT_THAT(*analysis.IterSpec(DotFusionAnalysis::Scope::RHS, vector, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/4, + /*subfragments=*/ElementsAre(4)))); +} + +TEST_F(TritonDotAnalysisTest, OutputBroadcastIsNotAccepted) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f16[1,35] parameter(0) + p0c = bf16[1,35] convert(p0) + p1 = bf16[35,1] parameter(1) + dot = bf16[1,1] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + b = bf16[] bitcast(dot) + ROOT bc = bf16[100] broadcast(b) +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kBroadcast); +} + using SplitKTest = HloTestBase; class SplitKTestWithMorePreciseReduction @@ -454,6 +631,79 @@ ENTRY e { HloOpcode::kReduce); } +TEST_F(SplitKTest, MakeSplitKWithOutputFusion) { + const std::string hlo_text = R"( +HloModule t + +triton_gemm_dot { + p0 = f16[480,128]{1,0} parameter(0) + p1 = f16[16,128]{1,0} parameter(1) + d = f16[480,16]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} + c = bf16[] constant(123) + n = bf16[] negate(c) + bc = bf16[480,16]{1,0} broadcast(n) + cv = bf16[480,16]{1,0} convert(d) + ROOT a = bf16[480,16]{1,0} multiply(bc, cv) +} + +ENTRY e { + p0 = f16[480,128]{1,0} parameter(0) + p1 = f16[16,128]{1,0} parameter(1) + ROOT fusion = bf16[480,16]{1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + AutotuneResult::TritonGemmKey key; + key.set_block_m(16); + key.set_block_n(16); + key.set_block_k(16); + key.set_split_k(4); + key.set_num_stages(1); + key.set_num_warps(4); + TF_EXPECT_OK( + MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key)); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kReduce); +} + +TEST_F(SplitKTest, PreventSplitKWithNonDistributiveOperations) { + const std::string hlo_text = R"( +HloModule t + +triton_gemm_dot { + p0 = f16[480,128]{1,0} parameter(0) + p1 = f16[16,128]{1,0} parameter(1) + d = f16[480,16]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} + c = f32[480,16]{1,0} convert(d) + ROOT s = f32[480,16]{1,0} tanh(c) +} + +ENTRY e { + p0 = f16[480,128]{1,0} parameter(0) + p1 = f16[16,128]{1,0} parameter(1) + ROOT fusion = f32[480,16]{1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + AutotuneResult::TritonGemmKey key; + key.set_block_m(16); + key.set_block_n(16); + key.set_block_k(16); + key.set_split_k(4); + key.set_num_stages(1); + key.set_num_warps(4); + EXPECT_THAT( + MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key), + tsl::testing::StatusIs( + tsl::error::CANCELLED, + absl::StrFormat( + "Operation non-distributive over addition after dot."))); +} + TEST_F(SplitKTest, MakeSplitKWithNonStandardOutputLayout) { const std::string kHloText = R"( HloModule t @@ -570,8 +820,6 @@ ENTRY e { } TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitKWithOutputFusion) { - GTEST_SKIP() << "Output fusion support is temporarily rolled back."; - const std::string hlo_text = R"( HloModule t @@ -793,6 +1041,206 @@ ENTRY e { EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); } +class GemmRewriterTritonLevel2Test : public GemmRewriterTritonTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_fusion_level(2); + return debug_options; + } +}; + +TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseIncompatibleDimOrders) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[5,3] parameter(0) + p1 = f16[5,7] parameter(1) + p2 = f16[7,5] parameter(2) + t = f16[5,7] transpose(p2), dimensions={1,0} + a = f16[5,7] add(t, p1) + ROOT d = f16[3,7] dot(p0, a), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Transpose()))); +} + +TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseTooManyParameters) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + tmp_0 = f32[] constant(1) + tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_2 = f32[3,49]{1,0} parameter(6) + tmp_3 = f32[] constant(0) + tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={} + tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT + tmp_6 = f32[3,49]{1,0} convert(tmp_5) + tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6) + tmp_8 = s32[] parameter(13) + tmp_9 = f32[] convert(tmp_8) + tmp_10 = f32[] maximum(tmp_9, tmp_0) + tmp_11 = f32[] divide(tmp_3, tmp_10) + tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={} + tmp_13 = pred[3,49]{1,0} parameter(7) + tmp_14 = pred[3,49]{1,0} parameter(10) + tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14) + tmp_16 = f32[3,49]{1,0} convert(tmp_15) + tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16) + tmp_18 = f32[3,49]{1,0} negate(tmp_17) + tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18) + tmp_20 = f32[3,49]{1,0} parameter(19) + tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20) + tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21) + tmp_23 = f32[3,49]{1,0} negate(tmp_22) + tmp_24 = f32[3,49]{1,0} negate(tmp_6) + tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17) + tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20) + tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26) + tmp_28 = f32[3,49]{1,0} parameter(18) + tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28) + tmp_30 = f32[3,49]{1,0} parameter(17) + tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30) + tmp_32 = f32[3,49]{1,0} parameter(16) + tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32) + tmp_34 = f32[3,49]{1,0} parameter(15) + tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34) + tmp_36 = f32[3,49]{1,0} parameter(14) + tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36) + tmp_38 = f32[1,1]{1,0} constant({ {0} }) + tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1} + tmp_40 = f32[] reshape(tmp_39) + tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={} + tmp_42 = u32[48]{0} parameter(11) + tmp_43 = u32[48]{0} parameter(5) + tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0} + tmp_45 = u32[3,32]{1,0} reshape(tmp_44) + tmp_46 = u32[96]{0} reshape(tmp_45) + tmp_47 = u32[] constant(1) + tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={} + tmp_49 = u32[96]{0} reshape(tmp_48) + tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49) + tmp_51 = u32[3,32]{1,0} reshape(tmp_50) + tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48) + tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52) + tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={} + tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54) + tmp_56 = f32[1,1]{1,0} constant({ {1} }) + tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1} + tmp_58 = f32[] reshape(tmp_57) + tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={} + tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59) + tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41) + tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61) + tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={} + tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT + tmp_65 = f32[3,32]{1,0} convert(tmp_64) + tmp_66 = f32[3,49]{1,0} parameter(9) + tmp_67 = f32[49]{0} parameter(4) + tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1} + tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68) + tmp_70 = f32[1,49]{1,0} parameter(12) + tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71) + tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1} + tmp_74 = f32[49]{0} reshape(tmp_73) + tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1} + tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75) + tmp_77 = f32[1,49]{1,0} parameter(3) + tmp_78 = f32[1,49]{1,0} parameter(8) + tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71) + tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72) + tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80) + tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71) + tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82) + tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83) + tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1} + tmp_86 = f32[49]{0} reshape(tmp_85) + tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1} + tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87) + tmp_89 = f32[1,49]{1,0} parameter(2) + tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1} + tmp_91 = f32[49]{0} reshape(tmp_90) + tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1} + tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92) + tmp_94 = f32[49,32]{1,0} parameter(1) + tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0} + tmp_96 = f32[32]{0} parameter(0) + tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1} + tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97) + tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98) + tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63) + tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63) + ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + DotFusionAnalysis::kMaxParameterPerScope * 2); +} + +TEST_F(GemmRewriterTritonLevel2Test, ParameterUsedElementwiseTwiceIsFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f32[1,35] parameter(0) + p0n = f32[1,35] negate(p0) + p0e = f32[1,35] exponential(p0) + a = f32[1,35] add(p0e, p0n) + p1 = f16[35,1] parameter(1) + p1c = f32[35,1] convert(p1) + ROOT dot = f32[1,1] dot(a, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::VOLTA, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter())))); + const DotFusionAnalysis analysis(module->entry_computation() + ->root_instruction() + ->called_computations()[0]); + EXPECT_EQ(analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).size(), 1); + EXPECT_EQ(analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS).size(), 1); +} + +TEST_F(GemmRewriterTritonLevel2Test, + ParameterUsedNonElementwiseTwiceIsFusedOnlyOnOnePath) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f32[4,4] parameter(0) + p0t = f32[4,4] transpose(p0), dimensions={1,0} + a = f32[4,4] add(p0, p0t) + p1 = f16[4,5] parameter(1) + p1c = f32[4,5] convert(p1) + ROOT dot = f32[4,5] dot(a, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::VOLTA, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Transpose(), m::Parameter())))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index f471b0726db2f7..53c2c3ec29b716 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -19,12 +19,10 @@ limitations under the License. #include #include #include -#include #include #include #include #include -#include // NOLINT #include #include #include @@ -40,9 +38,8 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/SplitModule.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" @@ -50,7 +47,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.h" #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "tensorflow/compiler/xla/mlir_hlo/transforms/gpu_passes.h" #include "tensorflow/compiler/xla/runtime/jit_executable.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/all_gather_broadcast_reorder.h" @@ -115,8 +111,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/gpu/metrics.h" #include "tensorflow/compiler/xla/service/gpu/move_copy_to_users.h" @@ -180,7 +174,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/stream_executor/device_description.pb.h" #include "tensorflow/compiler/xla/stream_executor/dnn.h" -#include "tensorflow/compiler/xla/stream_executor/rocm/rocm_platform_id.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/compiler/xla/util.h" @@ -990,6 +983,29 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( }); } + GpuFloatSupport bf16_support(BF16); + GpuFloatSupport f8e5m2_support(F8E5M2); + GpuFloatSupport f8e4m3fn_support(F8E4M3FN); + FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + + auto add_float_normalization = [&](HloPassPipeline& pipeline) { + auto& sub_pipeline = + pipeline.AddPass("float_normalization"); + sub_pipeline.AddPass(&bf16_support); + sub_pipeline.AddPass(&f8e5m2_support); + sub_pipeline.AddPass(&f8e4m3fn_support); + sub_pipeline.AddPass(&f8e4m3b11fnuz_support); + sub_pipeline.AddPass(&f8e5m2fnuz_support); + sub_pipeline.AddPass(&f8e4m3fnuz_support); + // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. + if (debug_options.xla_gpu_simplify_all_fp_conversions()) { + sub_pipeline.AddPass(); + } + }; + add_float_normalization(pipeline); + // By default use an externally provided thread pool. tsl::thread::ThreadPool* thread_pool = options.thread_pool; std::optional overriding_thread_pool; @@ -1011,18 +1027,8 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( &pipeline, hlo_module, stream_exec, debug_options, options, gpu_target_config, autotune_results, thread_pool)); - GpuFloatSupport bf16_support(BF16); - pipeline.AddPass(&bf16_support); - GpuFloatSupport f8e5m2_support(F8E5M2); - pipeline.AddPass(&f8e5m2_support); - GpuFloatSupport f8e4m3fn_support(F8E4M3FN); - pipeline.AddPass(&f8e4m3fn_support); - FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); - pipeline.AddPass(&f8e4m3b11fnuz_support); - FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); - pipeline.AddPass(&f8e5m2fnuz_support); - FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); - pipeline.AddPass(&f8e4m3fnuz_support); + // The Triton autotuner can insert new reductions. + add_float_normalization(pipeline); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_gpu_simplify_all_fp_conversions()) { @@ -1476,7 +1482,7 @@ StatusOr> GpuCompiler::RunBackend( .xla_gpu_enable_persistent_temp_buffers(), std::move(buffer_assignment_proto), [buffer_assignment] { return buffer_assignment->ToVerboseString(); }, - std::move(module)})); + std::move(module), options.enable_debug_info_manager})); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); @@ -1692,8 +1698,9 @@ std::optional GpuCompiler::FusionCanShareBufferHint( } // We need to make sure that the fusion parameter is accessed in the same - // iteration order as the fusion output. Also, there should not be any other - // fusion output that accesses it in a different iteration order. To make sure + // iteration order as the fusion output. Also, there should not be two fusion + // outputs that consume the fusion parameter, because we do not want to share + // the same fusion operand with two different fusion outputs. To make sure // that the iteration order is the same, we only allow ops on the path from // fusion parameter to fusion output which are elementwise (no copy) or // bitcast or an elementwise dynamic update slice (i.e. with the first operand @@ -1718,8 +1725,12 @@ std::optional GpuCompiler::FusionCanShareBufferHint( q.pop(); if (hlo_operand == output) { found_path_to_output = true; - // We still need to process the users of 'hlo_operand'. There can be other - // users in addition to the tuple user. + // The output should have at most 1 user: the tuple op (in case of a + // multi-output fusion) + if (hlo_operand->user_count() > 1) { + return false; + } + continue; } for (HloInstruction* hlo : hlo_operand->users()) { if (non_bitcast_root->opcode() == HloOpcode::kDynamicUpdateSlice && @@ -1746,8 +1757,10 @@ std::optional GpuCompiler::FusionCanShareBufferHint( } else if ((!hlo->IsElementwiseOnOperand( hlo->operand_index(hlo_operand)) || hlo->opcode() == HloOpcode::kCopy) && - hlo->opcode() != HloOpcode::kBitcast && - hlo->opcode() != HloOpcode::kTuple) { + hlo->opcode() != HloOpcode::kBitcast) { + // This check also catches the case that we reach a different fusion + // output, as that fusion output would have a tuple op as user, which we + // do not allow here. // Even if 'hlo' is not elementwise on the operand, it is ok if we are // coming from the second operand and 'hlo' is a DynamicUpdateSlice // which is the non_bitcast_root. This corresponds to the special case @@ -1761,11 +1774,9 @@ std::optional GpuCompiler::FusionCanShareBufferHint( return false; } } - if (visited.contains(hlo)) { - continue; + if (visited.insert(hlo).second) { + q.push(hlo); } - visited.insert(hlo); - q.push(hlo); } } return found_path_to_output; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index bea4b498b0f1b0..7ec383489e0cfb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -23,24 +23,20 @@ limitations under the License. #include #include -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/xla/autotune_results.pb.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/gpu/executable.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/stream_executor/device_description.pb.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler_test.cc index a673377bc96d67..7743d0c7266f2f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/autotune_results.pb.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" +#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/tsl/lib/core/status_test_util.h" @@ -33,16 +35,58 @@ namespace { namespace op = xla::testing::opcode_matchers; -using ::absl::LogSeverity; -using ::absl::ScopedMockLog; -using ::testing::EndsWith; using ::testing::IsEmpty; using ::testing::Not; -using ::testing::StartsWith; using ::testing::TempDir; using GpuCompilerTest = HloTestBase; +TEST_F(GpuCompilerTest, DebugInfoManagerEnabled) { + const char* hlo_text = R"( +HloModule test + +ENTRY main { + p = f32[10]{0} parameter(0) + ROOT neg = f32[10]{0} negate(p) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + std::unique_ptr executable = + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*enable_debug_info_manager=*/true}) + .value(); + EXPECT_TRUE(XlaDebugInfoManager::Get()->TracksModule( + executable->module().unique_id())); +} + +TEST_F(GpuCompilerTest, DebugInfoManagerDisabled) { + const char* hlo_text = R"( +HloModule test + +ENTRY main { + p = f32[10]{0} parameter(0) + ROOT neg = f32[10]{0} negate(p) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + std::unique_ptr executable = + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*enable_debug_info_manager=*/false}) + .value(); + EXPECT_FALSE(XlaDebugInfoManager::Get()->TracksModule( + executable->module().unique_id())); +} + TEST_F(GpuCompilerTest, CopyInsertionFusion) { const char* hlo_text = R"( HloModule cluster diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index 3915160c065e7a..3e9ba809238283 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -133,7 +133,7 @@ Status RunGpuConvForwardActivation(const GpuConvParams& params, output_type, params.config->conv_result_scale, params.config->fusion->side_input_scale, - /* leakyrelu_alpha = */ 0.0, + params.config->fusion->leakyrelu_alpha, params.config->input_descriptor, params.config->filter_descriptor, params.config->bias_descriptor, @@ -295,15 +295,18 @@ StatusOr GetGpuConvConfig( } if (config.kind == CudnnConvKind::kForwardActivation) { - config.fusion.emplace(); - GpuConvConfig::FusionConfig& fusion = *config.fusion; if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { return InternalError("Bad activation mode: %s", backend_config.ShortDebugString()); } + + GpuConvConfig::FusionConfig fusion; fusion.mode = static_cast(backend_config.activation_mode()); fusion.side_input_scale = backend_config.side_input_scale(); + fusion.leakyrelu_alpha = backend_config.leakyrelu_alpha(); + + config.fusion = fusion; } const Window& window = desc.window; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h index 0a1d22f7073d3a..0c27d10099121e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h @@ -50,6 +50,7 @@ struct GpuConvConfig { struct FusionConfig { se::dnn::ActivationMode mode; double side_input_scale; + double leakyrelu_alpha = 0.0; }; PrimitiveType input_type; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion_test.cc index eb953cb2c395c4..9889142b22d2b1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion_test.cc @@ -204,14 +204,13 @@ fused_computation { param_1.1 = f32[2,3]{1,0} parameter(1) neg = f32[2,3]{1,0} negate(param_1.1) mul = f32[2,3]{1,0} multiply(param_0.1, neg) - transpose = f32[3,2]{1,0} transpose(neg), dimensions={1,0} - ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) tuple(mul, neg, transpose) + ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}) tuple(mul, neg) } ENTRY main { param_0 = f32[2,3]{1,0} parameter(0) param_1 = f32[2,3]{1,0} parameter(1) - ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation + ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation } )"; @@ -221,7 +220,7 @@ ENTRY main { ExpectOptionalTrue( GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {0})); // The second operand cannot share the buffer with the second fusion output, - // because the 'neg' op is also used by a non-elementwise op. + // because the 'neg' op is also used on the path to the first fusion output. ExpectOptionalFalse( GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(1), {1})); // The first operand cannot share the buffer with the second fusion output, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 6e008a43807950..10595fc9bbb5b7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -102,8 +102,9 @@ StatusOr> GpuExecutable::Create(Params params) { if (std::holds_alternative(executable)) { auto& program = std::get(executable); - TF_ASSIGN_OR_RETURN(result->gpu_runtime_executable_, - GpuRuntimeExecutable::Create(std::move(program))); + TF_ASSIGN_OR_RETURN( + result->gpu_runtime_executable_, + GpuRuntimeExecutable::Create(result->module_name_, std::move(program))); return result; } @@ -126,7 +127,8 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) verbose_buffer_assignment_string_dumper_( params.verbose_buffer_assignment_string_dumper), constants_(std::move(params.constants)), - output_info_(std::move(params.output_info)) { + output_info_(std::move(params.output_info)), + enable_debug_info_manager_(params.enable_debug_info_manager) { #if TENSORFLOW_USE_ROCM // ROCm uses hsaco hashes to distinguish between modules. // Bad things happen if multiple modules with identical code are loaded. @@ -134,14 +136,14 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) *(uint64_t*)(&binary_[binary_.size() - 16]) = tsl::EnvTime::NowNanos(); *(uint64_t*)(&binary_[binary_.size() - 8]) = tsl::random::New64(); #endif - if (has_module()) { - XlaDebugInfoManager::Get()->RegisterModule( - module().unique_id(), shared_module(), debug_buffer_assignment_); + if (has_module() && enable_debug_info_manager_) { + XlaDebugInfoManager::Get()->RegisterModule(shared_module(), + debug_buffer_assignment_); } } GpuExecutable::~GpuExecutable() { - if (has_module()) { + if (has_module() && enable_debug_info_manager_) { XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id()); } @@ -925,9 +927,12 @@ GpuExecutable::GpuExecutable( output_shape_(xla_output_shape), allocations_(std::move(allocations)), constants_(std::move(constants)), - output_info_(std::move(output_info)) { - XlaDebugInfoManager::Get()->RegisterModule( - module().unique_id(), shared_module(), debug_buffer_assignment_); + output_info_(std::move(output_info)), + enable_debug_info_manager_(true) { + if (has_module()) { + XlaDebugInfoManager::Get()->RegisterModule(shared_module(), + debug_buffer_assignment_); + } } // Returns a list of functions exported from the `module` that should be loaded @@ -1054,10 +1059,10 @@ StatusOr> GpuExecutable::LoadFromObjFile( executable.status().message()); // Move runtime::Executable ownership to the GpuRuntimeExecutable. - TF_ASSIGN_OR_RETURN( - auto gpu_runtime_executable, - GpuRuntimeExecutable::Create(buffer_sizes, std::move(*executable), - std::move(debug_options))); + TF_ASSIGN_OR_RETURN(auto gpu_runtime_executable, + GpuRuntimeExecutable::Create( + hlo_module->name(), buffer_sizes, + std::move(*executable), std::move(debug_options))); // Construct GpuExecutable for the loaded XLA Runtime executable. std::string name = hlo_module->name(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 51010a8fa3124e..830cc20914e517 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -103,6 +104,7 @@ class GpuExecutable : public Executable { }; std::unique_ptr debug_module = nullptr; + bool enable_debug_info_manager = true; }; // Analyze the entry function to construct buffer allocation and other output @@ -314,6 +316,7 @@ class GpuExecutable : public Executable { // Retains shared ownership of on-device constants that are managed by XLA and // potentially shared with other executables. std::vector> shared_constants_; + bool enable_debug_info_manager_; GpuExecutable(const GpuExecutable&) = delete; GpuExecutable& operator=(const GpuExecutable&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc index 211dcc285df083..d9f4927628b43e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 6f9cf219715754..fb62f0e2b5c947 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -241,9 +241,15 @@ SchedulerConfig GetSchedulerConfig(const GpuDeviceInfo& gpu_info) { } // GPU specific resources for latency hiding scheduler. +// +// We use two resources to model collective operations: a resource for sending +// data and a resource for receiving data. All collective operations require +// both resources while the Send and Recv operations requires only the single +// resource corresponding to the operation. enum class GpuResourceType { - kGpuAsyncStream = 0, // The async stream for collectives. - kNumTargetResources = 1, + kGpuAsyncStreamSend = 0, // The resource for sending data. + kGpuAsyncStreamRecv = 1, // The resource for receiving data. + kNumTargetResources = 2, }; // Base GPU async tracker that enables async tracking only for async collectives @@ -285,11 +291,20 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { ResourceUsageType usage = op.outer == HloOpcode::kAsyncStart ? ResourceUsageType::kResourceRelease : ResourceUsageType::kResourceOccupy; - - const int64_t gpu_stream_resource = - GetFirstTargetDefinedResource() + - static_cast(GpuResourceType::kGpuAsyncStream); - return {std::make_pair(gpu_stream_resource, usage)}; + ResourcesVector resources; + auto add_resource = [&](GpuResourceType resource_type) { + const int64_t gpu_stream_resource = GetFirstTargetDefinedResource() + + static_cast(resource_type); + resources.push_back(std::make_pair(gpu_stream_resource, usage)); + }; + + if (op.inner != HloOpcode::kRecv) { + add_resource(GpuResourceType::kGpuAsyncStreamSend); + } + if (op.inner != HloOpcode::kSend) { + add_resource(GpuResourceType::kGpuAsyncStreamRecv); + } + return resources; } return GpuAsyncTrackerBase::GetResourcesFromInstruction(instr); } @@ -304,9 +319,9 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { if (resource_type < first_target_resource) { return GpuAsyncTrackerBase::GetNumAvailableResources(resource_type); } - CHECK_EQ(resource_type, + CHECK_LT(resource_type, first_target_resource + - static_cast(GpuResourceType::kGpuAsyncStream)); + static_cast(GpuResourceType::kNumTargetResources)); // We will allow upto 1 outstanding collective on the async stream. This // controls the number of collectives in flight in the schedule (a @@ -329,8 +344,10 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { CHECK_LE(resource_type, first_target_resource + GetNumTargetDefinedResources()); switch (resource_type - first_target_resource) { - case static_cast(GpuResourceType::kGpuAsyncStream): - return "kGpuAsyncStream"; + case static_cast(GpuResourceType::kGpuAsyncStreamSend): + return "kGpuAsyncStreamSend"; + case static_cast(GpuResourceType::kGpuAsyncStreamRecv): + return "kGpuAsyncStreamRecv"; default: return "kUnsupportedResource"; } @@ -420,31 +437,37 @@ std::optional ReadPGLEProfile( return std::nullopt; } tsl::Env* env = tsl::Env::Default(); + auto read_text_or_binary_profile = [&profile, env]( + const std::string& text_path, + const std::string& binary_path) + -> std::optional { + Status s = tsl::ReadTextProto(env, text_path, &profile); + if (s.ok()) { + LOG(INFO) << "Using PGLE profile from " << text_path; + return profile; + } + profile.Clear(); + s = tsl::ReadBinaryProto(env, binary_path, &profile); + if (s.ok()) { + LOG(INFO) << "Using PGLE profile from " << binary_path; + return profile; + } + return std::nullopt; + }; + // If its a directory, use fingerprint to look for the profile for this // specific module. if (env->IsDirectory(pgle_profile_file_or_dir_path).ok()) { - std::string pgle_profile_path = - pgle_profile_file_or_dir_path + "/" + fingerprint + ".pbtxt"; - Status s = - tsl::ReadTextProto(tsl::Env::Default(), pgle_profile_path, &profile); - if (!s.ok()) { - // Unable to read PGLE using fingerprint. - return std::nullopt; - } - LOG(INFO) << "Using PGLE profile from " << pgle_profile_path; - return profile; + std::string pgle_profile_path_prefix = + pgle_profile_file_or_dir_path + "/" + fingerprint; + return read_text_or_binary_profile(pgle_profile_path_prefix + ".pbtxt", + pgle_profile_path_prefix + ".pb"); } - // The pgle_profile_file_or_dir is a file. Read the profile and see if its - // applicable for this HLO module (all instruction names in the profile should - // be present in the HLO module) - Status s = tsl::ReadTextProto(tsl::Env::Default(), - pgle_profile_file_or_dir_path, &profile); - if (s.ok()) { - LOG(INFO) << "Using PGLE profile from " << pgle_profile_file_or_dir_path; - return profile; - } - return std::nullopt; + // The pgle_profile_file_or_dir is a file. Attempt to read the profile as text + // proto or binary proto. + return read_text_or_binary_profile(pgle_profile_file_or_dir_path, + pgle_profile_file_or_dir_path); } // Return true if the profile is applicable to the module. That is true if every diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 57d9157215ffc4..c77806b89fcdf6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -501,6 +501,103 @@ TEST_F(GpuHloScheduleTest, LHSSendRecv) { EXPECT_TRUE(HasValidFingerprint(module.get())); } +// Checks that the two pairs of (Recv, RecvDone) and (Send, SendDone) do not +// interleave. +TEST_F(GpuHloScheduleTest, LHSSendRecvPairs2) { + const char* hlo_text = R"( + HloModule test + while_cond { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(25) + ROOT cond_result = pred[] compare(count, ub), direction=LT + } + + while_body { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(%param), index=0 + send-data = get-tuple-element(%param), index=1 + + after-all-0 = token[] after-all() + recv-0 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all-0), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } + send-0 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all-0), + channel_id=1, control-predecessors={recv-0}, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } + recv-done-0 = (f32[1, 1024, 1024], token[]) recv-done(recv-0), channel_id=1 + send-done-0 = token[] send-done(send-0), control-predecessors={recv-done-0}, channel_id=1 + recv-data-0 = f32[1, 1024, 1024] get-tuple-element(recv-done-0), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + replica = u32[] replica-id() + c10 = u32[] constant(10) + sum = u32[] add(replica, c10) + sum2 = u32[] add(sum, count) + conv = f32[] convert(sum2) + s1 = f32[1, 1024, 1024] broadcast(conv), dimensions={} + + after-all-1 = token[] after-all() + recv-1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all-1), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1, 0}}" + } + send-1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all-1), + channel_id=2, control-predecessors={recv-1}, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1, 0}}" + } + recv-done-1 = (f32[1, 1024, 1024], token[]) recv-done(recv-1), channel_id=2 + send-done-1 = token[] send-done(send-1), control-predecessors={recv-done-1}, channel_id=2 + recv-data-1 = f32[1, 1024, 1024] get-tuple-element(recv-done-1), index=0 + + s2 = f32[1, 1024, 1024] add(recv-data-0, s1) + s = f32[1, 1024, 1024] add(recv-data-1, s2) + + ROOT result = (u32[], f32[1, 1024, 1024]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + while_init = (u32[], f32[1, 1024, 1024]) tuple(c0, init) + while_result = (u32[], f32[1, 1024, 1024]) while(while_init), + body=while_body, condition=while_cond + ROOT entry_result = f32[1, 1024, 1024] get-tuple-element(while_result), index=1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + hlo_text, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true))); + SequentialHloOrdering order = BuildHloOrdering(module.get()); + HloComputation* while_body = module->GetComputationWithName("while_body"); + const std::vector& instruction_sequence = + order.SequentialOrder(*while_body)->instructions(); + auto get_index = [&](absl::string_view hlo_name) { + return absl::c_find_if(instruction_sequence, + [hlo_name](HloInstruction* instruction) { + return instruction->name() == hlo_name; + }) - + instruction_sequence.begin(); + }; + + EXPECT_TRUE(HasValidFingerprint(module.get())); + + EXPECT_LT(get_index("recv-1"), get_index("send-1")); + EXPECT_LT(get_index("send-1"), get_index("recv-done-1")); + EXPECT_GE(get_index("send-done-1") - get_index("send-1"), 14); + EXPECT_LT(abs(get_index("send-done-1") - get_index("result")), 2); + + EXPECT_LT(get_index("recv-done-0"), get_index("recv-1")); + EXPECT_LT(get_index("send-done-0"), get_index("send-1")); +} + class GpuHloScheduleParameterizedTest : public GpuHloScheduleTest, public ::testing::WithParamInterface {}; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc b/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc index 308230501664b5..c0a20c96e5d9d7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc @@ -138,11 +138,14 @@ std::optional EstimateThreadCount( bool use_experimental_block_size) { auto fusion = DynCast(instr); if (fusion != nullptr && cc.has_value()) { - HloFusionAnalysis fusion_analysis(fusion, &gpu_device_info, cc.value()); - auto launch_dimensions = - fusion_analysis.GetLaunchDimensions(use_experimental_block_size); - if (launch_dimensions.ok()) { - return launch_dimensions->launch_bound(); + auto analysis = + HloFusionAnalysis::Create(fusion, &gpu_device_info, cc.value()); + if (analysis.ok()) { + auto launch_dimensions = + analysis->GetLaunchDimensions(use_experimental_block_size); + if (launch_dimensions.ok()) { + return launch_dimensions->launch_bound(); + } } } return std::nullopt; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc index 250ff1c54af1cb..58215f84138c9e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc @@ -48,41 +48,6 @@ const auto kDimX = TilingScheme::DimX; const auto kLinearIndexingX = TilingScheme::LinearIndexingX; const auto kStridedIndexingX = TilingScheme::StridedIndexingX; -// Returns true if the fusion has consistent transpose heros. -bool HasConsistentTransposeHeros(HloComputation* fusion) { - std::vector hlo_roots = GetFusionRoots(fusion); - if (!HasAnyTiledTransposeRoot(fusion)) { - return false; - } - const HloInstruction* first_transpose = &FindNonTrivialHero(**absl::c_find_if( - hlo_roots, - [](HloInstruction* instr) { return FindAnyTiledTranspose(*instr); })); - const Shape& transpose_in_shape = first_transpose->operand(0)->shape(); - std::optional first_tiled_transpose = - FindAnyTiledTranspose(*first_transpose); - - // We need the following invariant: - // For every tuple element: - // -> EITHER it's a kCopy: S{L} -> S{L'} - // -> OR it's an elementwise op of shape S{L} - for (HloInstruction* root : hlo_roots) { - std::optional tiled_transpose = - FindAnyTiledTranspose(*root); - if (tiled_transpose) { - if (*tiled_transpose != *first_tiled_transpose) { - return false; - } - } else { - if (!ShapeUtil::IsReshapeOrTransposeBitcast( - root->shape(), transpose_in_shape, - /*ignore_element_type=*/true)) { - return false; - } - } - } - return true; -} - // Returns true if the fusion output contains non-strided slices only. bool IsInputFusibleNonStridedSlices(const HloInstruction* root) { if (root->opcode() == HloOpcode::kTuple) { @@ -258,13 +223,47 @@ int64_t NearestPowerOfTwo(int64_t v) { } // namespace -StatusOr -HloFusionAnalysis::GetEmitterFusionKind() const { +// Returns true if the fusion has consistent transpose heros. +bool HloFusionAnalysis::HasConsistentTransposeHeros() const { + if (!tiled_transpose_) { + return false; + } + + auto* fusion = fusion_->fused_instructions_computation(); + std::vector hlo_roots = GetFusionRoots(fusion); + const HloInstruction* first_transpose = + &FindNonTrivialHero(*root_with_tiled_transpose_); + const Shape& transpose_in_shape = first_transpose->operand(0)->shape(); + std::optional first_tiled_transpose = + FindAnyTiledTranspose(*first_transpose); + + // We need the following invariant: + // For every tuple element: + // -> EITHER it's a kCopy: S{L} -> S{L'} + // -> OR it's an elementwise op of shape S{L} + for (HloInstruction* root : hlo_roots) { + std::optional tiled_transpose = + FindAnyTiledTranspose(*root); + if (tiled_transpose) { + if (*tiled_transpose != *first_tiled_transpose) { + return false; + } + } else { + if (!ShapeUtil::IsReshapeOrTransposeBitcast( + root->shape(), transpose_in_shape, + /*ignore_element_type=*/true)) { + return false; + } + } + } + return true; +} + +HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() + const { #if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN(auto backend_config, - fusion_->backend_config()); - if (backend_config.kind() == kTritonGemmFusionKind || - backend_config.kind() == kTritonSoftmaxFusionKind) { + if (fusion_backend_config_.kind() == kTritonGemmFusionKind || + fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) { return EmitterFusionKind::kTriton; } #endif @@ -273,7 +272,8 @@ HloFusionAnalysis::GetEmitterFusionKind() const { if (HasAnyUnnestedReductionRoot(fused_computation)) { return EmitterFusionKind::kReduction; } - if (HasConsistentTransposeHeros(fused_computation)) { + // We expect that the last dimension is swapped with a different dimension. + if (HasConsistentTransposeHeros() && tiled_transpose_->permutation[2] != 2) { return EmitterFusionKind::kTranspose; } @@ -297,7 +297,7 @@ HloFusionAnalysis::GetEmitterFusionKind() const { StatusOr HloFusionAnalysis::GetLaunchDimensions( bool use_experimental_block_size) { - TF_ASSIGN_OR_RETURN(auto emitter_fusion_kind, GetEmitterFusionKind()); + auto emitter_fusion_kind = GetEmitterFusionKind(); switch (emitter_fusion_kind) { case EmitterFusionKind::kLoop: { // Disable experimental block size if few_waves or row_vectorized enabled. @@ -309,8 +309,7 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions( *loop_fusion_config); } case EmitterFusionKind::kReduction: { - TF_ASSIGN_OR_RETURN(auto reduction_codegen_info, - GetReductionCodegenInfo()); + auto* reduction_codegen_info = GetReductionCodegenInfo(); const TilingScheme& tiling_scheme = reduction_codegen_info->GetTilingScheme(); size_t blocks_y = reduction_codegen_info->GetIndexGroups().size(); @@ -321,57 +320,73 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions( /*y=*/1, /*z=*/1}); } case EmitterFusionKind::kTranspose: { - TF_ASSIGN_OR_RETURN(auto tiling_scheme, GetTransposeTilingScheme()); + auto* tiling_scheme = GetTransposeTilingScheme(); return LaunchDimensions(tiling_scheme->GetNumberOfBlocksPhysical(), tiling_scheme->GetNumThreadsPerBlockPhysical()); } - default: + case EmitterFusionKind::kInputSlices: { + auto* root = + fusion_->fused_instructions_computation()->root_instruction(); + xla::Shape shape; + if (root->opcode() == HloOpcode::kSlice) { + shape = root->operands()[0]->shape(); + } else { + CHECK_EQ(root->opcode(), HloOpcode::kTuple); + // We already verified that the shapes are compatible in + // `GetEmitterFusionKind`. + shape = root->operands()[0]->operands()[0]->shape(); + } + constexpr int kUnrollFactor = 1; + return CalculateLaunchDimensions( + shape, *device_info_, use_experimental_block_size, {kUnrollFactor}); + } + case EmitterFusionKind::kScatter: { + const auto& root_shape = fusion_->fused_instructions_computation() + ->root_instruction() + ->shape(); + int64_t num_elements = ShapeUtil::ElementsIn(root_shape); + int unroll_factor = num_elements % 4 == 0 ? 4 + : num_elements % 2 == 0 ? 2 + : 1; + return CalculateLaunchDimensions(root_shape, *device_info_, + use_experimental_block_size, + {unroll_factor, /*few_waves=*/false}); + } + case EmitterFusionKind::kTriton: return Unimplemented("GetLaunchDimensions"); } } -StatusOr -HloFusionAnalysis::GetReductionCodegenInfo() { +const ReductionCodegenInfo* HloFusionAnalysis::GetReductionCodegenInfo() { if (reduction_codegen_info_.has_value()) { return &reduction_codegen_info_.value(); } - HloInstruction* first_reduce = - *absl::c_find_if(fusion_roots_, [](HloInstruction* instr) { - return IsReductionFromOrToContiguousDimensions(*instr); - }); + HloInstruction* hero_reduction = + FindHeroReduction(absl::Span(fusion_roots_)); + CHECK_NE(hero_reduction, nullptr); - // We always use the first reduce as representative to construct - // ReductionCodegenInfo, since all the reductions are required to have the - // same shape and layout as verified by `IsFusedReductionOutputConsistent()`. - TF_ASSIGN_OR_RETURN(auto reduction_codegen_info, - ComputeReductionCodegenInfo(first_reduce)); + auto reduction_codegen_info = ComputeReductionCodegenInfo(hero_reduction); reduction_codegen_info_.emplace(std::move(reduction_codegen_info)); return &reduction_codegen_info_.value(); } -StatusOr HloFusionAnalysis::GetTransposeTilingScheme() { +const TilingScheme* HloFusionAnalysis::GetTransposeTilingScheme() { if (transpose_tiling_scheme_.has_value()) { return &transpose_tiling_scheme_.value(); } - std::optional dims_and_order = FindAnyTiledTranspose( - **absl::c_find_if(fusion_roots_, [](HloInstruction* instr) { - return FindAnyTiledTranspose(*instr); - })); - - // TODO(cheshire): have a more robust way of checking this. - TF_RET_CHECK(dims_and_order.has_value()); + if (!tiled_transpose_) { + return nullptr; + } constexpr int kNumRows = 4; - TF_RET_CHECK(WarpSize() % kNumRows == 0); + static_assert(WarpSize() % kNumRows == 0); // 3D view over the input shape. - Vector3 dims = dims_and_order->dimensions; - Vector3 order = dims_and_order->permutation; + Vector3 dims = tiled_transpose_->dimensions; + Vector3 order = tiled_transpose_->permutation; - // We expect that the last dimension is swapped with a different dimension. - TF_RET_CHECK(order[2] != 2); Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; Vector3 tile_sizes{1, 1, 1}; tile_sizes[order[2]] = WarpSize() / kNumRows; @@ -670,11 +685,11 @@ int HloFusionAnalysis::CalculateVirtualThreadScalingFactorForReduction( return 1; } -StatusOr HloFusionAnalysis::ComputeReductionCodegenInfo( - HloInstruction* first_reduce) const { - Shape input_shape = first_reduce->operand(0)->shape(); +ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( + HloInstruction* hero_reduction) const { + Shape input_shape = hero_reduction->operand(0)->shape(); ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*first_reduce); + GetReductionKindAndContiguousComponents(*hero_reduction); VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction << " " << reduction_dimensions.dimensions[0] << " " << reduction_dimensions.dimensions[1] << " " @@ -692,9 +707,9 @@ StatusOr HloFusionAnalysis::ComputeReductionCodegenInfo( // Use 512 as default block size (threads per block) for row reductions. // For multi-output fusions, reduce the block size further to decrease // register pressure when multiple outputs are computed by each thread. - int64_t max_block_size = - std::max(MinThreadsXRowReduction(), - static_cast(512LL / NearestPowerOfTwo(fan_out))); + int64_t max_block_size = std::max( + MinThreadsXRowReduction(hero_reduction->GetModule()->config()), + static_cast(512LL / NearestPowerOfTwo(fan_out))); return std::min(max_block_size, RoundUpTo(CeilOfRatio(reduction_dimensions.dimensions[2], reduction_tiling[2]), @@ -710,7 +725,8 @@ StatusOr HloFusionAnalysis::ComputeReductionCodegenInfo( int64_t shmem_usage = ProjectedShmemUsageBytes(reduction_dimensions, instr_index_groups); const int64_t shmem_budget = device_info_->shared_memory_per_block; - bool reduction_is_race_free = ReductionIsRaceFree(reduction_dimensions); + bool reduction_is_race_free = ReductionIsRaceFree( + hero_reduction->GetModule()->config(), reduction_dimensions); bool vectorize = // Vectorization might cause us to run out of budget. (shmem_usage * 2 <= shmem_budget) && @@ -772,7 +788,7 @@ StatusOr HloFusionAnalysis::ComputeReductionCodegenInfo( virtual_thread_scaling_factor); return ReductionCodegenInfo( tiling_scheme, num_partial_results, reduction_dimensions.is_row_reduction, - reduction_is_race_free, std::move(instr_index_groups), first_reduce); + reduction_is_race_free, std::move(instr_index_groups), hero_reduction); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h index e647357eb078cb..0d771a8b897d3c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h @@ -17,10 +17,13 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_FUSION_ANALYSIS_H_ #include +#include #include #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" @@ -43,38 +46,69 @@ class HloFusionAnalysis { kScatter, }; - HloFusionAnalysis(const HloFusionInstruction* fusion, - const GpuDeviceInfo* device_info, - se::CudaComputeCapability compute_capability) - : fusion_(fusion), - fused_computation_(fusion->fused_instructions_computation()), - fusion_roots_(GetFusionRoots(fusion->fused_instructions_computation())), - device_info_(device_info), - compute_capability_(compute_capability) {} + static StatusOr Create( + const HloFusionInstruction* fusion, const GpuDeviceInfo* device_info, + se::CudaComputeCapability compute_capability) { + TF_ASSIGN_OR_RETURN(auto backend_config, + fusion->backend_config()); + + auto hlo_roots = GetFusionRoots(fusion->fused_instructions_computation()); + HloInstruction* root_with_tiled_transpose; + std::optional tiled_transpose; + + for (auto* root : hlo_roots) { + if ((tiled_transpose = FindAnyTiledTranspose(*root))) { + root_with_tiled_transpose = root; + break; + } + } + + return HloFusionAnalysis(fusion, std::move(backend_config), device_info, + compute_capability, root_with_tiled_transpose, + tiled_transpose); + } - // Simple getters. const HloComputation* fused_computation() const { return fused_computation_; } absl::Span fusion_roots() const { return absl::MakeSpan(fusion_roots_); } - // Determine the fusion type for the emitter. - StatusOr GetEmitterFusionKind() const; + // Determines the fusion type for the emitter. + EmitterFusionKind GetEmitterFusionKind() const; - // Determine the launch dimensions for the fusion. + // Determines the launch dimensions for the fusion. The fusion kind must not + // be `kTriton`. StatusOr GetLaunchDimensions( bool use_experimental_block_size = false); - // Calculate reduction information (kind: kReduction). - StatusOr GetReductionCodegenInfo(); + // Calculates the reduction information. Returns `nullptr` if the fusion is + // not a reduction. + const ReductionCodegenInfo* GetReductionCodegenInfo(); - // Calculate transpose tiling information (kind: kTranspose). - StatusOr GetTransposeTilingScheme(); + // Calculates the transpose tiling information. Returns `nullptr` if the + // fusion is not a transpose. + const TilingScheme* GetTransposeTilingScheme(); - // Calculate loop fusion config (kind: kLoop). + // Calculates the loop fusion config. Returns `nullptr` if the fusion is not a + // loop. const LaunchDimensionsConfig* GetLoopFusionConfig(); private: + HloFusionAnalysis(const HloFusionInstruction* fusion, + FusionBackendConfig fusion_backend_config, + const GpuDeviceInfo* device_info, + se::CudaComputeCapability compute_capability, + HloInstruction* root_with_tiled_transpose, + std::optional tiled_transpose) + : fusion_(fusion), + fusion_backend_config_(std::move(fusion_backend_config)), + fused_computation_(fusion->fused_instructions_computation()), + fusion_roots_(GetFusionRoots(fusion->fused_instructions_computation())), + device_info_(device_info), + compute_capability_(compute_capability), + root_with_tiled_transpose_(root_with_tiled_transpose), + tiled_transpose_(tiled_transpose) {} + const Shape& GetElementShape() const; int SmallestInputDtypeBits() const; int64_t MaxBeneficialColumnReductionUnrollBasedOnBlockSize() const; @@ -88,14 +122,18 @@ class HloFusionAnalysis { bool reduction_is_race_free) const; int CalculateVirtualThreadScalingFactorForReduction( const ReductionDimensions& reduction_dimensions) const; - StatusOr ComputeReductionCodegenInfo( - HloInstruction* first_reduce) const; + ReductionCodegenInfo ComputeReductionCodegenInfo( + HloInstruction* hero_reduction) const; + bool HasConsistentTransposeHeros() const; const HloFusionInstruction* fusion_; + FusionBackendConfig fusion_backend_config_; const HloComputation* fused_computation_; std::vector fusion_roots_; const GpuDeviceInfo* device_info_; se::CudaComputeCapability compute_capability_; + HloInstruction* root_with_tiled_transpose_; + std::optional tiled_transpose_; std::optional reduction_codegen_info_; std::optional transpose_tiling_scheme_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 083b483fcfb06d..7f9f8e6593fd06 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -36,10 +35,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#ifdef GOOGLE_CUDA +#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" +#endif // GOOGLE_CUDA + namespace xla { namespace gpu { @@ -136,6 +139,23 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { return true; } +int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config) { +#ifdef GOOGLE_CUDA + auto ptxas_config = + PtxOptsFromDebugOptions(hlo_module_config.debug_options()); + auto ptxas_version_tuple = + se::GetAsmCompilerVersion(ptxas_config.preferred_cuda_dir); + // ptxas versions prior to 12.2 have a very rare bug when very high register + // spilling occurs with some order of instructions, so use less threads to + // reduce register pressure. + if (!ptxas_version_tuple.ok() || + ptxas_version_tuple.value() < std::array{12, 2, 0}) { + return 512; + } +#endif // GOOGLE_CUDA + return 1024; +} + Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { if (reduction_dimensions.is_row_reduction) { int64_t tile_z = std::min(reduction_dimensions.dimensions[0], @@ -777,11 +797,13 @@ Shape GetShape(mlir::Value value) { return {}; } -bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions) { +bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions) { Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); return (reduction_dimensions.is_row_reduction && reduction_dimensions.dimensions[2] <= - MinThreadsXRowReduction() * reduction_tiling[2] && + MinThreadsXRowReduction(hlo_module_config) * + reduction_tiling[2] && reduction_dimensions.dimensions[0] <= BatchedReductionRaceFreeBound()) || (!reduction_dimensions.is_row_reduction && @@ -975,6 +997,16 @@ bool HasAnyUnnestedReductionRoot(HloComputation* computation) { }); } +HloInstruction* FindHeroReduction(absl::Span roots) { + auto it = absl::c_find_if(roots, [](HloInstruction* instr) { + return IsReductionFromOrToContiguousDimensions(*instr); + }); + if (it == roots.end()) { + return nullptr; + } + return *it; +} + void LogAndVerify(const llvm::Module* m) { if (VLOG_IS_ON(5)) { XLA_VLOG_LINES(5, llvm_ir::DumpToString(m)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index fb72d566d87a81..29fb463544d373 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "llvm/IR/IRBuilder.h" @@ -26,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" namespace xla { namespace gpu { @@ -51,7 +49,7 @@ inline constexpr int64_t WarpSize() { return 32; } // Need at least 1024 threads/block for reasonable tree reduction // performance (assuming all data fits). -inline constexpr int64_t MinThreadsXRowReduction() { return 1024; } +int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config); // When doing batched row reduction, how big the batch dimension could be. inline constexpr int64_t BatchedReductionRaceFreeBound() { return 8; } @@ -175,7 +173,8 @@ Shape GetShape(mlir::Value value); // Returns whether the given reduction can be safely generated without atomics: // that is, at most one block will write to every output element. -bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions); +bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions); // Description of how to emit a given transposition. // @@ -228,6 +227,12 @@ std::vector GetFusionRoots(HloComputation* computation); // reduction emitter. bool HasAnyUnnestedReductionRoot(HloComputation* computation); +// Returns the hero reduction of the computation. +// We always use the first reduce root that triggers unnested reduction emitter +// as the hero reduction, since all the reductions are required to have the same +// shape and layout as verified by `IsFusedReductionOutputConsistent()`. +HloInstruction* FindHeroReduction(absl::Span roots); + const HloInstruction& FindNonTrivialHero(const HloInstruction& instr); // Whether there is a fusion root triggering transposition emitter. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index 709f3e40b52c3f..f9ddf728322bc0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/statusor.h" #include "tensorflow/tsl/platform/tensor_float_32_utils.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" @@ -412,24 +413,16 @@ Value EmitElementwise(mlir::ImplicitLocOpBuilder& b, } } -Value EmitParameter(mlir::ImplicitLocOpBuilder& b, - const HloInstruction& parameter, mlir::triton::FuncOp fn, - Value load_offsets, Value load_mask) { - Value param = fn.getArgument(parameter.parameter_number()); - mlir::ArrayRef tile_shape = - load_offsets.dyn_cast().getType().getShape(); - if (load_mask != nullptr) { - Value zeros_like = CreateConst( - b, TritonType(b, parameter.shape().element_type()), 0, tile_shape); - return b.create( - AddPtr(b, Splat(b, param, tile_shape), load_offsets), load_mask, - zeros_like, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL, - /*isVolatile=*/false); +Value EmitParameterLoad(mlir::ImplicitLocOpBuilder& b, Value tensor_pointer, + mlir::ArrayRef boundary_checks) { + std::optional padding; + if (!boundary_checks.empty()) { + padding = mt::PaddingOption::PAD_ZERO; } - return b.create( - AddPtr(b, Splat(b, param, tile_shape), load_offsets), - mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL, - /*isVolatile=*/false); + return b.create(tensor_pointer, boundary_checks, padding, + mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false); } Value EmitConstant(mlir::ImplicitLocOpBuilder& b, @@ -457,16 +450,16 @@ Value EmitBroadcast(mlir::ImplicitLocOpBuilder& b, return input; } -Value EmitScope(mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, - mlir::triton::FuncOp fn, - absl::Span instructions, - absl::flat_hash_map& values, - Value load_offsets, Value load_mask); +StatusOr EmitScope( + mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + absl::Span instructions, + absl::flat_hash_map& values, + mlir::ArrayRef tile_shape, Value tile_mask); -Value EmitReduce(mlir::ImplicitLocOpBuilder& b, - const HloInstruction& hlo_reduce, - absl::string_view libdevice_path, mlir::triton::FuncOp fn, - Value input, Value tile_mask) { +StatusOr EmitReduce(mlir::ImplicitLocOpBuilder& b, + const HloInstruction& hlo_reduce, + absl::string_view libdevice_path, Value input, + Value tile_mask) { llvm::ArrayRef input_shape = input.cast().getType().getShape(); @@ -490,12 +483,14 @@ Value EmitReduce(mlir::ImplicitLocOpBuilder& b, // reduction is computed correctly, since it is the neutral value with regards // to the reducer. Value neutral = EmitConstant(b, *hlo_reduce.operand(1)); - Value masked_input = - b.create(tile_mask, input, Splat(b, neutral, input_shape)); + if (tile_mask) { + input = b.create(tile_mask, input, + Splat(b, neutral, input_shape)); + } // Triton actually only performs reductions on float32 inputs, and we must // thus upcast/downcast our input if its data type is different. - Value casted_input = Cast(b, masked_input, b.getF32Type()); + Value casted_input = Cast(b, input, b.getF32Type()); mt::ReduceOp reduction = b.create( SmallVector({casted_input}), (int)input_shape.size() - 1); @@ -525,8 +520,9 @@ Value EmitReduce(mlir::ImplicitLocOpBuilder& b, CHECK(!to_emit.empty()); b.setInsertionPointToStart(reducer); - Value result = - EmitScope(b, libdevice_path, fn, to_emit, region_values, {}, {}); + TF_ASSIGN_OR_RETURN(Value result, + EmitScope(b, libdevice_path, to_emit, region_values, + /*tile_shape=*/{}, /*tile_mask=*/{})); b.create(SmallVector({result})); b.setInsertionPointAfter(reduction); } @@ -537,24 +533,25 @@ Value EmitReduce(mlir::ImplicitLocOpBuilder& b, // Emit sequence of instructions using compatible tiling ordered producers // before consumers. -Value EmitScope(mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, - mlir::triton::FuncOp fn, - absl::Span instructions, - absl::flat_hash_map& values, - Value load_offsets, Value load_mask) { +StatusOr EmitScope( + mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + absl::Span instructions, + absl::flat_hash_map& values, + mlir::ArrayRef tile_shape, Value tile_mask) { for (const HloInstruction* hlo : instructions) { Value result; if (hlo->opcode() == HloOpcode::kParameter) { - result = EmitParameter(b, *hlo, fn, load_offsets, load_mask); + // Parameter loads are handled outside EmitScope. + TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); + continue; } else if (hlo->opcode() == HloOpcode::kConstant) { result = EmitConstant(b, *hlo); } else if (hlo->opcode() == HloOpcode::kBroadcast) { - mlir::ArrayRef tile_shape = - load_offsets.dyn_cast().getType().getShape(); result = EmitBroadcast(b, *hlo, values[hlo->operand(0)], tile_shape); } else if (hlo->opcode() == HloOpcode::kReduce) { - result = EmitReduce(b, *hlo, libdevice_path, fn, values[hlo->operand(0)], - load_mask); + TF_ASSIGN_OR_RETURN( + result, EmitReduce(b, *hlo, libdevice_path, values[hlo->operand(0)], + tile_mask)); } else if (hlo->IsElementwise()) { std::vector operands; operands.reserve(hlo->operands().size()); @@ -563,14 +560,14 @@ Value EmitScope(mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, } result = EmitElementwise(b, libdevice_path, *hlo, operands); } else if (hlo->opcode() == HloOpcode::kTuple) { - CHECK(hlo->IsRoot()) << hlo->ToString(); + TF_RET_CHECK(hlo->IsRoot()) << hlo->ToString(); } else if (hlo->opcode() == HloOpcode::kBitcast || hlo->opcode() == HloOpcode::kReshape) { result = values[hlo->operand(0)]; } else { LOG(FATAL) << hlo->ToString(); } - CHECK(values.insert({hlo, result}).second) << hlo->ToString(); + TF_RET_CHECK(values.insert({hlo, result}).second) << hlo->ToString(); VLOG(8) << "Emitted " << hlo->ToString(); } return values[instructions.back()]; @@ -582,6 +579,7 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, const int ccAsInt = cc.major * 10 + cc.minor; // Based on optimize_ttir() in // @triton//:python/triton/compiler/compiler.py + pm.addPass(mt::createRewriteTensorPointerPass()); pm.addPass(mlir::createInlinerPass()); pm.addPass(mt::createCombineOpsPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -615,6 +613,7 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); + // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. } // Extract additional attributes from an LLVM function that are not passed @@ -712,11 +711,12 @@ StatusOr MatMulImpl( auto loc = mlir::NameLoc::get(builder.getStringAttr(dot_instr->name())); mlir::ImplicitLocOpBuilder b(loc, builder); Type i32_ty = b.getI32Type(); + Type i64_ty = b.getI64Type(); Type int_ty; if constexpr (std::is_same_v) { - int_ty = b.getI64Type(); + int_ty = i64_ty; } else { - int_ty = b.getI32Type(); + int_ty = i32_ty; } const DotDimensionNumbers& dims = dot_instr->dot_dimension_numbers(); const DotFusionAnalysis analysis(dot_instr->parent(), config.split_k()); @@ -745,12 +745,12 @@ StatusOr MatMulImpl( const bool have_batch = dims.lhs_batch_dimensions_size() - have_split_k; CHECK_EQ(dot_instr->operand(0)->shape().rank(), 2 + have_split_k + have_batch); - const int64_t lhs_noncontracting_dim_idx = + const int lhs_noncontracting_dim_idx = GetNonContractingDims(dot_instr->operand(0)->shape(), dims.lhs_batch_dimensions(), dims.lhs_contracting_dimensions()) .value()[0]; - const int64_t rhs_noncontracting_dim_idx = + const int rhs_noncontracting_dim_idx = GetNonContractingDims(dot_instr->operand(1)->shape(), dims.rhs_batch_dimensions(), dims.rhs_contracting_dimensions()) @@ -785,14 +785,12 @@ StatusOr MatMulImpl( bool lhs_nc_split = false; // Either batch size or upper part of the length of a split nc dimension. int batch_size = 1; - IndexT stride_lhs_m = 0; - IndexT stride_lhs_k = 0; IndexT stride_lhs_batch = 0; IndexT stride_rhs_batch = 0; if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).empty()) { const HloInstruction* lhs_param0 = *analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).begin(); - const DotFusionAnalysis::DimIterationSpec* lhs_nc_iter_spec = + const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = analysis.IterSpec(DotFusionAnalysis::Scope::LHS, lhs_param0, lhs_noncontracting_dim_idx); lhs_nc_split = lhs_nc_iter_spec->size() > 1; @@ -827,20 +825,12 @@ StatusOr MatMulImpl( dims.lhs_contracting_dimensions(0)) ->size(), 1); - stride_lhs_m = lhs_nc_iter_spec->at(0).stride; - stride_lhs_k = analysis - .IterSpec(DotFusionAnalysis::Scope::LHS, lhs_param0, - dims.lhs_contracting_dimensions(0)) - ->at(0) - .stride; // Just the fastest-varying part of it if the dimension is split. m = lhs_nc_iter_spec->at(0).count; } CHECK_GE(m, 1); - IndexT stride_rhs_k = 0; - IndexT stride_rhs_n = 0; if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS).empty()) { const HloInstruction* rhs_param0 = *analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS).begin(); @@ -850,16 +840,6 @@ StatusOr MatMulImpl( rhs_noncontracting_dim_idx) ->size(), 1); - stride_rhs_k = analysis - .IterSpec(DotFusionAnalysis::Scope::RHS, rhs_param0, - dims.rhs_contracting_dimensions(0)) - ->at(0) - .stride; - stride_rhs_n = analysis - .IterSpec(DotFusionAnalysis::Scope::RHS, rhs_param0, - rhs_noncontracting_dim_idx) - ->at(0) - .stride; if (have_batch) { const int64_t rhs_batch_dim_idx = *(dims.rhs_batch_dimensions().cend() - 1); @@ -874,44 +854,11 @@ StatusOr MatMulImpl( constexpr int group_m = 8; - IndexT stride_out_m = - analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, lhs_nc_out_idx) - ->at(0) - .stride; - const int64_t n = + const int n = analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, rhs_nc_out_idx) ->at(0) .count; CHECK_GE(n, 1); - IndexT stride_out_n = - analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, rhs_nc_out_idx) - ->at(0) - .stride; - IndexT stride_out_split_k = 0; - if (have_split_k) { - stride_out_split_k = - analysis - .IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, split_k_out_idx) - ->at(0) - .stride; - CHECK_GE(stride_out_split_k, 1); - } - IndexT stride_out_batch = 0; - if (have_batch) { - stride_out_batch = - analysis - .IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, batch_out_idx) - ->at(0) - .stride; - CHECK_GE(stride_out_batch, 1); - } else if (lhs_nc_split) { - // Dimension of the output produced by the non-contracting LHS one - // is physically contiguous even if the producing LHS one is split. - // Because the major part of the split is implemented using the batch - // logic stride_out_batch is populated here as the stride of the minor - // part times its size. - stride_out_batch = stride_out_m * m; - } const int block_m = config.block_m(); const int block_k = config.block_k(); @@ -978,125 +925,120 @@ StatusOr MatMulImpl( } return value; }; - auto convert_range = [&](Value value) -> Value { - if constexpr (std::is_same_v) { - auto type = mlir::RankedTensorType::get( - value.dyn_cast().getType().getShape(), int_ty); - return b.create(type, value); - } - return value; - }; auto pid_m = b.create(first_pid_m, b.create(pid_nc, group_size)); - auto pid_m_stride = + auto pid_m_offset = b.create(pid_m, CreateConst(b, i32_ty, block_m)); - // TODO(b/270351731): Consider regenerating range_m to reduce register - // pressure if we figure out how to make this optimization survive CSE. - auto range_m = - b.create(Splat(b, pid_m_stride, block_m), Range(b, block_m)); auto pid_n = b.create( b.create(pid_nc, CreateConst(b, i32_ty, width)), group_size); - auto pid_n_stride = + auto pid_n_offset = b.create(pid_n, CreateConst(b, i32_ty, block_n)); - auto range_n = - b.create(Splat(b, pid_n_stride, block_n), Range(b, block_n)); - - auto range_k = b.create( - Splat(b, b.create(pid_k, CreateConst(b, i32_ty, block_k)), - block_k), - Range(b, block_k)); - - SmallVector shape_m_1{block_m, 1}; - auto range_lhs_m = convert_range( - b.create(range_m, CreateConst(b, i32_ty, m, block_m))); - auto lhs_offsets_m = - b.create(b.create(range_lhs_m, 1), - CreateConst(b, int_ty, stride_lhs_m, shape_m_1)); - SmallVector shape_1_k{1, block_k}; - auto lhs_offsets_k = b.create( - b.create(convert_range(range_k), 0), - CreateConst(b, int_ty, stride_lhs_k, shape_1_k)); - SmallVector shape_m_k{block_m, block_k}; - auto lhs_offset_batch = b.create( - convert_scalar(pid_batch), CreateConst(b, int_ty, stride_lhs_batch)); - auto lhs_offsets_init = b.create( - Broadcast(b, lhs_offsets_m.getResult().template cast(), - shape_m_k), - Broadcast(b, lhs_offsets_k.getResult().template cast(), - shape_m_k)); - lhs_offsets_init = b.create( - lhs_offsets_init, Splat(b, lhs_offset_batch, shape_m_k)); - - SmallVector shape_k_1{block_k, 1}; - auto rhs_offsets_k = b.create( - b.create(convert_range(range_k), 1), - CreateConst(b, int_ty, stride_rhs_k, shape_k_1)); - SmallVector shape_1_n{1, block_n}; - auto range_rhs_n = convert_range( - b.create(range_n, CreateConst(b, i32_ty, n, block_n))); - auto rhs_offsets_n = - b.create(b.create(range_rhs_n, 0), - CreateConst(b, int_ty, stride_rhs_n, shape_1_n)); - SmallVector shape_k_n{block_k, block_n}; - auto rhs_offset_batch = b.create( - convert_scalar(pid_batch), CreateConst(b, int_ty, stride_rhs_batch)); - auto rhs_offsets_init = b.create( - Broadcast(b, rhs_offsets_k.getResult().template cast(), - shape_k_n), - Broadcast(b, rhs_offsets_n.getResult().template cast(), - shape_k_n)); - rhs_offsets_init = b.create( - rhs_offsets_init, Splat(b, rhs_offset_batch, shape_k_n)); - SmallVector shape_m_n{block_m, block_n}; - ma::ConstantOp accumulator_init = CreateConst(b, acc_ty, 0, shape_m_n); + + auto pid_k_offset = + b.create(pid_k, CreateConst(b, i32_ty, block_k)); + + ma::ConstantOp accumulator_init = + CreateConst(b, acc_ty, 0, {block_m, block_n}); + + // Numbers of dimensions of tensor pointers that need masking on loads or + // stores. + std::vector boundary_checks_lhs; + std::vector boundary_checks_rhs; + std::vector boundary_checks_out; + if (m % block_m != 0) { + boundary_checks_lhs.push_back(0); + boundary_checks_out.push_back(0); + } + if (k % (block_k * config.split_k()) != 0) { + boundary_checks_lhs.push_back(1); + boundary_checks_rhs.push_back(0); + } + if (n % block_n != 0) { + boundary_checks_rhs.push_back(1); + boundary_checks_out.push_back(1); + } + + // Parameters are passed to the loop in non-trivial order, this map helps + // finding them. + absl::flat_hash_map iter_args_to_parameters; auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki, - mlir::ValueRange iterArgs) { - Value lhs_offsets = iterArgs[0]; - Value rhs_offsets = iterArgs[1]; - Value accumulator = iterArgs[2]; - Value lhs_mask = nullptr; - Value rhs_mask = nullptr; + mlir::ValueRange iter_args) { + SmallVector iter_args_next; + iter_args_next.reserve(iter_args.size()); + absl::flat_hash_map values_lhs; + absl::flat_hash_map values_rhs; + // Load tiles of all parameters of LHS and RHS scopes and advance pointers. + for (int i = 0; i < iter_args.size() - 1; ++i) { + const bool is_lhs = + i < analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).size(); + const int increment_dim0 = block_k * config.split_k() * (is_lhs ? 0 : 1); + const int increment_dim1 = block_k * config.split_k() * (is_lhs ? 1 : 0); + absl::flat_hash_map& values = + is_lhs ? values_lhs : values_rhs; + CHECK(values + .insert({iter_args_to_parameters[i], + EmitParameterLoad(b, iter_args[i], + is_lhs ? boundary_checks_lhs + : boundary_checks_rhs)}) + .second); + iter_args_next.push_back(b.create( + iter_args[i].getType(), iter_args[i], + mlir::ValueRange{CreateConst(b, i32_ty, increment_dim0), + CreateConst(b, i32_ty, increment_dim1)})); + } + // TODO(b/269726484): Peel the loop instead of inserting a masked load in // every iteration, even the ones that do not need it. const bool need_masking = k % (block_k * config.split_k()) > 0; + Value lhs_mask; + Value rhs_mask; if (need_masking) { auto elements_in_tile = b.create(CreateConst(b, i32_ty, k), ki); - lhs_mask = - Broadcast(b, - b.create(ma::CmpIPredicate::slt, - b.create(range_k, 0), - Splat(b, elements_in_tile, shape_1_k)) - .getResult() - .template cast(), - shape_m_k); - rhs_mask = - Broadcast(b, - b.create(ma::CmpIPredicate::slt, - b.create(range_k, 1), - Splat(b, elements_in_tile, shape_k_1)) - .getResult() - .template cast(), - shape_k_n); + auto range_k = b.create( + Splat(b, b.create(pid_k, CreateConst(b, i32_ty, block_k)), + block_k), + Range(b, block_k)); + lhs_mask = Broadcast( + b, + b.create(ma::CmpIPredicate::slt, + b.create(range_k, 0), + Splat(b, elements_in_tile, {1, block_k})) + .getResult() + .template cast(), + {block_m, block_k}); + rhs_mask = Broadcast( + b, + b.create(ma::CmpIPredicate::slt, + b.create(range_k, 1), + Splat(b, elements_in_tile, {block_k, 1})) + .getResult() + .template cast(), + {block_k, block_n}); } - // For now use one shape for LHS inputs and one for RHS. - absl::flat_hash_map values_lhs; - Value dot_input_lhs = - EmitScope(b, libdevice_path, fn, + // Emit all operations of LHS and RHS scopes. + TF_ASSIGN_OR_RETURN( + Value dot_input_lhs, + EmitScope(b, libdevice_path, dot_instr->parent()->MakeInstructionPostOrderFrom( const_cast(*dot_instr->operand(0))), - values_lhs, lhs_offsets, lhs_mask); - absl::flat_hash_map values_rhs; - Value dot_input_rhs = - EmitScope(b, libdevice_path, fn, + values_lhs, {block_m, block_k}, lhs_mask)); + TF_ASSIGN_OR_RETURN( + Value dot_input_rhs, + EmitScope(b, libdevice_path, dot_instr->parent()->MakeInstructionPostOrderFrom( const_cast(*dot_instr->operand(1))), - values_rhs, rhs_offsets, rhs_mask); + values_rhs, {block_k, block_n}, rhs_mask)); + // Operation in the fusion before the dot can alter the elements of the + // tiles that were zero masked during loads. These have to be zeroed here + // again just before the dot so that they do not affect the output. + // Only the K dimension needs masking here because unnecessary elements in + // the other two get discarded by the masked store at the end. if (need_masking) { dot_input_lhs = b.create(lhs_mask, dot_input_lhs, ZerosLike(b, dot_input_lhs)); @@ -1104,22 +1046,90 @@ StatusOr MatMulImpl( ZerosLike(b, dot_input_rhs)); } - auto accumulator_next = b.create( - dot_input_lhs, dot_input_rhs, accumulator, + // Execute matrix multiplication of input tiles and pass the accumulator. + Value accumulator_next = b.create( + dot_input_lhs, dot_input_rhs, iter_args.back(), /*allowTF32=*/tsl::tensor_float_32_execution_enabled()); + iter_args_next.push_back(accumulator_next); - Value lhs_offsets_next = b.create( - lhs_offsets, - CreateConst(b, int_ty, block_k * config.split_k() * stride_lhs_k, - shape_m_k)); - Value rhs_offsets_next = b.create( - rhs_offsets, - CreateConst(b, int_ty, block_k * config.split_k() * stride_rhs_k, - shape_k_n)); - - b.create( - mlir::ValueRange{lhs_offsets_next, rhs_offsets_next, accumulator_next}); + b.create(iter_args_next); + return OkStatus(); }; + + // Pointers to parameters of LHS scope, then RHS, then the accumulator + // that change with every loop iteration and are passed between them. + // LHS and RHS can use same HLO computation parameters, but because they use + // different pointers they have to be stored separately for each scope. + SmallVector iter_args; + iter_args.reserve( + analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).size() + + analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS).size() + 1); + + Value lhs_offset_batch = b.create( + convert_scalar(pid_batch), CreateConst(b, int_ty, stride_lhs_batch)); + for (const HloInstruction* parameter : + analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS)) { + Value base = fn.getArgument(parameter->parameter_number()); + const int64_t stride_lhs_m = + analysis + .IterSpec(DotFusionAnalysis::Scope::LHS, parameter, + lhs_noncontracting_dim_idx) + ->at(0) + .stride; + const int64_t stride_lhs_k = + analysis + .IterSpec(DotFusionAnalysis::Scope::LHS, parameter, + dims.lhs_contracting_dimensions(0)) + ->at(0) + .stride; + Value ptrs = b.create( + /*base=*/AddPtr(b, base, lhs_offset_batch), + /*shape=*/ + mlir::ValueRange{CreateConst(b, i64_ty, m), CreateConst(b, i64_ty, k)}, + /*strides=*/ + mlir::ValueRange{CreateConst(b, i64_ty, stride_lhs_m), + CreateConst(b, i64_ty, stride_lhs_k)}, + /*offsets=*/mlir::ValueRange{pid_m_offset, pid_k_offset}, + /*tensorShape=*/std::vector{block_m, block_k}, + /*order=*/std::vector{1, 0}); + CHECK(iter_args_to_parameters.insert({iter_args.size(), parameter}).second) + << parameter->ToString(); + iter_args.push_back(ptrs); + } + + Value rhs_offset_batch = b.create( + convert_scalar(pid_batch), CreateConst(b, int_ty, stride_rhs_batch)); + for (const HloInstruction* parameter : + analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS)) { + Value base = fn.getArgument(parameter->parameter_number()); + const IndexT stride_rhs_k = + analysis + .IterSpec(DotFusionAnalysis::Scope::RHS, parameter, + dims.rhs_contracting_dimensions(0)) + ->at(0) + .stride; + const IndexT stride_rhs_n = + analysis + .IterSpec(DotFusionAnalysis::Scope::RHS, parameter, + rhs_noncontracting_dim_idx) + ->at(0) + .stride; + Value ptrs = b.create( + /*base=*/AddPtr(b, base, rhs_offset_batch), + /*shape=*/ + mlir::ValueRange{CreateConst(b, i64_ty, k), CreateConst(b, i64_ty, n)}, + /*strides=*/ + mlir::ValueRange{CreateConst(b, i64_ty, stride_rhs_k), + CreateConst(b, i64_ty, stride_rhs_n)}, + /*offsets=*/mlir::ValueRange{pid_k_offset, pid_n_offset}, + /*tensorShape=*/std::vector{block_k, block_n}, + /*order=*/std::vector{1, 0}); + CHECK(iter_args_to_parameters.insert({iter_args.size(), parameter}).second) + << parameter->ToString(); + iter_args.push_back(ptrs); + } + + iter_args.push_back(accumulator_init); Value acc_final = b.create( /*lowerBound=*/b.create(0, /*width=*/32), @@ -1127,43 +1137,85 @@ StatusOr MatMulImpl( /*step=*/ b.create(block_k * config.split_k(), /*width=*/32), - /*iterArgs=*/ - mlir::ValueRange{lhs_offsets_init, rhs_offsets_init, - accumulator_init}, - body_builder) - .getResult(2); + /*iterArgs=*/iter_args, body_builder) + .getResult(iter_args.size() - 1); absl::flat_hash_map values_out; values_out[dot_instr] = Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type())); - // Output tile offsets. - auto out_offset_batch = b.create( - convert_scalar(pid_batch), CreateConst(b, int_ty, stride_out_batch)); - auto out_offsets_m = b.create( - b.create(convert_range(range_m), 1), - CreateConst(b, int_ty, stride_out_m, shape_m_1)); - - auto out_offsets_n = b.create( - b.create(convert_range(range_n), 0), - CreateConst(b, int_ty, stride_out_n, shape_1_n)); - auto out_offsets = b.create(Splat(b, out_offset_batch, shape_m_1), - out_offsets_m); - out_offsets = b.create( - Broadcast(b, out_offsets.getResult().template cast(), - shape_m_n), - Broadcast(b, out_offsets_n.getResult().template cast(), - shape_m_n)); - - // Output tile mask: check that the indices are within [M, N]. - auto rm_cmp = b.create(ma::CmpIPredicate::slt, - b.create(range_m, 1), - CreateConst(b, i32_ty, m, shape_m_1)); - auto rn_cmp = b.create(ma::CmpIPredicate::slt, - b.create(range_n, 0), - CreateConst(b, i32_ty, n, shape_1_n)); - auto out_mask = b.create( - Broadcast(b, rm_cmp.getResult().template cast(), shape_m_n), - Broadcast(b, rn_cmp.getResult().template cast(), shape_m_n)); + // Generate tensor pointer for a parameter load or output store within the + // dot's output scope. + auto output_scope_tensor_pointer = [&](const HloInstruction* hlo, Value base, + bool add_split_k_offset) { + const IndexT stride_m = + analysis + .IterSpec(DotFusionAnalysis::Scope::OUTPUT, hlo, lhs_nc_out_idx) + ->at(0) + .stride; + { + IndexT stride_batch = 0; + if (have_batch) { + stride_batch = + analysis + .IterSpec(DotFusionAnalysis::Scope::OUTPUT, hlo, batch_out_idx) + ->at(0) + .stride; + CHECK_GE(stride_batch, 1); + } + { + const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec( + DotFusionAnalysis::Scope::OUTPUT, hlo, lhs_nc_out_idx); + if (spec->size() > 1) { + CHECK_EQ(spec->size(), 2); + // Support one specific kind of output transpose that splits the + // dimension originating from the split LHS non-contracting one. + CHECK(!have_batch); + CHECK(lhs_nc_split); + CHECK_EQ(spec->at(1).count, batch_size); + stride_batch = spec->at(1).stride; + } else if (lhs_nc_split) { + // Dimension of the output produced by the non-contracting LHS one + // is physically contiguous though the producing LHS one is split. + // Because the major part of the split is implemented using the batch + // logic stride_out_batch is populated here as the stride of the minor + // part times its size. + stride_batch = stride_m * m; + } + } + Value offset_batch = b.create( + convert_scalar(pid_batch), CreateConst(b, int_ty, stride_batch)); + base = AddPtr(b, base, offset_batch); + } + if (add_split_k_offset) { + IndexT stride_split_k = 0; + if (have_split_k) { + stride_split_k = analysis + .IterSpec(DotFusionAnalysis::Scope::OUTPUT, hlo, + split_k_out_idx) + ->at(0) + .stride; + CHECK_GE(stride_split_k, 1); + } + Value offset_split_k = b.create( + convert_scalar(pid_k), CreateConst(b, int_ty, stride_split_k)); + base = AddPtr(b, base, offset_split_k); + } + const IndexT stride_n = + analysis + .IterSpec(DotFusionAnalysis::Scope::OUTPUT, hlo, rhs_nc_out_idx) + ->at(0) + .stride; + return b.create( + /*base=*/base, + /*shape=*/ + mlir::ValueRange{CreateConst(b, i64_ty, m), CreateConst(b, i64_ty, n)}, + /*strides=*/ + mlir::ValueRange{CreateConst(b, i64_ty, stride_m), + CreateConst(b, i64_ty, stride_n)}, + /*offsets=*/mlir::ValueRange{pid_m_offset, pid_n_offset}, + /*tensorShape=*/std::vector{block_m, block_n}, + /*order=*/std::vector{1, 0}); + }; // Collect all instructions of the dot's output scope. absl::flat_hash_set to_order; @@ -1193,23 +1245,34 @@ StatusOr MatMulImpl( to_emit.push_back(hlo); } } + // Emit the output scope. if (!to_emit.empty()) { - EmitScope(b, libdevice_path, fn, to_emit, values_out, out_offsets, - out_mask); + for (const HloInstruction* parameter : + analysis.ScopeParameters(DotFusionAnalysis::Scope::OUTPUT)) { + Value tensor_pointer = output_scope_tensor_pointer( + parameter, fn.getArgument(parameter->parameter_number()), + /*add_split_k_offset=*/false); + CHECK(values_out + .insert({parameter, EmitParameterLoad(b, tensor_pointer, + boundary_checks_out)}) + .second); + } + TF_RETURN_IF_ERROR(EmitScope(b, libdevice_path, to_emit, values_out, + {block_m, block_n}, /*tile_mask=*/{}) + .status()); } - auto out_offset_split_k = b.create( - convert_scalar(pid_k), CreateConst(b, int_ty, stride_out_split_k)); - out_offsets = b.create(out_offsets, - Splat(b, out_offset_split_k, shape_m_n)); + // Emit tensor store operations for all outputs. for (int i = 0; i < fn.getNumArguments() - dot_instr->parent()->num_parameters(); ++i) { - Value out = fn.getArgument(i + dot_instr->parent()->num_parameters()); const HloInstruction* producer = root->shape().IsTuple() ? root->operand(i) : root; - b.create(AddPtr(b, Splat(b, out, shape_m_n), out_offsets), - values_out[producer], out_mask, - mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); + Value tensor_pointer = output_scope_tensor_pointer( + producer, fn.getArgument(i + dot_instr->parent()->num_parameters()), + /*add_split_k_offset=*/true); + b.create(tensor_pointer, values_out[producer], + boundary_checks_out, mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL); } return launch_dimensions; } @@ -1287,35 +1350,37 @@ StatusOr SoftMax(mlir::OpBuilder builder, for (int minor_axis = 1; minor_axis < reduce_input_shape.rank(); ++minor_axis) num_rows *= reduce_input_shape.dimensions_minor(minor_axis); - // softmax_kernel(input_ptr, output_ptr, num_rows, row_len, block_row) { - // row_index = tl.program_id(0) - // row_stride = row_len - // offset = row_index * row_stride Value row_index = b.create(mt::ProgramIDDim::X); - Value row_stride = b.create(row_len, /*width=*/32); - Value offset = b.create(row_index, row_stride); - - // row_tile = tl.arange(0, block_row) + offset - Value splat_offsets = Splat(b, offset, block_row); - Value row_tile = b.create(splat_offsets, Range(b, block_row)); - - // mask = row_tile < row_stride - Value splat_row_stride = Splat(b, row_stride, block_row); - Value mask = b.create(ma::CmpIPredicate::slt, Range(b, block_row), - splat_row_stride); + Value row_stride = CreateConst(b, b.getI32Type(), row_len); absl::flat_hash_map values_out; - Value result = - EmitScope(b, libdevice_path, fn, computation->MakeInstructionPostOrder(), - values_out, row_tile, mask); - - // tl.store(output_ptr + row_tile, result, mask=mask) - Value splat_output_ptr = Splat(b, fn.getArgument(1), block_row); - Value store_ptrs = AddPtr(b, splat_output_ptr, row_tile); + auto make_tensor_pointer = [&](Value base) { + Value offset = b.create(row_index, row_stride); + return b.create( + /*base=*/AddPtr(b, base, offset), + /*shape=*/mlir::ValueRange{CreateConst(b, b.getI64Type(), row_len)}, + /*strides=*/mlir::ValueRange{CreateConst(b, b.getI64Type(), 1)}, + /*offsets=*/mlir::ValueRange{CreateConst(b, b.getI32Type(), 0)}, + /*tensorShape=*/std::vector{block_row}, + /*order=*/std::vector{0}); + }; - b.create(store_ptrs, result, mask, mt::CacheModifier::NONE, + std::vector boundary_checks; + if (block_row != row_len) { + boundary_checks.push_back(0); + } + values_out[computation->parameter_instruction(0)] = EmitParameterLoad( + b, make_tensor_pointer(fn.getArgument(0)), boundary_checks); + Value mask = b.create(ma::CmpIPredicate::slt, Range(b, block_row), + Splat(b, row_stride, block_row)); + TF_ASSIGN_OR_RETURN( + Value result, + EmitScope(b, libdevice_path, computation->MakeInstructionPostOrder(), + values_out, {block_row}, mask)); + + b.create(make_tensor_pointer(fn.getArgument(1)), result, + std::vector{0}, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); - // } const LaunchDimensions launch_dimensions{ {num_rows, 1, 1}, {config.num_warps() * WarpSize(), 1, 1}}; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index 3b38a412b8b27b..272b3520ae7af7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -18,19 +18,25 @@ limitations under the License. #include #include #include +#include #include "llvm/IR/LLVMContext.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/autotuning.pb.h" #include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" +#include "tensorflow/tsl/platform/path.h" #include "tensorflow/tsl/platform/status_matchers.h" #include "tensorflow/tsl/platform/statusor.h" #include "tensorflow/tsl/platform/tensor_float_32_utils.h" @@ -39,6 +45,8 @@ namespace xla { namespace gpu { namespace { +namespace m = ::xla::match; + class TritonGemmNoTF32Test : public GpuCodegenTest { public: void SetUp() override { @@ -92,6 +100,35 @@ class TritonGemmTest : public GpuCodegenTest { } }; +TEST_F(TritonGemmTest, DebugOptionsArePropagated) { + const std::string kHloText = R"( +ENTRY e { + p0 = f16[30,30] parameter(0) + p1 = s8[30,30] parameter(1) + cp1 = f16[30,30] convert(p1) + ROOT _ = f16[30,30] dot(p0, cp1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + std::string output_directory; + if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) { + output_directory = tsl::testing::TmpDir(); + } + DebugOptions debug_options = verified_module->config().debug_options(); + debug_options.set_xla_dump_to(output_directory); + debug_options.set_xla_gpu_dump_llvmir(true); + verified_module->config().set_debug_options(debug_options); + + EXPECT_TRUE(RunAndCompare(std::move(verified_module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + + std::vector paths; + TF_EXPECT_OK(tsl::Env::Default()->GetMatchingPaths( + tsl::io::JoinPath(output_directory, "*.triton-passes.log"), &paths)); + EXPECT_EQ(paths.size(), 1); +} + TEST_F(TritonGemmTest, UseTensorCoresForF32OnAmpere) { const std::string kHloText = R"( HloModule t, is_scheduled=true @@ -611,7 +648,9 @@ ENTRY entry { .status()); } -TEST_F(TritonGemmTest, TritonCompilerCanFailOnConstants) { +// Triton compiler used to have an issue with reordering constants: +// https://github.com/openai/triton/issues/1864 +TEST_F(TritonGemmTest, TritonCompilerDoesNotFailOnConstants) { TF_CHECK_OK(GetOptimizedModule(R"( HloModule m, is_scheduled=true @@ -683,6 +722,337 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } +class TritonGemmLevel2Test : public TritonGemmTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_fusion_level(2); + return debug_options; + } +}; + +TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + p1 = f32[3,16] parameter(1) + p2 = f32[3,16] parameter(2) + e = f32[3,16] exponential(p1) + a = f32[3,16] add(e, p2) + c = f32[7,3] convert(p0) + ROOT d = f32[7,16] dot(c, a), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, BinaryOperationWithLargeInputsIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[333,1000] parameter(0) + p1 = f32[1000,333] parameter(1) + p1n = f32[1000,333] negate(p1) + p2 = f32[1000,333] parameter(2) + p2n = f32[1000,333] negate(p2) + s = f32[1000,333] subtract(p1n, p2n) + c = f32[333,1000] convert(p0) + ROOT d = f32[1000,1000] dot(s, c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fused_computation +; CHECK: negate +; CHECK: negate +; CHECK: ROOT +; CHECK-SAME: subtract +; CHECK: ENTRY +; CHECK: kLoop +; CHECK: kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, BinaryOperationOnLargeParametersIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[1000,111] parameter(0) + p1 = f32[111,10000] parameter(1) + p2 = f32[111,10000] parameter(2) + s = f32[111,10000] subtract(p1, p2) + c = f32[1000,111] convert(p0) + ROOT d = f32[10000,1000] dot(s, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, LinkingLibdeviceTwiceWorks) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + c0 = f32[7,3] convert(p0) + e0 = f32[7,3] exponential(c0) + p1 = f32[3,16] parameter(1) + e1 = f32[3,16] exponential(p1) + d0 = f32[7,16] dot(c0, e1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + d1 = f32[7,16] dot(e0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT a = f32[7,16] add(d0, d1) +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: kCustom +; CHECK-NEXT: kCustom +; CHECK-NEXT: ROOT +; CHECK-SAME: add +)"); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Add( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom), + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmLevel2Test, BroadcastOfScalarConstantIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[70,30] parameter(0) + p0c = f32[70,30] convert(p0) + constant_3663 = f32[] constant(4321) + bc0 = f32[30,5] broadcast(constant_3663) + ROOT d = f32[70,5] dot(p0c, bc0), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); +} + +TEST_F(TritonGemmTest, SineOutputIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,101] parameter(0) + p1 = f32[101,16] parameter(1) + c = f32[7,101] convert(p0) + d = f32[7,16] dot(c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT r = f32[7,16] sine(d) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Sin( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmLevel2Test, NarrowingConvertOutputIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[22,80] parameter(0) + p1 = f32[80,54] parameter(1) + c = f32[22,80] convert(p0) + d = f32[54,22] dot(p1, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + ROOT r = f16[54,22] convert(d) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/3e-2, /*arel=*/3e-2})); +} + +TEST_F(TritonGemmLevel2Test, ParameterAfterDotIsFused) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = bf16[350,1280]{1,0} parameter(0) + p1 = s16[1280,690]{0,1} parameter(1) + p1c = bf16[1280,690]{0,1} convert(p1) + dot.21 = bf16[350,690]{1,0} dot(p0, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = bf16[350,690]{1,0} parameter(2) + ROOT r = bf16[350,690]{1,0} multiply(p2, dot.21) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + const HloInstruction* instr = module->entry_computation()->root_instruction(); + if (!instr->IsCustomFusion()) { + instr = instr->operand(0); + ASSERT_TRUE(instr->IsCustomFusion()); + } + EXPECT_THAT( + instr, + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); +} + +TEST_F(TritonGemmLevel2Test, OutputFusionExecutesCorrectly) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[350,1280]{1,0} parameter(0) + p0c = bf16[350,1280]{1,0} convert(p0) + p1 = bf16[1280,690]{0,1} parameter(1) + d = bf16[350,690]{1,0} dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p3 = bf16[350,690]{1,0} parameter(3) + multiply.8811 = bf16[350,690]{1,0} multiply(d, p3) + neg.484 = bf16[350,690]{1,0} negate(multiply.8811) + p2 = bf16[350,690]{1,0} parameter(2) + ROOT multiply.8808 = bf16[350,690]{1,0} multiply(neg.484, p2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + const HloInstruction* instr = module->entry_computation()->root_instruction(); + if (!instr->IsCustomFusion()) { + instr = instr->operand(0); + ASSERT_TRUE(instr->IsCustomFusion()); + } + EXPECT_THAT( + instr, + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), + m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); +} + +TEST_F(TritonGemmLevel2Test, SplitLHSOutputTransposeAloneIsNotFused) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[18,15000] parameter(0) + p0c = bf16[18,15000] convert(p0) + p1 = bf16[42,18] parameter(1) + d = bf16[15000,42] dot(p0c, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + r1 = bf16[5,200,15,42] reshape(d) + ROOT t1 = bf16[5,42,200,15] transpose(r1), dimensions={0,3,1,2} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Transpose( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[5,18,20,150] parameter(0) + p0c = bf16[5,18,20,150] convert(p0) + t0 = bf16[18,5,20,150] transpose(p0c), dimensions={1,0,2,3} + r0 = bf16[18,15000] reshape(t0) + p1 = bf16[42,18] parameter(1) + d = bf16[15000,42] dot(r0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + r1 = bf16[5,20,150,42] reshape(d) + ROOT t1 = bf16[5,42,20,150] transpose(r1), dimensions={0,3,1,2} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t @@ -1600,6 +1970,13 @@ class TritonSoftmaxTest : public GpuCodegenTest { debug_options.set_xla_gpu_enable_triton_softmax_fusion(true); return debug_options; } + + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } }; TEST_F(TritonSoftmaxTest, CanFuseAndEmitExactSoftmaxF32) { @@ -1640,43 +2017,6 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(1e-6, 1e-6))); } -TEST_F(TritonSoftmaxTest, CanFuseAndEmitExactSoftmaxF32WithShortRows) { - const std::string hlo_text = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,5]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,5]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,5]{1,0} subtract(param_0, broadcast) - exponential = f32[127,5]{1,0} exponential(subtract) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,5]{1,0} broadcast(second_reduce), dimensions={0} - ROOT divide = f32[127,5]{1,0} divide(exponential, second_broadcast) -} -)"; - - MatchOptimizedHlo(hlo_text, R"( -; CHECK: ENTRY -; CHECK: %[[P0:.*]] = f32[127,5]{1,0} parameter(0) -; CHECK: ROOT -; CHECK-SAME: fusion(%[[P0]]) -; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax -)"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(1e-6, 1e-6))); -} TEST_F(TritonSoftmaxTest, CanFuseAndEmitFirstSoftmaxDiamondF16) { const std::string hlo_text = R"( @@ -1995,6 +2335,49 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(1e-6, 1e-6))); } +TEST_F( + TritonSoftmaxTest, + CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectlyForAmpereAndVoltaComputeCapability) { // NOLINT(whitespace/line_length) + const std::string hlo_text = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = bf16[127,125]{1,0} parameter(0) + param_0_f32 = f32[127,125]{1,0} convert(param_0) + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast) +} +)"; + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton_softmax +)"); + } else { + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) +; CHECK: %[[CONVERT:.*]] = f32[127,125]{1,0} convert(%[[P0]]) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[CONVERT]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton_softmax +)"); + } + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(1e-6, 1e-6))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bdce5db243a0ee..c66438cb26eff3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -75,7 +75,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" @@ -100,7 +99,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fused_mha_runner.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" @@ -139,7 +137,6 @@ limitations under the License. #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/location_exporter.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" -#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -1086,6 +1083,8 @@ Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { descriptor.kind = CudnnConvKind::kForwardActivation; fill_conv_descriptor(conv); TF_RETURN_IF_ERROR(set_activation_mode(conv)); + descriptor.backend_config.set_leakyrelu_alpha( + conv.getLeakyreluAlpha().convertToDouble()); } else if (auto conv = dyn_cast(op)) { descriptor.kind = CudnnConvKind::kForwardActivation; fill_conv_descriptor(conv); @@ -1141,11 +1140,12 @@ Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(aux, GetAllocationSlice(matmul.getAux())); } - TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan, - cublas_lt::MatmulPlan::For(matmul)); + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); + TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue epilogue, + cublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); auto thunk = std::make_unique( - GetThunkInfo(op), std::move(plan), matmul.getAlgorithm(), a, b, c, d, - bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); + GetThunkInfo(op), std::move(gemm_config), epilogue, matmul.getAlgorithm(), + a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); AddThunkToThunkSequence(std::move(thunk)); return OkStatus(); @@ -1183,12 +1183,12 @@ Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(mlir::Operation* op) { BufferAllocation::Slice aux; // Not used. - TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan, - cublas_lt::MatmulPlan::For(matmul)); - + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); + TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::Epilogue epilogue, + cublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); auto thunk = std::make_unique( - GetThunkInfo(op), std::move(plan), matmul.getAlgorithm(), a, b, c, d, - bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); + GetThunkInfo(op), std::move(gemm_config), epilogue, matmul.getAlgorithm(), + a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); AddThunkToThunkSequence(std::move(thunk)); return OkStatus(); @@ -1989,8 +1989,7 @@ Status IrEmitterUnnested::EmitLoopFusion(mlir::lmhlo::FusionOp fusion, Status IrEmitterUnnested::EmitUnnestedTranspose( mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { - TF_ASSIGN_OR_RETURN(auto tiling_scheme, - fusion_analysis.GetTransposeTilingScheme()); + auto* tiling_scheme = fusion_analysis.GetTransposeTilingScheme(); TF_ASSIGN_OR_RETURN(auto launch_dimensions, fusion_analysis.GetLaunchDimensions()); @@ -2042,12 +2041,13 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { // Create HloFusionAnalysis instance. GpuDeviceInfo device_info = ir_emitter_context_->gpu_device_info(); - HloFusionAnalysis fusion_analysis( - &fusion, &device_info, ir_emitter_context_->cuda_compute_capability()); + TF_ASSIGN_OR_RETURN(auto fusion_analysis, + HloFusionAnalysis::Create( + &fusion, &device_info, + ir_emitter_context_->cuda_compute_capability())); // Dispatch to the fusion specific emitter. - TF_ASSIGN_OR_RETURN(auto emitter_fusion_kind, - fusion_analysis.GetEmitterFusionKind()); + auto emitter_fusion_kind = fusion_analysis.GetEmitterFusionKind(); switch (emitter_fusion_kind) { case HloFusionAnalysis::EmitterFusionKind::kTriton: { #if GOOGLE_CUDA @@ -2079,9 +2079,9 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { case HloFusionAnalysis::EmitterFusionKind::kTranspose: return EmitUnnestedTranspose(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kInputSlices: - return EmitInputFusibleNonStridedSlices(op); + return EmitInputFusibleNonStridedSlices(op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: - return EmitScatter(fusion_op, fused_computation); + return EmitScatter(fusion_op, fused_computation, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { // Special case: DUS bool is_single = IsSingleInstructionFusion(fusion_op); @@ -2639,6 +2639,10 @@ Status IrEmitterUnnested::EmitScatter( index.GetType()); int64_t operand_dim = desc.dim_numbers.getScatterDimsToOperandDims()[i]; + if (operand_dim > rank) { + return absl::OutOfRangeError( + "The provided scatter_dims_to_operand_dims was out of range."); + } TF_ASSIGN_OR_RETURN( llvm::Value* const loaded_scatter_index, desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( @@ -2727,6 +2731,7 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region, TF_ASSIGN_OR_RETURN( module, HloModule::CreateFromProto(xla_computation.proto(), HloModuleConfig(program_shape))); + module->config().set_debug_options(hlo_module_config_.debug_options()); if (is_fusion) { HloComputation* fused_computation = module->entry_computation(); @@ -4870,8 +4875,7 @@ Status IrEmitterUnnested::EmitIRForReduction( Status IrEmitterUnnested::EmitUnnestedReduction( mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { - TF_ASSIGN_OR_RETURN(auto reduction_codegen_info, - fusion_analysis.GetReductionCodegenInfo()); + auto* reduction_codegen_info = fusion_analysis.GetReductionCodegenInfo(); TF_ASSIGN_OR_RETURN(auto launch_dimensions, fusion_analysis.GetLaunchDimensions()); @@ -5041,25 +5045,19 @@ Status IrEmitterUnnested::EmitElementForInputFusibleSlices( } Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( - mlir::Operation* op) { + mlir::Operation* op, HloFusionAnalysis& fusion_analysis) { auto fusion = mlir::cast(op); - constexpr int unroll_factor = 1; - TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, GetOrCreateSubComputationFromRegion(&fusion.getRegion(), /*is_fusion=*/true)); - TF_ASSIGN_OR_RETURN(Shape element_shape, - GetConsistentInputShapeForRootSlices(fused_computation)); bool use_experimental_block_size = hlo_module_config_.debug_options() .xla_gpu_enable_experimental_block_size(); - - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), - use_experimental_block_size, {unroll_factor})); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + fusion_analysis.GetLaunchDimensions(use_experimental_block_size)); TF_ASSIGN_OR_RETURN( std::optional> opt_ir_arrays, @@ -5070,6 +5068,8 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( } std::vector& ir_arrays = opt_ir_arrays.value(); + TF_ASSIGN_OR_RETURN(Shape element_shape, + GetConsistentInputShapeForRootSlices(fused_computation)); return ParallelLoopEmitter( [&](const llvm_ir::IrArray::Index index) -> Status { return EmitElementForInputFusibleSlices(fused_computation, @@ -5176,7 +5176,8 @@ Status IrEmitterUnnested::EmitDynamicUpdateSlice( } Status IrEmitterUnnested::EmitScatter(mlir::lmhlo::FusionOp fusion_op, - const HloComputation* fused_computation) { + const HloComputation* fused_computation, + HloFusionAnalysis& fusion_analysis) { auto* root = fused_computation->root_instruction(); // The initialization from 'operand' is using different loop bounds, so @@ -5188,12 +5189,9 @@ Status IrEmitterUnnested::EmitScatter(mlir::lmhlo::FusionOp fusion_op, TF_RETURN_IF_ERROR([&] { auto unroll_factor = ComputeMaxUnrollFactor(fusion_op); - const Shape& element_shape = root->shape(); TF_ASSIGN_OR_RETURN( LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), - use_experimental_block_size, {unroll_factor, /*few_waves=*/false})); + fusion_analysis.GetLaunchDimensions(use_experimental_block_size)); TF_ASSIGN_OR_RETURN( std::optional> opt_ir_arrays, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index d4a39d19a35f01..2440ede416299c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -507,7 +507,8 @@ class IrEmitterUnnested : public IrEmitter { // different. On the other hand, the input ranges of slices can be // overlapping. Further generalization/specialization when the needs are seen // in the future. - Status EmitInputFusibleNonStridedSlices(mlir::Operation* op); + Status EmitInputFusibleNonStridedSlices(mlir::Operation* op, + HloFusionAnalysis& fusion_analysis); Status EmitElementForInputFusibleSlices( const HloComputation* fused_computation, @@ -558,7 +559,8 @@ class IrEmitterUnnested : public IrEmitter { const LaunchDimensions& launch_dimensions); Status EmitScatter(mlir::lmhlo::FusionOp fusion_op, - const HloComputation* fused_computation); + const HloComputation* fused_computation, + HloFusionAnalysis& fusion_analysis); Status EmitDynamicUpdateSlice(mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation); diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index 3c4b24b596d6d3..90ead8a3f7c815 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -17,12 +17,8 @@ limitations under the License. #include #include -#include -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/tsl/platform/logging.h" namespace xla { namespace gpu { @@ -37,7 +33,7 @@ std::ostream& operator<<(std::ostream& out, return out; } -static int64_t ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) { +static int64_t ThreadsPerBlockLimit(const GpuDeviceInfo& gpu_device_info) { int64_t threads_per_block = gpu_device_info.threads_per_block_limit; if (threads_per_block <= 0) { static std::atomic log_count{0}; @@ -57,7 +53,7 @@ static int64_t ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) { } int64_t ThreadsPerBlockRowVectorized(const Shape& shape, - GpuDeviceInfo gpu_device_info, + const GpuDeviceInfo& gpu_device_info, LaunchDimensionsConfig dim_config) { if (shape.dimensions().empty()) { return -1; @@ -79,8 +75,8 @@ int64_t ThreadsPerBlockRowVectorized(const Shape& shape, } StatusOr CalculateLaunchDimensionsImplExperimental( - const Shape& shape, GpuDeviceInfo gpu_device_info, - LaunchDimensionsConfig dim_config, mlir::Operation* op) { + const Shape& shape, const GpuDeviceInfo& gpu_device_info, + LaunchDimensionsConfig dim_config) { int64_t num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); @@ -101,7 +97,7 @@ StatusOr CalculateLaunchDimensionsImplExperimental( } StatusOr CalculateLaunchDimensionsImpl( - const Shape& shape, GpuDeviceInfo gpu_device_info, + const Shape& shape, const GpuDeviceInfo& gpu_device_info, LaunchDimensionsConfig dim_config) { int64_t num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { @@ -204,13 +200,12 @@ StatusOr CalculateLaunchDimensionsImpl( } StatusOr CalculateLaunchDimensions( - const Shape& shape, GpuDeviceInfo gpu_device_info, - bool use_experimental_block_size, LaunchDimensionsConfig dim_config, - mlir::Operation* op) { - if (use_experimental_block_size && op != nullptr) { + const Shape& shape, const GpuDeviceInfo& gpu_device_info, + bool use_experimental_block_size, LaunchDimensionsConfig dim_config) { + if (use_experimental_block_size) { VLOG(2) << "Experimental block size is enabled"; return CalculateLaunchDimensionsImplExperimental(shape, gpu_device_info, - dim_config, op); + dim_config); } return CalculateLaunchDimensionsImpl(shape, gpu_device_info, dim_config); } diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index 2140219fafee6d..95228825403b8f 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/shape.h" @@ -131,14 +130,13 @@ struct LaunchDimensionsConfig { // Returns -1 if the shape doesn't allow the row vectorization code path. // If supported, return the number of threads to use in that case. int64_t ThreadsPerBlockRowVectorized(const Shape& shape, - GpuDeviceInfo gpu_device_info, + const GpuDeviceInfo& gpu_device_info, LaunchDimensionsConfig dim_config); // Calculates the launch dimensions used to invoke `hlo`. StatusOr CalculateLaunchDimensions( - const Shape& shape, GpuDeviceInfo gpu_device_info, - bool use_experimental_block_size, LaunchDimensionsConfig dim_config = {}, - mlir::Operation* op = nullptr); + const Shape& shape, const GpuDeviceInfo& gpu_device_info, + bool use_experimental_block_size, LaunchDimensionsConfig dim_config = {}); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc index e237aa2e73ecf1..88fc88b8b97c22 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc @@ -26,7 +26,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape.h" @@ -36,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/blas.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/numeric_options.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.h b/tensorflow/compiler/xla/service/gpu/matmul_utils.h index a708943037e773..52b12ecb665f04 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.h +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.h @@ -120,6 +120,42 @@ struct GemmConfig { double alpha_imag, double beta, std::optional algorithm, int64_t compute_precision); + template ::value || + std::is_same::value>> + static StatusOr For(CublasLtMatmulMaybeF8Op op) { + mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); + + int64_t compute_precision = 0; // Default + if (op.getPrecisionConfig().has_value()) { + auto precision_config = op.getPrecisionConfig(); + for (auto attr : precision_config.value()) { + int64_t value = static_cast( + attr.template cast().getValue()); + if (value > compute_precision) { + compute_precision = value; + } + } + } + + Shape bias_shape; + if (op.getBias() != nullptr) { + bias_shape = GetShape(op.getBias()); + } + return GemmConfig::For( + GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), + dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), + dot_dims.getRhsBatchingDimensions(), + dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), + op.getBias() == nullptr ? nullptr : &bias_shape, GetShape(op.getD()), + op.getAlphaReal().convertToDouble(), + op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), + op.getAlgorithm(), compute_precision); + } + MatrixLayout lhs_layout; MatrixLayout rhs_layout; MatrixLayout c_layout; @@ -171,48 +207,6 @@ StatusOr AsBlasLtEpilogue( class MatmulPlan { public: - template ::value || - std::is_same::value>> - static StatusOr For(CublasLtMatmulMaybeF8Op op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } - } - - Shape bias_shape; - if (op.getBias() != nullptr) { - bias_shape = GetShape(op.getBias()); - } - TF_ASSIGN_OR_RETURN( - GemmConfig config, - GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - op.getBias() == nullptr ? nullptr : &bias_shape, - GetShape(op.getD()), op.getAlphaReal().convertToDouble(), - op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), - op.getAlgorithm(), compute_precision)); - - TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue epilogue, - AsBlasLtEpilogue(op.getEpilogue())); - return From(config, epilogue); - } - static StatusOr From(const GemmConfig& config, se::gpu::BlasLt::Epilogue epilogue); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index ed050a66028ee8..2bccdceadb15c9 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" -#include - +#include #include #include #include @@ -38,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/float_normalization.h" #include "tensorflow/compiler/xla/service/float_support.h" #include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" #include "tensorflow/compiler/xla/service/gpu/cublas_pad_for_gemms.h" #include "tensorflow/compiler/xla/service/gpu/cublas_padding_requirements.h" @@ -49,7 +49,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -625,7 +624,8 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( return cache_value->cubin_data; } -static bool UseNvlink(const std::string& preferred_cuda_dir) { +static std::optional> GetNvLinkVersion( + const std::string& preferred_cuda_dir) { const bool use_nvlink_by_default = #ifdef TF_DISABLE_NVLINK_BY_DEFAULT false; @@ -638,13 +638,66 @@ static bool UseNvlink(const std::string& preferred_cuda_dir) { use_nvlink_by_default, &use_nvlink)); if (!use_nvlink) { - return false; + return std::nullopt; } // Make sure nvlink exists and is executable. const std::string bin_path = se::FindCudaExecutable("nvlink", preferred_cuda_dir); - return se::GetToolVersion(bin_path).ok(); + auto version = se::GetToolVersion(bin_path); + if (!version.ok()) { + return std::nullopt; + } + return *version; +} + +StatusOr NVPTXCompiler::ChooseLinkingMethod( + const std::string& preferred_cuda_dir) { + { + absl::MutexLock lock(&mutex_); + auto it = linking_methods_.find(preferred_cuda_dir); + if (it != linking_methods_.end()) { + return it->second; + } + } + + LinkingMethod linking_method = LinkingMethod::kNone; + TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple, + se::GetAsmCompilerVersion(preferred_cuda_dir)); + + static const std::optional> nvlink_version = + GetNvLinkVersion(preferred_cuda_dir); + if (nvlink_version && *nvlink_version >= ptxas_version_tuple) { + linking_method = LinkingMethod::kNvLink; + } else { + int ptxas_version = std::get<0>(ptxas_version_tuple) * 1000 + + std::get<1>(ptxas_version_tuple) * 10; + int driver_version; + if (!se::gpu::GpuDriver::GetDriverVersion(&driver_version)) { + return FailedPrecondition("Unable to get CUDA driver version"); + } + bool ok = driver_version >= ptxas_version; + if (!ok) { + LOG_FIRST_N(WARNING, 1) + << "The NVIDIA driver's CUDA version is " + << absl::StrFormat("%d.%d", driver_version / 1000, + (driver_version % 1000) / 10) + << " which is older than the ptxas CUDA version " + << absl::StrFormat("(%d.%d.%d)", std::get<0>(ptxas_version_tuple), + std::get<1>(ptxas_version_tuple), + std::get<2>(ptxas_version_tuple)) + << ". Because the driver is older than the ptxas version, XLA is " + "disabling parallel compilation, which may slow down compilation. " + "You should update your NVIDIA driver or use the NVIDIA-provided " + "CUDA forward compatibility packages."; + } + linking_method = LinkingMethod::kDriver; + } + { + absl::MutexLock lock(&mutex_); + linking_methods_[preferred_cuda_dir] = linking_method; + } + return linking_method; } StatusOr NVPTXCompiler::CanUseLinkModules( @@ -653,37 +706,9 @@ StatusOr NVPTXCompiler::CanUseLinkModules( // robust if we simply tried to link something the first time we compile. auto ptxas_config = PtxOptsFromDebugOptions(hlo_module_config.debug_options()); - - static const bool use_nvlink = UseNvlink(ptxas_config.preferred_cuda_dir); - if (use_nvlink) { - return true; - } - - TF_ASSIGN_OR_RETURN( - auto ptxas_version_tuple, - se::GetAsmCompilerVersion(ptxas_config.preferred_cuda_dir)); - int ptxas_version = std::get<0>(ptxas_version_tuple) * 1000 + - std::get<1>(ptxas_version_tuple) * 10; - int driver_version; - if (!se::gpu::GpuDriver::GetDriverVersion(&driver_version)) { - return FailedPrecondition("Unable to get CUDA driver version"); - } - bool ok = driver_version >= ptxas_version; - if (!ok) { - LOG_FIRST_N(WARNING, 1) - << "The NVIDIA driver's CUDA version is " - << absl::StrFormat("%d.%d", driver_version / 1000, - (driver_version % 1000) / 10) - << " which is older than the ptxas CUDA version " - << absl::StrFormat("(%d.%d.%d)", std::get<0>(ptxas_version_tuple), - std::get<1>(ptxas_version_tuple), - std::get<2>(ptxas_version_tuple)) - << ". Because the driver is older than the ptxas version, XLA is " - "disabling parallel compilation, which may slow down compilation. " - "You should update your NVIDIA driver or use the NVIDIA-provided " - "CUDA forward compatibility packages."; - } - return ok; + TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, + ChooseLinkingMethod(ptxas_config.preferred_cuda_dir)); + return linking_method != LinkingMethod::kNone; } StatusOr> NVPTXCompiler::LinkModules( @@ -698,7 +723,10 @@ StatusOr> NVPTXCompiler::LinkModules( } auto context = static_cast( stream_exec->implementation()->GpuContextHack()); - if (UseNvlink(ptxas_config.preferred_cuda_dir)) { + + TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, + ChooseLinkingMethod(ptxas_config.preferred_cuda_dir)); + if (linking_method == LinkingMethod::kNvLink) { return LinkUsingNvlink(debug_options.xla_gpu_cuda_data_dir(), context, images); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index d4bde5bca73d31..500037d083ec35 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -16,12 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ -#include -#include #include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/statusor.h" @@ -87,15 +86,16 @@ class NVPTXCompiler : public GpuCompiler { absl::Mutex mutex_; - // When compiling an HLO module, we need to find a path to the nvvm libdevice - // files. We search in the module's config.debug_options().cuda_data_dir() - // and in tensorflow::LibdeviceRoot(), the latter of which is a constant. - // - // We cache the cuda_data_dir() and the result of our search, so that if the - // next module we have to compile has the same cuda_data_dir(), we can skip - // the search. - std::string cached_cuda_data_dir_ ABSL_GUARDED_BY(mutex_); - std::string cached_libdevice_dir_ ABSL_GUARDED_BY(mutex_); + enum class LinkingMethod { + kNone, + kNvLink, + kDriver, + }; + absl::flat_hash_map linking_methods_ + ABSL_GUARDED_BY(mutex_); + + StatusOr ChooseLinkingMethod( + const std::string& preferred_cuda_dir); // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index 812d193b66f91d..cda878e91f68ef 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -81,7 +81,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/service/gpu:gpu_conv_algorithm_picker", + "//tensorflow/compiler/xla/service/gpu:conv_algorithm_picker", ]), ) diff --git a/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc b/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc index e1ea335ba179fd..47780a2613e578 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc @@ -193,10 +193,10 @@ absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options, NcclCollectiveThunk::GetDeviceString(params); auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets, replica_group_values, is_async); - if (!comm.ok()) return ToAbslStatus(comm.status()); + if (!comm.ok()) return comm.status(); auto device_buffers = device_buffers_getter(args); - if (!device_buffers.ok()) return ToAbslStatus(device_buffers.status()); + if (!device_buffers.ok()) return device_buffers.status(); if (device_buffers->size() != 1) { return absl::InternalError(absl::StrFormat( "Expected device buffer size: 1, got %d", device_buffers->size())); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc b/tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc index 3fed1c6ed07687..fb11172d7f702f 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc @@ -43,14 +43,22 @@ ConcurrentRegionStatus::~ConcurrentRegionStatus() { DCHECK(!IsInConcurrentRegion()); } +// Assign a stream in a round-robin fashion. Either the capture stream or one of +// the borrowed streams is returned. se::Stream* ConcurrentRegionStatus::GetNextStream() { DCHECK(IsInConcurrentRegion()); if (borrowed_streams_.empty()) { return nullptr; } - int index = stream_index_ % borrowed_streams_.size(); + + int index = stream_index_ % (borrowed_streams_.size() + 1); stream_index_++; - return borrowed_streams_[index].get(); + + if (index == 0) { + return capture_stream_; + } + + return borrowed_streams_[index - 1].get(); } absl::Status ConcurrentRegionStatus::StartConcurrentRegion( @@ -68,10 +76,9 @@ absl::Status ConcurrentRegionStatus::StartConcurrentRegion( } } - // Switch borrowed streams into capture mode. If the number of kernel launches - // in the region is less than the number of borrowed streams, only synchronize - // enough streams to run the kernels. - for (int i = 0; i < std::min(size, num_borrowed_streams_); ++i) { + // Switch borrowed streams into capture mode. We only synchronize enough + // streams to run the kernels. + for (int i = 0; i < std::min(size - 1, num_borrowed_streams_); ++i) { borrowed_streams_[i]->ThenWaitFor(capture_stream); } @@ -84,7 +91,7 @@ void ConcurrentRegionStatus::EndConcurrentRegion() { DCHECK(IsInConcurrentRegion()); // Synchronize main capture stream with all borrowed streams in capture mode. - for (int i = 0; i < std::min(region_size_, num_borrowed_streams_); + for (int i = 0; i < std::min(region_size_ - 1, num_borrowed_streams_); ++i) { capture_stream_->ThenWaitFor(borrowed_streams_[i].get()); } diff --git a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc index 13a093c4831fa4..ccc8bbf79cb1ec 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc @@ -38,7 +38,7 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #endif namespace xla { @@ -228,6 +228,10 @@ struct SideInputAttrs { double side_input_scale; }; +struct LeakyReluAlphaAttrs { + double leaky_relu_alpha; +}; + } // namespace static GpuConvDescriptor GetConvDescriptor( @@ -239,7 +243,8 @@ static GpuConvDescriptor GetConvDescriptor( ConvDimensionNumbers dims, Window w, ConvBackendConfig b, ConvAttrs attrs, // Conv-specific arguments and attributes std::optional fused = std::nullopt, - std::optional side_input = std::nullopt) { + std::optional side_input = std::nullopt, + std::optional leakyrelu_alpha = std::nullopt) { // Build a convolution descriptor from the attributes. GpuConvDescriptor descriptor; descriptor.kind = kind; @@ -313,6 +318,11 @@ static GpuConvDescriptor GetConvDescriptor( if (fused.has_value()) descriptor.backend_config.set_activation_mode(fused->activation_mode); + // Set attributes specific for fused convolutions with leaky_relu_alpha. + if (leakyrelu_alpha.has_value()) + descriptor.backend_config.set_leakyrelu_alpha( + leakyrelu_alpha->leaky_relu_alpha); + // Set attributes specific for convolutions with side input. if (side_input.has_value()) descriptor.backend_config.set_side_input_scale( @@ -344,7 +354,8 @@ static absl::Status ConvImpl( int64_t feature_group_count, double result_scale, // Optional attributes for fused convolutions. std::optional activation_mode = std::nullopt, - std::optional side_input_scale = std::nullopt) { + std::optional side_input_scale = std::nullopt, + std::optional leakyrelu_alpha = std::nullopt) { // Build config for optional attributes. std::optional fused_attrs = std::nullopt; if (activation_mode.has_value()) fused_attrs = {*activation_mode}; @@ -352,6 +363,9 @@ static absl::Status ConvImpl( std::optional side_input_attrs = std::nullopt; if (side_input_scale.has_value()) side_input_attrs = {*side_input_scale}; + std::optional leakyrelu_alpha_attrs = std::nullopt; + if (leakyrelu_alpha.has_value()) leakyrelu_alpha_attrs = {*leakyrelu_alpha}; + bool runtime_autotuning = false; if (backend_config.algorithm == -1) { // Set the algorithm back to the default algorithm to avoid error from @@ -369,7 +383,7 @@ static absl::Status ConvImpl( {window_strides, padding, lhs_dilation, rhs_dilation, window_reversal}, backend_config, {feature_group_count, result_scale}, fused_attrs, - side_input_attrs); + side_input_attrs, leakyrelu_alpha_attrs); TF_ASSIGN_OR_RETURN(GpuConvConfig conv_config, GetGpuConvConfig(descriptor, "")); @@ -404,7 +418,7 @@ static absl::Status ConvImpl( TF_ASSIGN_OR_RETURN( AutotuneResult best_algo, conv_algorithm_picker.PickBestAlgorithmWithAllocatedBuffer( - gpu_conv_config, run_options, debug_options, buffers, + config, gpu_conv_config, run_options, *debug_options, buffers, result_buffer)); // Set algorithm in the convolution runner state. @@ -495,6 +509,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL_TEMPLATE( ) .Value(std::optional()) // activation_mode .Value(std::optional()) // side_input_scale + .Value(std::optional()) // leaky_relu_alpha ); XLA_RUNTIME_DEFINE_CUSTOM_CALL( @@ -513,7 +528,8 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Arg() // scratch ) .Attr("activation_mode") - .Value(std::optional()) // side_input_scale + .Value(std::optional()) // side_input_scale + .Attr("leakyrelu_alpha") // leaky_relu_alpha ); XLA_RUNTIME_DEFINE_CUSTOM_CALL( @@ -532,7 +548,8 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Arg() // scratch ) .Attr("activation_mode") - .Attr("side_input_scale")); + .Attr("side_input_scale") + .Value(std::optional())); // leaky_relu_alpha //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc index 80e328df306f06..13e5a6fb6da717 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/runtime/custom_call.h" #include "tensorflow/compiler/xla/service/gpu/runtime/fft.h" #include "tensorflow/compiler/xla/service/gpu/runtime/gemm.h" +#include "tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h" #include "tensorflow/compiler/xla/service/gpu/runtime/io_feed.h" #include "tensorflow/compiler/xla/service/gpu/runtime/memcpy.h" #include "tensorflow/compiler/xla/service/gpu/runtime/memset.h" @@ -137,13 +138,22 @@ void RegisterXlaGpuAttrEncoding(CustomCallAttrEncodingSet& encoding) { //===----------------------------------------------------------------------===// +// Executable can have only one "main" function and only graph capture function. +static int64_t GetNumGraphs(const runtime::Executable& executable) { + return executable.num_functions() - 1; +} + GpuRuntimeExecutable::GpuRuntimeExecutable( - std::vector buffer_sizes, + std::string module_name, std::vector buffer_sizes, std::unique_ptr jit_executable, DebugOptions debug_options, ModulesState modules_state, FfiModulesState ffi_modules_state) - : buffer_sizes_(std::move(buffer_sizes)), + : module_name_(std::move(module_name)), + buffer_sizes_(std::move(buffer_sizes)), executable_(std::move(jit_executable)), debug_options_(std::move(debug_options)), +#if GOOGLE_CUDA + graph_instances_(module_name_, GetNumGraphs(executable())), +#endif // GOOGLE_CUDA modules_state_(std::move(modules_state)), ffi_modules_state_(std::move(ffi_modules_state)) { ExportModules(dynamic_custom_calls_); // export runtime modules @@ -151,12 +161,16 @@ GpuRuntimeExecutable::GpuRuntimeExecutable( } GpuRuntimeExecutable::GpuRuntimeExecutable( - std::vector buffer_sizes, + std::string module_name, std::vector buffer_sizes, std::unique_ptr aot_executable, DebugOptions debug_options, ModulesState modules_state, FfiModulesState ffi_modules_state) - : buffer_sizes_(std::move(buffer_sizes)), + : module_name_(std::move(module_name)), + buffer_sizes_(std::move(buffer_sizes)), executable_(std::move(aot_executable)), debug_options_(std::move(debug_options)), +#if GOOGLE_CUDA + graph_instances_(module_name_, GetNumGraphs(executable())), +#endif // GOOGL_CUDA modules_state_(std::move(modules_state)), ffi_modules_state_(std::move(ffi_modules_state)) { ExportModules(dynamic_custom_calls_); // export runtime modules @@ -168,8 +182,9 @@ GpuRuntimeExecutable::GpuRuntimeExecutable( //===----------------------------------------------------------------------===// /*static*/ StatusOr> -GpuRuntimeExecutable::Create(std::unique_ptr program) { - // Options for the default XLA Runtim compilation pipeline. +GpuRuntimeExecutable::Create(std::string module_name, + std::unique_ptr program) { + // Options for the default XLA Runtime compilation pipeline. runtime::CompilationPipelineOptions copts; // Populate mapping from XLA (SE) enums/structs type id to symbol names. @@ -222,7 +237,7 @@ GpuRuntimeExecutable::Create(std::unique_ptr program) { ffi_modules_state.status().message()); return std::unique_ptr(new GpuRuntimeExecutable( - std::move(program->buffer_sizes), + std::move(module_name), std::move(program->buffer_sizes), std::make_unique(std::move(*jit_executable)), std::move(program->debug_options), std::move(*modules_state), std::move(*ffi_modules_state))); @@ -233,7 +248,8 @@ GpuRuntimeExecutable::Create(std::unique_ptr program) { //===----------------------------------------------------------------------===// /*static*/ StatusOr> -GpuRuntimeExecutable::Create(absl::Span buffer_sizes, +GpuRuntimeExecutable::Create(std::string module_name, + absl::Span buffer_sizes, Executable executable, DebugOptions debug_options) { // Instantiate state for all registered runtime modules. @@ -249,6 +265,7 @@ GpuRuntimeExecutable::Create(absl::Span buffer_sizes, ffi_modules_state.status().message()); return std::unique_ptr(new GpuRuntimeExecutable( + std::move(module_name), std::vector(buffer_sizes.begin(), buffer_sizes.end()), std::make_unique(std::move(executable)), std::move(debug_options), std::move(*modules_state), @@ -274,7 +291,7 @@ static void InitializeCallFrame(runtime::Executable::CallFrame& call_frame, assert(ptrs.empty() && "pointers storage must be empty"); ptrs.resize_for_overwrite(num_allocations); - // Each buffer allocation pased as 1d memref to the compiled function: + // Each buffer allocation passed as 1d memref to the compiled function: // {basePtr, dataPtr, offset, [sizes, ...], [strides, ...]} size_t num_args_ptrs = 1 + num_allocations * 5; call_frame.args.resize_for_overwrite(num_args_ptrs); @@ -413,6 +430,35 @@ Status GpuRuntimeExecutable::Execute( return InternalError("Failed to initialize runtime modules state: %s", state_ref.status().message()); +#if GOOGLE_CUDA + // Instantiate all CUDA graphs before executing the main function. + if (debug_options_.xla_gpu_cuda_graph_num_runs_to_instantiate() < 0 && + !graph_instances_.InstantiatedAllGraphs(run_options, executable)) { + // To instantiate all Gpu graphs we have to pass a valid device pointer + // because some device operations in XLA (e.g. memcpy) query device + // information from a pointer. We have to find the largest allocation + // available, to guarantee that all memref slices are within bounds, + // otherwise we might get crashes from a Gpu driver. + void* device_ptr = temp_buffer.opaque(); + size_t device_ptr_size = temp_buffer.size(); + + for (unsigned i = 0; i < buffer_allocations.size(); ++i) { + auto mem = buffer_allocations.GetDeviceAddress(i); + if (mem.size() > device_ptr_size) { + device_ptr = mem.opaque(); + device_ptr_size = mem.size(); + } + } + + if (auto instantiated = graph_instances_.InstantiateAllGraphs( + run_options, executable, user_data, device_ptr); + !instantiated.ok()) { + return InternalError("Failed to instantiate CUDA graphs: %s", + instantiated.message()); + } + } +#endif // GOOGLE_CUDA + // Collect all emitted diagnostic messages. std::string diagnostic; runtime::DiagnosticEngine diagnostic_engine; diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.h b/tensorflow/compiler/xla/service/gpu/runtime/executable.h index 114c3655711ccd..0405c86db315b6 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.h @@ -93,12 +93,12 @@ class GpuRuntimeExecutable { public: // Creates GpuRuntimeExecutable from the Xla Gpu Program. static StatusOr> Create( - std::unique_ptr program); + std::string module_name, std::unique_ptr program); // Creates GpuRuntimeExecutable from the AOT compiled binary. static StatusOr> Create( - absl::Span buffer_sizes, runtime::Executable executable, - DebugOptions debug_options); + std::string module_name, absl::Span buffer_sizes, + runtime::Executable executable, DebugOptions debug_options); // Executes entry function with the given buffer arguments. Status Execute(const ServiceExecutableRunOptions* run_options, @@ -115,17 +115,23 @@ class GpuRuntimeExecutable { // Returns MLIR module behind this executable if it is available. StatusOr GetMlirModule() const; + std::string_view module_name() const { return module_name_; } + private: - GpuRuntimeExecutable(std::vector buffer_sizes, + GpuRuntimeExecutable(std::string module_name, + std::vector buffer_sizes, std::unique_ptr jit_executable, DebugOptions debug_options, ModulesState modules_state, FfiModulesState ffi_modules_state); - GpuRuntimeExecutable(std::vector buffer_sizes, + GpuRuntimeExecutable(std::string module_name, + std::vector buffer_sizes, std::unique_ptr aot_executable, DebugOptions debug_options, ModulesState modules_state, FfiModulesState ffi_modules_state); + std::string module_name_; + // Depending on the state of `executable_` returns a reference to active // Xla runtime executable. runtime::Executable& executable(); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc index 5b18bf5ca6120c..f9df7a72503163 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h" +#include #include #include #include @@ -52,18 +53,178 @@ using xla::runtime::Arguments; using xla::runtime::AsyncTaskRunner; using xla::runtime::CustomCall; using xla::runtime::Executable; +using xla::runtime::FunctionRef; +using xla::runtime::FunctionType; using xla::runtime::MemrefDesc; -using xla::runtime::ScalarArg; +using xla::runtime::MemrefType; using xla::runtime::StridedMemrefView; +#if GOOGLE_CUDA +using se::gpu::OwnedCudaGraph; + +// Captures Gpu graph by running given function in capture mode. +static absl::StatusOr CaptureGraph( + const ServiceExecutableRunOptions* run_options, + runtime::FunctionRef function_ref, Arguments& args, + CustomCall::UserData user_data); +#endif // GOOGLE_CUDA + //===----------------------------------------------------------------------===// // CUDA graphs caching. //===----------------------------------------------------------------------===// +static absl::Mutex* GetGraphInstancesMutex() { + static auto* mu = new absl::Mutex(); + return mu; +} + +// Keep track of instantiated graphs on each StreamExecutor, we use this +// information in the graph eviction policy. +using GraphInstancesState = absl::flat_hash_map; + +static GraphInstancesState& GetGraphInstancesState() { + static auto* state = new GraphInstancesState(); + return *state; +} + +static int64_t NotifyGraphInstancesCreated(se::StreamExecutor* executor, + int64_t num_graphs) { + absl::MutexLock lock(GetGraphInstancesMutex()); + return GetGraphInstancesState()[executor] += num_graphs; +} + +static int64_t NotifyGraphInstancesDestroyed(se::StreamExecutor* executor, + int64_t num_graphs) { + absl::MutexLock lock(GetGraphInstancesMutex()); + return GetGraphInstancesState()[executor] -= num_graphs; +} + +GraphInstances::GraphInstances(std::string module_name, int64_t num_graphs) + : impl_(std::make_shared()) { + impl_->module_name = std::move(module_name); + impl_->num_graphs = num_graphs; + VLOG(3) << "Construct graph instances cache for: @" << impl_->module_name + << " (num_graphs = " << impl_->num_graphs << ")"; +} + +GraphInstances::~GraphInstances() { + VLOG(3) << "Destroy graph instances cache for: @" << impl_->module_name + << " (num_graphs = " << impl_->num_graphs << ")"; + + absl::MutexLock lock(&impl_->mu); + for (auto& [executor, state] : impl_->graphs) { + VLOG(3) << "Destroy " << impl_->num_graphs << " graphs for: @" + << impl_->module_name << " at executor: " << executor + << ". Total remaining graphs at given executor: " + << NotifyGraphInstancesDestroyed(executor, impl_->num_graphs); + } +} + StreamExecutorGraphInstances* GraphInstances::operator()( se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &graphs_[executor]; + absl::MutexLock lock(&impl_->mu); + + auto it = impl_->graphs.try_emplace(executor); + if (it.second && impl_->num_graphs > 0) { + VLOG(3) << "Instantiate " << impl_->num_graphs << " graphs for: @" + << impl_->module_name << " at executor: " << executor + << ". Total graphs at given executor: " + << NotifyGraphInstancesCreated(executor, impl_->num_graphs); + } + + State& state = it.first->second; + state.last_use_micros = tsl::Env::Default()->NowMicros(); + return &state.instances; +} + +bool GraphInstances::InstantiatedAllGraphs( + const ServiceExecutableRunOptions* run_options, + const Executable& executable) { + if (executable.num_functions() == 1) return true; + + absl::MutexLock lock(&impl_->mu); + return impl_->graphs[run_options->stream()->parent()].instantiated; +} + +Status GraphInstances::InstantiateAllGraphs( + const ServiceExecutableRunOptions* run_options, + const Executable& executable, const CustomCall::UserData& user_data, + void* ptr) { + // We have only "main" function in the executable. + if (executable.num_functions() == 1) return OkStatus(); + + absl::MutexLock lock(&impl_->mu); + se::StreamExecutor* executor = run_options->stream()->parent(); + + State& state = impl_->graphs[executor]; + + // All Gpu graphs are already instantiated for a given executor. + if (state.instantiated) return OkStatus(); + + TraceMe trace("cuda.graph.instantiate_all"); + + // Initialize graph instances snapshot for a given executor. + StreamExecutorGraphInstances::Snapshot instances = state.instances.snapshot(); + + // Instantiate all Gpu graphs by calling graph capture functions with fake + // arguments. Once we'll execute them first time for real, they'll be updated + // with correct pointers. + for (unsigned ordinal = 1; ordinal < executable.num_functions(); ++ordinal) { + if (!absl::StartsWith(executable.function_name(ordinal), + "xla.gpu.cuda.graph.capture")) + continue; + + VLOG(3) << "Instantiate Gpu graph defined by capture function @" + << executable.function_name(ordinal) << " (ordinal = " << ordinal + << ")"; + + TraceMe trace_instantiation([&] { + return TraceMeEncode("cuda.graph.instantiate", {{"ordinal", ordinal}}); + }); + + FunctionRef function_ref = executable.function_ref(ordinal); + + const FunctionType& signature = executable.signature(ordinal); + assert(signature.num_results() == 0 && "unexpected number of results"); + Arguments args(signature.num_operands()); + + // Prepare arguments for the graph capture function. + for (size_t j = 0; j < signature.num_operands(); ++j) { + auto* memref = llvm::dyn_cast(signature.operand(j)); + + if (!memref) + return absl::InternalError(absl::StrFormat( + "Unsupported capture function argument type #%d", j)); + + if (memref->sizes().size() != 1) + return absl::InternalError( + absl::StrFormat("Unsupported capture function memref rank #%d: %d", + j, memref->sizes().size())); + + std::array sizes = {memref->size(0)}; + std::array strides = {1}; + + args.emplace_back(memref->element_type(), ptr, + /*offset=*/0, sizes, strides); + } + +#if GOOGLE_CUDA + // Instantiate a Gpu graph with fake arguments. + auto instantiate = [&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN( + auto g, CaptureGraph(run_options, function_ref, args, user_data)); + TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateCudaGraph(std::move(g))); + return GraphInstance(0, std::move(e)); + }; + + TF_ASSIGN_OR_RETURN(GraphInstance * instance, + instances.GetOrCreate(ordinal, instantiate)); + (void)instance; +#endif // GOOGLE_CUDA + } + + state.instantiated = true; + return OkStatus(); } CapturedFunctionExecutionCount* CapturedFunctionExecutionCounts::operator()( @@ -99,8 +260,6 @@ H AbslHashValue(H h, const RemainingArgsPtrs& m) { #if GOOGLE_CUDA -using se::gpu::OwnedCudaGraph; - static bool InDebugMode() { #ifdef NDEBUG return false; @@ -108,9 +267,26 @@ static bool InDebugMode() { return true; } +// Forwards custom call arguments to an arguments container that can be passed +// to an executable function. +static absl::Status ForwardArguments(CustomCall::RemainingArgs fwd_args, + Arguments& args) { + for (size_t i = 0; i < fwd_args.size(); ++i) { + if (auto memref = fwd_args.get(i); succeeded(memref)) { + args.emplace_back(memref->dtype, memref->data, /*offset=*/0, + memref->sizes, memref->strides); + continue; + } + + return absl::InvalidArgumentError("Unsupported argument type"); + } + + return OkStatus(); +} + static absl::StatusOr CaptureGraph( const ServiceExecutableRunOptions* run_options, - runtime::FunctionRef function_ref, CustomCall::RemainingArgs fwd_args, + runtime::FunctionRef function_ref, Arguments& args, CustomCall::UserData user_data) { // We capture graph on a borrowed stream because we do not want to // accidentally record any concurrent kernel launches from other XLA @@ -162,29 +338,6 @@ static absl::StatusOr CaptureGraph( // Graph capture function should not launch any async tasks. opts.async_task_runner = reinterpret_cast(0XDEADBEEF); - // Graph capture functions can only have index arguments for launch - // dimensions, or memrefs for passing buffers. We need to re-package custom - // call arguments into a container that can be passed to an executable - // function. - Arguments args(fwd_args.size()); - - for (size_t i = 0; i < fwd_args.size(); ++i) { - // `index` argument passed as int64_t. - if (auto idx = fwd_args.get(i); succeeded(idx)) { - args.emplace_back(*idx); - continue; - } - - // Pass `memref` argument as a MemrefDesc. - if (auto memref = fwd_args.get(i); succeeded(memref)) { - args.emplace_back(memref->dtype, memref->data, /*offset=*/0, - memref->sizes, memref->strides); - continue; - } - - return absl::InvalidArgumentError("Unsupported argument type"); - } - // Create a graph from running the graph capture function. auto captured = se::gpu::CaptureCudaGraph(capture_stream->get(), [&]() { return function_ref(args, runtime::NoResultConverter{}, opts, @@ -223,32 +376,15 @@ static absl::Status RunGraphWithoutCapture( // Graph capture function should not launch any async tasks. opts.async_task_runner = reinterpret_cast(0XDEADBEEF); - Arguments args(fwd_args.size()); - - for (size_t i = 0; i < fwd_args.size(); ++i) { - // `index` argument passed as int64_t. - if (auto idx = fwd_args.get(i); succeeded(idx)) { - args.emplace_back(*idx); - continue; - } - - // Pass `memref` argument as a MemrefDesc. - if (auto memref = fwd_args.get(i); succeeded(memref)) { - args.emplace_back(memref->dtype, memref->data, /*offset=*/0, - memref->sizes, memref->strides); - continue; - } - - return absl::InvalidArgumentError("Unsupported argument type"); - } + Arguments args(fwd_args.size()); + TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - auto status = - function_ref(args, runtime::NoResultConverter{}, opts, InDebugMode()) - .status(); - if (!status.ok()) { + auto executed = + function_ref(args, runtime::NoResultConverter{}, opts, InDebugMode()); + if (!executed.ok()) { return InternalError("RunGraphWithoutCapture failed (%s): %s", diagnostic.empty() ? "" : diagnostic, - status.ToString()); + executed.status().ToString()); } return absl::OkStatus(); } @@ -272,7 +408,7 @@ static absl::Status LaunchGraph( ConcurrentRegionStatus* region_status, CustomCall::RemainingArgs fwd_args, CustomCall::FunctionOrdinal capture) { #if GOOGLE_CUDA - VLOG(1) << "Launch Cuda Graph: capture=" << capture.ordinal; + VLOG(1) << "Launch Cuda Graph: ordinal = " << capture.ordinal; // Get a reference to exported function that captures the cuda graph. runtime::FunctionRef function_ref = executable->function_ref(capture.ordinal); @@ -287,15 +423,13 @@ static absl::Status LaunchGraph( gemm_config, gpu_lock, region_status); }; - TF_ASSIGN_OR_RETURN( - std::unique_ptr> * get_count, - counts->GetOrCreate( - capture.ordinal, - []() -> absl::StatusOr>> { - return std::make_unique>(0); - })); - uint64_t count = (*get_count)->fetch_add(1); - uint64_t instantiation_threshold = + TF_ASSIGN_OR_RETURN(std::unique_ptr> * get_count, + counts->GetOrCreate(capture.ordinal, [] { + return std::make_unique>(0); + })); + + int64_t count = (*get_count)->fetch_add(1); + int64_t num_runs_to_instantiate = debug_options->xla_gpu_cuda_graph_num_runs_to_instantiate(); // TODO(ezhulenev): Cupti tracing leads to deadlocks in CUDA 11. Always fall @@ -306,23 +440,27 @@ static absl::Status LaunchGraph( bool is_profiling = tsl::profiler::ScopedAnnotationStack::IsEnabled(); #endif - if (count < instantiation_threshold || is_profiling) { - // Run captured graph directly. + if (count < num_runs_to_instantiate || is_profiling) { + VLOG(3) << "Run gpu graph in op-by-op mode: ordinal = " << capture.ordinal; return RunGraphWithoutCapture(run_options, function_ref, fwd_args, user_data()); } - TF_ASSIGN_OR_RETURN( - GraphInstance * instance, - instances->GetOrCreate( - capture.ordinal, [&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(auto g, CaptureGraph(run_options, function_ref, - fwd_args, user_data())); + // Instantiate Gpu graph by running graph capture function. + auto instantiate = [&]() -> absl::StatusOr { + Arguments args(fwd_args.size()); + TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - TF_ASSIGN_OR_RETURN(auto e, - se::gpu::InstantiateCudaGraph(std::move(g))); - return GraphInstance(ptrs_hash, std::move(e)); - })); + TF_ASSIGN_OR_RETURN( + auto g, CaptureGraph(run_options, function_ref, args, user_data())); + + TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateCudaGraph(std::move(g))); + + return GraphInstance(ptrs_hash, std::move(e)); + }; + + TF_ASSIGN_OR_RETURN(GraphInstance * instance, + instances->GetOrCreate(capture.ordinal, instantiate)); { // Lock graph instance for read only access. If we'll have to update the @@ -343,9 +481,13 @@ static absl::Status LaunchGraph( // Otherwise we have to re-capture the graph and update the graph instance. VLOG(3) << "Update cached graph instance"; + + Arguments args(fwd_args.size()); + TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); + // Capture CUDA graph by running capture function. TF_ASSIGN_OR_RETURN( - auto g, CaptureGraph(run_options, function_ref, fwd_args, user_data())); + auto g, CaptureGraph(run_options, function_ref, args, user_data())); // At this point we have to grab a writer lock, because we might potentially // have concurrent execution of the cached graph instance. diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h index 7171862da99a23..6a183409f17855 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h @@ -18,14 +18,15 @@ limitations under the License. #include #include -#include -#include -#include +#include #include #include "absl/container/node_hash_map.h" #include "tensorflow/compiler/xla/runtime/custom_call_registry.h" +#include "tensorflow/compiler/xla/runtime/executable.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #if GOOGLE_CUDA #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h" @@ -79,14 +80,56 @@ class StreamExecutorGraphInstances #endif // #if GOOGLE_CUDA // Xla executable keeps a mapping from stream executors to graph instances. +// +// Graph instances allocate on-device memory, so we periodically destroy +// them to free up some space on device. JAX for example keeps all XLA +// executables alive, and destroys them when the process shuts down, so we can +// end up with thousands of unused (or rarely used) graphs in device memory. class GraphInstances { public: + GraphInstances(std::string module_name, int64_t num_graphs); + ~GraphInstances(); + StreamExecutorGraphInstances* operator()(se::StreamExecutor* executor); + // Instantiates all Gpu graphs defined by the given executable using user + // provided run options. This guarantees that once we start execution, all Gpu + // graphs are ready, and will only require cheap update operation and will not + // require allocating new resources (we avoid non deterministic OOM errors). + Status InstantiateAllGraphs(const ServiceExecutableRunOptions* run_options, + const runtime::Executable& executable, + const runtime::CustomCall::UserData& user_data, + void* ptr); + + // Returns true if all Gpu graphs were already instantiated. + bool InstantiatedAllGraphs(const ServiceExecutableRunOptions* run_options, + const runtime::Executable& executable); + private: - mutable absl::Mutex mutex_; - absl::node_hash_map graphs_ - ABSL_GUARDED_BY(mutex_); + struct State { + // A flag signalling if `InstantiateAllGraphs` was already called and we + // have all Gpu graph instantiated ahead of time. + bool instantiated = false; + + // Last time graph instances were used by a particular stream executor. + uint64_t last_use_micros = 0; + + StreamExecutorGraphInstances instances; + }; + + struct Impl { + // XLA module name that owns graph instances. We use it only to produce logs + // that can be attributed back to XLA executables. + std::string module_name; + + // Number of graphs in the parent module. + int64_t num_graphs; + + mutable absl::Mutex mu; + absl::node_hash_map graphs ABSL_GUARDED_BY(mu); + }; + + std::shared_ptr impl_; }; // Xla executable keeps a mapping from stream executors to execution counts. diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc index ee203f7a1bb98b..3b686af91a2631 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -42,29 +43,23 @@ bool HasDefaultLayout(const Shape& shape) { LayoutUtil::IsMonotonicWithDim0Major(shape.layout()); } -bool IsSupportedReductionComputation(HloComputation* computation) { - static const absl::flat_hash_set* const kSupportedOpcodes = - new absl::flat_hash_set{HloOpcode::kAdd, HloOpcode::kMaximum}; - - HloInstruction* root = computation->root_instruction(); - if (root->operand_count() != 2 || - root->operand(0)->opcode() != HloOpcode::kParameter || - root->operand(1)->opcode() != HloOpcode::kParameter) { - return false; - } - return kSupportedOpcodes->contains(root->opcode()); -} - -bool IsTritonSupportedInstruction(const HloInstruction* instr) { +bool IsTritonSupportedInstruction(const HloInstruction* instr, + const GpuVersion& gpu_version) { // TODO(bchetioui): expand with non-trivial instructions. if (instr->IsElementwise()) { + if (instr->opcode() == HloOpcode::kConvert && + (instr->operand(0)->shape().element_type() == BF16 || + instr->shape().element_type() == BF16) && + !std::get(gpu_version) + .IsAtLeast(stream_executor::CudaComputeCapability::AMPERE)) { + return false; + } return IsTritonSupportedElementwise(instr->opcode(), instr->shape().element_type()); } switch (instr->opcode()) { case HloOpcode::kBitcast: - case HloOpcode::kConvert: case HloOpcode::kParameter: return true; default: @@ -77,11 +72,16 @@ bool IsTritonSupportedInstruction(const HloInstruction* instr) { // set to it. The definition of "trivial" operations is as given in // 'IsTriviallyFusible'. bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, - HloOpcode opcode); + HloOpcode opcode, const GpuVersion& gpu_version); -bool BitcastIsTilingNoop(HloInstruction* bitcast) { +bool BitcastIsTilingNoop(HloInstruction* bitcast, + const GpuVersion& gpu_version) { CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast); + if (ShapeUtil::IsEffectiveScalar(bitcast->shape())) { + return true; + } + // In the Softmax rewriter for now, tiling is derived from a hero reduction // operation, which should be reducing its input on the last axis. Therefore, // a bitcast is always a no-op with regards to a tile if @@ -97,7 +97,8 @@ bool BitcastIsTilingNoop(HloInstruction* bitcast) { }; HloInstruction* reduce = nullptr; - TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce); + TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce, + gpu_version); return (HasDefaultLayout(bitcast->shape()) && HasDefaultLayout(bitcast->operand(0)->shape()) && @@ -105,7 +106,8 @@ bool BitcastIsTilingNoop(HloInstruction* bitcast) { last_dimension(bitcast->operand(0)) == last_dimension(bitcast))); } -bool IsTriviallyFusible(HloInstruction* instr, int num_allowed_users = 1) { +bool IsTriviallyFusible(HloInstruction* instr, const GpuVersion& gpu_version, + int num_allowed_users = 1) { // Checks whether an op is trivially fusible. An op is said to be trivially // fusible if it does not increase the amount of memory read/written by the // resulting fusion, is compatible with any chosen tiling, and can be @@ -116,21 +118,22 @@ bool IsTriviallyFusible(HloInstruction* instr, int num_allowed_users = 1) { return false; } - if (instr->opcode() == HloOpcode::kBitcast && BitcastIsTilingNoop(instr)) { + if (instr->opcode() == HloOpcode::kBitcast && + BitcastIsTilingNoop(instr, gpu_version)) { return true; } if (instr->IsElementwise() && instr->operand_count() == 1) { - return IsTritonSupportedInstruction(instr); + return IsTritonSupportedInstruction(instr, gpu_version); } return false; } bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, - HloOpcode opcode) { + HloOpcode opcode, const GpuVersion& gpu_version) { while (consumer->opcode() != opcode) { - if (IsTriviallyFusible(consumer)) { + if (IsTriviallyFusible(consumer, gpu_version)) { consumer = consumer->mutable_operand(0); } else { return false; @@ -142,18 +145,20 @@ bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, } bool IsTriviallyConnectedProducerOf(HloInstruction* producer, - HloInstruction* consumer) { + HloInstruction* consumer, + const GpuVersion& gpu_version) { if (producer == consumer) { return true; } HloInstruction* found_producer = consumer; - while (TrivialEdge(&found_producer, consumer, producer->opcode())) { + while ( + TrivialEdge(&found_producer, consumer, producer->opcode(), gpu_version)) { if (found_producer == producer) { return true; } - if (!IsTriviallyFusible(found_producer)) { + if (!IsTriviallyFusible(found_producer, gpu_version)) { return false; } @@ -167,9 +172,10 @@ inline bool HasOneUse(const HloInstruction* instr) { return instr->user_count() == 1; } -bool IsTritonSupportedComputation(const HloComputation* computation) { +bool IsTritonSupportedComputation(const HloComputation* computation, + const GpuVersion& gpu_version) { for (const HloInstruction* instr : computation->instructions()) { - if (!IsTritonSupportedInstruction(instr)) { + if (!IsTritonSupportedInstruction(instr, gpu_version)) { return false; } } @@ -177,7 +183,7 @@ bool IsTritonSupportedComputation(const HloComputation* computation) { } std::optional MatchesTritonCompatibleClosedReductionDiamond( - HloInstruction* instr) { + HloInstruction* instr, const GpuVersion& gpu_version) { // Return the producer of the following pattern: // // producer @@ -197,7 +203,8 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( // array. std::optional match_failure = std::nullopt; - if (!instr->IsElementwiseBinary() || !IsTritonSupportedInstruction(instr)) { + if (!instr->IsElementwiseBinary() || + !IsTritonSupportedInstruction(instr, gpu_version)) { return match_failure; } @@ -206,13 +213,13 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( HloInstruction* reduce; if (!(TrivialEdge(&broadcast, instr->mutable_operand(1), - HloOpcode::kBroadcast) && - TrivialEdge(&reduce, broadcast->mutable_operand(0), - HloOpcode::kReduce) && + HloOpcode::kBroadcast, gpu_version) && + TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, + gpu_version) && HasDefaultLayout(broadcast->shape()) && HasDefaultLayout(reduce->shape()) && reduce->operand_count() == 2 && reduce->operand(1)->opcode() == HloOpcode::kConstant && - IsTritonSupportedComputation(reduce->to_apply()))) { + IsTritonSupportedComputation(reduce->to_apply(), gpu_version))) { return match_failure; } @@ -229,12 +236,19 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( return match_failure; } - while (IsTriviallyFusible(producer)) { + // TODO(b/291204753): remove this filter. This heuristic enables flipping the + // default flag while filtering out cases that could result in regressions. + if (reduce->operand(0)->shape().dimensions().back() < 64) { + return match_failure; + } + + while (IsTriviallyFusible(producer, gpu_version)) { producer = producer->mutable_operand(0); } if (!HasDefaultLayout(producer->shape()) || - !IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0)) || + !IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), + gpu_version) || !(producer == instr->operand(0) || instr->operand(0)->user_count() == 1)) { return match_failure; @@ -250,10 +264,11 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( // that instruction is used more than once, and/or is not trivially // fusible. HloInstruction* FindFirstNonFusibleDiamondProducer( - HloInstruction* diamond_producer) { - if (IsTriviallyFusible(diamond_producer, /*num_allowed_users=*/2)) { + HloInstruction* diamond_producer, const GpuVersion& gpu_version) { + if (IsTriviallyFusible(diamond_producer, gpu_version, + /*num_allowed_users=*/2)) { diamond_producer = diamond_producer->mutable_operand(0); - while (IsTriviallyFusible(diamond_producer)) { + while (IsTriviallyFusible(diamond_producer, gpu_version)) { diamond_producer = diamond_producer->mutable_operand(0); } } @@ -343,8 +358,8 @@ StatusOr SoftmaxRewriterTriton::Run( continue; } - if (auto producer = - MatchesTritonCompatibleClosedReductionDiamond(instr)) { + if (auto producer = MatchesTritonCompatibleClosedReductionDiamond( + instr, gpu_version_)) { matched_diamonds.push_back(DiamondDescriptor{instr, producer.value()}); } } @@ -367,9 +382,9 @@ StatusOr SoftmaxRewriterTriton::Run( return instr->operand(0)->shape().dimensions(operand_rank - 1); }; - auto last_trivially_fusible_user = [](HloInstruction* instr) { + auto last_trivially_fusible_user = [&](HloInstruction* instr) { while (HasOneUse(instr) && !instr->IsRoot() && - IsTriviallyFusible(instr->users().front())) { + IsTriviallyFusible(instr->users().front(), gpu_version_)) { instr = instr->users().front(); } @@ -378,7 +393,7 @@ StatusOr SoftmaxRewriterTriton::Run( // restriction. if (HasOneUse(instr) && !instr->IsRoot() && IsTriviallyFusible( - instr->users().front(), + instr->users().front(), gpu_version_, /*num_allowed_users=*/instr->users().front()->user_count())) { instr = instr->users().front(); } @@ -402,8 +417,8 @@ StatusOr SoftmaxRewriterTriton::Run( // Crucially, this approach relies on a diamond root never being considered a // trivially fusible operation. std::vector diamond_chains; - HloInstruction* current_fusion_producer = - FindFirstNonFusibleDiamondProducer(matched_diamonds.front().producer); + HloInstruction* current_fusion_producer = FindFirstNonFusibleDiamondProducer( + matched_diamonds.front().producer, gpu_version_); int current_reduce_dimension_size = reduction_dimension_size_from_diamond_root(matched_diamonds.front().root); @@ -414,7 +429,7 @@ StatusOr SoftmaxRewriterTriton::Run( matched_diamonds[diamond_idx - 1].root; HloInstruction* first_non_fusible_diamond_producer = - FindFirstNonFusibleDiamondProducer(diamond_producer); + FindFirstNonFusibleDiamondProducer(diamond_producer, gpu_version_); int diamond_reduce_dimension_size = reduction_dimension_size_from_diamond_root(diamond_root); diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc index 77f49d8a423b6c..4c55e6908734ae 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -14,6 +14,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -576,17 +577,17 @@ add_computation { ROOT add = f32[] add(arg_0.1, arg_1.1) } ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) + param_0 = f32[127,625]{1,0} parameter(0) constant_neg_inf = f32[] constant(-inf) reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - bitcasted_subtract = f32[127,5,25] bitcast(subtract) - exponential = f32[127,5,25] exponential(bitcasted_subtract) + broadcast = f32[127,625]{1,0} broadcast(reduce), dimensions={0} + subtract = f32[127,625]{1,0} subtract(param_0, broadcast) + bitcasted_subtract = f32[127,5,125] bitcast(subtract) + exponential = f32[127,5,125] exponential(bitcasted_subtract) constant_zero = f32[] constant(0) second_reduce = f32[127,5] reduce(exponential, constant_zero), dimensions={2}, to_apply=add_computation - second_broadcast = f32[127,5,25] broadcast(second_reduce), dimensions={0,1} - ROOT divide = f32[127,5,25] divide(exponential, second_broadcast) + second_broadcast = f32[127,5,125] broadcast(second_reduce), dimensions={0,1} + ROOT divide = f32[127,5,125] divide(exponential, second_broadcast) } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); @@ -788,6 +789,68 @@ ENTRY main { EXPECT_FALSE(fusion_rewriter.Run(module.get()).value()); } +TEST_F(SoftmaxRewriterTritonTest, DoNotFuseSoftmaxWithSmallRows) { + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = f32[127,50]{1,0} parameter(0) + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f32[127,50]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,50]{1,0} subtract(param_0, broadcast) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + SoftmaxRewriterTriton fusion_rewriter(gpu_version_); + EXPECT_FALSE(fusion_rewriter.Run(module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + CanOnlyFuseConvertInvolvingBF16InputIntoSoftmaxDiamondWithAtLeastAmpereComputeCapability) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = bf16[127,125]{1,0} parameter(0) + param_0_f32 = f32[127,125]{1,0} convert(param_0) + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast) +} +)"; + auto ampere_module = ParseAndReturnVerifiedModule(hlo_string).value(); + auto volta_module = ampere_module->Clone(); + + // Ampere + SoftmaxRewriterTriton fusion_rewriter_ampere( + se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}); + EXPECT_TRUE(fusion_rewriter_ampere.Run(ampere_module.get()).value()); + EXPECT_TRUE(verifier().Run(ampere_module.get()).status().ok()); + VLOG(2) << ampere_module->ToString(); + EXPECT_THAT(ampere_module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter()))); + + // Volta (pre-Ampere) + SoftmaxRewriterTriton fusion_rewriter_volta( + se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}); + EXPECT_TRUE(fusion_rewriter_volta.Run(volta_module.get()).value()); + EXPECT_TRUE(verifier().Run(volta_module.get()).status().ok()); + VLOG(2) << volta_module->ToString(); + EXPECT_THAT(volta_module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Convert(m::Parameter())))); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 795da88570d84e..4249b66b6c8d66 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -99,13 +99,9 @@ xla_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_pass_pipeline", - "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service/gpu:gpu_reduce_scatter_creator", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/tsl/lib/core:status_test_util", "//tensorflow/tsl/platform:test", ], ) @@ -787,9 +783,9 @@ glob_lit_tests( "calling_convention_amdgcn.hlo": ["no_cuda_asan"], "copy_amdgcn.hlo": ["no_cuda_asan"], "copy_nested_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], - "dynamic_update_slice_inplace_amdgcn.hlo": ["no_cuda_asan"], + "dynamic_update_slice_inplace_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], "fused_scatter_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], - "fused_slice_amdgcn.hlo": ["no_cuda_asan"], + "fused_slice_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], "fused_slice_different_operands_amdgcn.hlo": ["no_cuda_asan"], "fusion_amdgcn.hlo": ["no_cuda_asan"], "launch_dimensions_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo b/tensorflow/compiler/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo index 660ae9b3f7eee3..71fe8a85899746 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo @@ -22,7 +22,7 @@ // CHECK: %[[VAL_16:.*]] = zext i32 %[[VAL_15]] to i64 // CHECK: %[[VAL_17:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 // CHECK: %[[VAL_18:.*]] = zext i32 %[[VAL_17]] to i64 -// CHECK: %[[VAL_19:.*]] = mul nuw nsw i64 %[[VAL_16]], 1024 +// CHECK: %[[VAL_19:.*]] = mul nuw nsw i64 %[[VAL_16]], 128 // CHECK: %[[VAL_20:.*]] = add nuw nsw i64 %[[VAL_19]], %[[VAL_18]] // CHECK: %[[VAL_21:.*]] = icmp ult i64 %[[VAL_20]], 98304 // CHECK: call void @llvm.assume(i1 %[[VAL_21]]) diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo index deaec101253fbe..71d571f965d4ce 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo @@ -5,7 +5,7 @@ // CHECK-LABEL: entry: // CHECK: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 1024 +// CHECK: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 // CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] // CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 2048 // CHECK: call void @llvm.assume(i1 %[[VAL_4]]) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 2da870489e4d70..24be7fdbccd2b0 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -1267,6 +1267,7 @@ class LegacyCublasGemmRewriteTest : public GemmRewriteTest { public: DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_triton_gemm(false); debug_options.set_xla_gpu_enable_cublaslt(false); return debug_options; } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc index a80f897735fd89..c2a4bf7213a0d9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc @@ -15,18 +15,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h" +#include +#include + #include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { @@ -98,6 +96,44 @@ ENTRY %AllReduce { EXPECT_EQ(AllReduceCount(module), 0); } +TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithOffsetReshape) { + absl::string_view hlo_string = R"( +HloModule AllReduce + +%sum { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + ROOT %add = f32[] add(%a, %b) +} + +ENTRY %AllReduce { + %param = f32[32,8,128]{2,1,0} parameter(0) + %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param), + replica_groups={}, to_apply=%sum + %table = s32[8]{0} constant({0,1,2,3,4,5,6,7}) + %rid = u32[] replica-id() + %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1} + %slice_size = s32[1] constant({4}) + %offset = s32[1] multiply(%id, %slice_size) + %reshape = s32[] reshape(%offset) + %zero = s32[] constant(0) + ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %reshape, %zero, %zero), + dynamic_slice_sizes={4,8,128} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/8, + /*num_partitions=*/1, + /*expect_change=*/true)); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::ReduceScatter(op::Parameter(0))); + const auto *rs = Cast( + module->entry_computation()->root_instruction()); + EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString(); + EXPECT_EQ(AllReduceCount(module), 0); +} + TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshape) { absl::string_view hlo_string = R"( HloModule AllReduce diff --git a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc index ccf274bd2a0f56..b2f5b48fcb17e7 100644 --- a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc @@ -99,7 +99,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { bool is_row_reduction = reduction_dimensions.is_row_reduction; // Base case: everything fits. - if (ReductionIsRaceFree(reduction_dimensions)) { + if (ReductionIsRaceFree(hlo->GetModule()->config(), reduction_dimensions)) { VLOG(3) << "Base case: dimensions fit"; return OkStatus(); } diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index 440a9611a8fe27..934d550238e1c1 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -46,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_float_support.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" @@ -70,6 +68,8 @@ limitations under the License. namespace xla { namespace gpu { +using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput; + namespace { // Constructs an autotuning key for a gemm performed in Triton. @@ -148,20 +148,11 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { const DebugOptions& debug_opts = fusion.parent()->config().debug_options(); - se::RedzoneAllocator rz_allocator( - stream, allocator, PtxOptsFromDebugOptions(debug_opts), - /*memory_limit=*/std::numeric_limits::max(), - /*redzone_size=*/config_.should_check_correctness() - ? debug_opts.xla_gpu_redzone_padding_bytes() - : 0); - - se::DeviceMemoryBase reference_buffer; - if (config_.should_check_correctness()) { - TF_ASSIGN_OR_RETURN( - reference_buffer, - rz_allocator.AllocateBytes(ShapeUtil::ByteSizeOf(root->shape()))); - } + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator rz_allocator, + AutotunerUtil::CreateRedzoneAllocator(config_, debug_opts)); + std::optional reference_buffer; BufferComparator comparator(root->shape(), fusion.parent()->config()); const std::vector configurations = @@ -180,7 +171,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { AutotuneResult config; *config.mutable_triton() = conf; StatusOr res = - autotuner_compile_util_->Compile(fusion, config, cache_key, [&] { + autotuner_compile_util_->Compile(config, cache_key, [&] { return TritonGemmAutotuneExtractor(conf, gpu_device_info, fusion.FusionInstruction()); }); @@ -204,14 +195,11 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { } if (config_.should_check_correctness()) { - TF_RETURN_IF_ERROR(RunMatmulWithCublas(fusion, stream, allocator, inputs, - reference_buffer, cache_key)); + TF_ASSIGN_OR_RETURN( + reference_buffer, + RunMatmulWithCublas(fusion, stream, allocator, inputs, cache_key)); } - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase output_buffer, - rz_allocator.AllocateBytes(ShapeUtil::ByteSizeOf(root->shape()))); - std::vector results; for (const AutotuneResult::TritonGemmKey& conf : configurations) { VLOG(1) << "Trying triton tiling: " << conf.ShortDebugString(); @@ -219,17 +207,18 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { AutotuneResult res; *res.mutable_triton() = conf; - TF_ASSIGN_OR_RETURN(std::optional duration, - RunMatmulWithConfig(fusion, conf, stream, inputs, - output_buffer, cache_key)); + TF_ASSIGN_OR_RETURN( + std::optional profiling_output, + RunMatmulWithConfig(fusion, conf, stream, inputs, cache_key)); - if (!duration) { + if (!profiling_output) { VLOG(1) << "Skipping this tiling."; continue; } - VLOG(1) << "Running the kernel took: " << *duration; - *res.mutable_run_time() = tsl::proto_utils::ToDurationProto(*duration); + VLOG(1) << "Running the kernel took: " << profiling_output->duration; + *res.mutable_run_time() = + tsl::proto_utils::ToDurationProto(profiling_output->duration); if (config_.should_check_correctness()) { TF_ASSIGN_OR_RETURN( @@ -246,8 +235,9 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { TF_ASSIGN_OR_RETURN( bool outputs_match, - comparator.CompareEqual(stream, /*current=*/output_buffer, - /*expected=*/reference_buffer)); + comparator.CompareEqual( + stream, /*current=*/profiling_output->output.root_buffer(), + /*expected=*/reference_buffer->root_buffer())); if (!outputs_match) { LOG(ERROR) << "Results do not match the reference. " << "This is likely a bug/unexpected loss of precision."; @@ -272,19 +262,16 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { // // `cache_key`: The cache key corresponding to the code of the fusion and the // device type. Passing it to avoid recalculating it everywhere it's needed. - StatusOr> RunMatmulWithConfig( + StatusOr> RunMatmulWithConfig( const HloComputation& hlo_computation, const AutotuneResult::TritonGemmKey& autotune_config, se::Stream* stream, absl::Span input_buffers, - se::DeviceMemoryBase output_buffer, const AutotuneCacheKey& cache_key) { + const AutotuneCacheKey& cache_key) { AutotuneResult config; *config.mutable_triton() = autotune_config; - std::vector used_buffers; - absl::c_copy(input_buffers, std::back_inserter(used_buffers)); return autotuner_compile_util_->GenerateAndProfileExecutable( - hlo_computation, config, cache_key, stream, used_buffers, output_buffer, - [&] { + config, cache_key, stream, input_buffers, [&] { return TritonGemmAutotuneExtractor( autotune_config, GetGpuDeviceInfo(config_.GetExecutor()), hlo_computation.FusionInstruction()); @@ -336,11 +323,11 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { // // `cache_key`: The cache key corresponding to the code of the fusion and the // device type. Passing it to avoid recalculating it everywhere it's needed. - Status RunMatmulWithCublas( + StatusOr RunMatmulWithCublas( const HloComputation& original_computation, se::Stream* stream, se::DeviceMemoryAllocator* allocator, absl::Span input_buffers, - se::DeviceMemoryBase output_buffer, const AutotuneCacheKey& cache_key) { + const AutotuneCacheKey& cache_key) { AutotuneResult res; // We need some value to cache compilation. We associate the compiled module @@ -349,16 +336,15 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { gemm.set_algorithm(0); *res.mutable_gemm() = gemm; - TF_ASSIGN_OR_RETURN(std::optional duration, + TF_ASSIGN_OR_RETURN(std::optional output, autotuner_compile_util_->GenerateAndProfileExecutable( - original_computation, res, cache_key, stream, - input_buffers, output_buffer, [&] { + res, cache_key, stream, input_buffers, [&] { return CublasGemmAutotuneExtractor( GetGpuDeviceInfo(config_.GetExecutor()), &original_computation); })); - TF_RET_CHECK(duration.has_value()); - return OkStatus(); + TF_RET_CHECK(output.has_value()); + return std::move(output->output); } StatusOr> CublasGemmAutotuneExtractor( @@ -471,20 +457,9 @@ StatusOr TritonAutotuner::Run( return false; } - std::optional autotuner_compile_util; - if (!config_.IsDeviceless()) { - // TODO(cheshire): The ones below should not be needed. - se::StreamExecutor* stream_exec = config_.GetExecutor(); - se::DeviceMemoryAllocator* allocator = config_.GetAllocator() - ? config_.GetAllocator() - : stream_exec->GetAllocator(); - TF_ASSIGN_OR_RETURN(se::Stream* const stream, - allocator->GetStream(stream_exec->device_ordinal())); - TF_ASSIGN_OR_RETURN(AutotunerCompileUtil util, - AutotunerCompileUtil::Create(*stream, *allocator)); - autotuner_compile_util.emplace(util); - } - + TF_ASSIGN_OR_RETURN( + std::optional autotuner_compile_util, + AutotunerCompileUtil::Create(config_, module->config().debug_options())); return TritonAutotunerVisitor{config_, thread_pool_, autotuner_compile_util} .RunOnModule(module, execution_threads); } diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 0b5398220a8364..cb27cdc1eda6a7 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -111,7 +111,7 @@ enum CustomCallApiVersion { } // Serialization of HloInstruction. -// Next ID: 82 +// Next ID: 83 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -363,6 +363,10 @@ message HloInstructionProto { // Represents the K value for top-k. int64 k = 81; + + // Represents the information for tracking propagation of values within HLO + // graph. + xla.StatisticsViz statistics_viz = 82; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 1d258ee5deeff1..eb0caab8c81cfb 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -202,6 +202,45 @@ NodeColors NodeColorsForScheme(ColorScheme color) { } } +// Given a Statistic object, returns a hex string for the fill color of the node +// with that statistic. +const char* NodeFillColorForStatistic(const Statistic& statistic) { + auto stat_val = statistic.stat_val(); + if (stat_val == 0) { + return "#f5f5f5"; + } else if (stat_val < 10) { + return "#f7d4cc"; + } else if (stat_val < 20) { + return "#f8b2a3"; + } else if (stat_val < 30) { + return "#f9a28f"; + } else if (stat_val < 40) { + return "#fa917b"; + } else if (stat_val < 50) { + return "#fb8066"; + } else if (stat_val < 60) { + return "#fc7052"; + } else if (stat_val < 70) { + return "#fd5f3d"; + } else if (stat_val < 80) { + return "#fd4e29"; + } else if (stat_val < 90) { + return "#fe3e14"; + } else { + return "#ff2d00"; + } +} + +// Given a Statistic object, returns a hex string for the font color of the node +// with that statistic. +const char* NodeFontColorForStatistic(const Statistic& statistic) { + if (statistic.stat_val() < 60) { + return "black"; + } else { + return "white"; + } +} + // Given a ColorScheme, returns an attribute string for a node of that color. // Sets the node's style and fill/stroke/text colors. // @@ -658,7 +697,13 @@ std::string HloDotDumper::DumpSubcomputation( bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; const char* strokecolor; - if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) { + + if (!highlight && parent_instr->has_statistics()) { + // Use color from the statistic + fillcolor = + NodeFillColorForStatistic(parent_instr->statistic_to_visualize()); + strokecolor = "#c2c2c2"; + } else if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) { // Use the sharding color, if the node isn't highlighted. NodeColors node_colors = NodeColorsForScheme(GetInstructionColor(parent_instr)); @@ -837,6 +882,22 @@ std::string HloDotDumper::DumpInstruction(const HloInstruction* instr) { color = kDarkRed; } } + + NodeColors node_colors = NodeColorsForScheme(color); + if (instr->has_statistics()) { + // override node's color to show statistics + const auto& statistic_to_visualize = instr->statistic_to_visualize(); + node_colors.fill_color = NodeFillColorForStatistic(statistic_to_visualize); + node_colors.stroke_color = "#c2c2c2"; + node_colors.font_color = NodeFontColorForStatistic(statistic_to_visualize); + } + + // Build the node style + std::string node_style = + StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); + // Build the text that will be displayed inside the node. std::string node_body = node_label; for (const std::string& s : {trivial_subcomputation, extra_info, @@ -849,7 +910,7 @@ std::string HloDotDumper::DumpInstruction(const HloInstruction* instr) { return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" "\n", InstructionId(instr), node_body, node_shape, node_metadata, - NodeColorAttributes(color)); + node_style); } std::string HloDotDumper::GetInstructionNodeInlinedOperands( @@ -1205,6 +1266,9 @@ ExtractCudnnConvBackendConfigProps(const gpu::CudnnConvBackendConfig& config) { if (config.side_input_scale() != 0 && config.side_input_scale() != 1) { props.emplace_back("side_input_scale", StrCat(config.side_input_scale())); } + if (config.activation_mode() == se::dnn::ActivationMode::kLeakyRelu) { + props.emplace_back("leakyrelu_alpha", StrCat(config.leakyrelu_alpha())); + } props.emplace_back( "activation_mode", se::dnn::ActivationModeString( @@ -2072,10 +2136,11 @@ void RegisterFusionState(const HloComputation& computation, fusion_progress.AddState(dot_txt, label, producer_to_highlight); } -StatusOr RenderGraph( - const HloComputation& computation, absl::string_view label, - const DebugOptions& debug_options, RenderedGraphFormat format, - HloRenderOptions hlo_render_options) { +StatusOr RenderGraph(const HloComputation& computation, + absl::string_view label, + const DebugOptions& debug_options, + RenderedGraphFormat format, + HloRenderOptions hlo_render_options) { absl::MutexLock lock(&url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { return Unavailable("Can't render as URL; no URL renderer was registered."); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 901a5aabf302c0..e97f2495edf1fc 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -164,6 +164,25 @@ TEST_F(HloGraphDumperTest, Compare) { EXPECT_THAT(graph, HasSubstr("direction=LT")); } +TEST_F(HloGraphDumperTest, HasStatisticsViz) { + const char* hlo_string = R"( + HloModule comp + + ENTRY comp { + param.0 = f32[10] parameter(0), statistics={visualizing_index=0,stat-0=0.5} + param.1 = f32[10] parameter(1), statistics={visualizing_index=1,stat-0=55.5,stat-1=44.4} + ROOT lt = pred[10] compare(param.0, param.1), direction=LT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Just check that it doesn't crash. + TF_ASSERT_OK_AND_ASSIGN( + std::string graph, + RenderGraph(*module->entry_computation(), /*label=*/"tuple_constant", + DebugOptions(), RenderedGraphFormat::kDot)); +} + TEST_F(HloGraphDumperTest, RootIsConstant) { const char* hlo_string = R"( HloModule indexed_conditional diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index c9dc160999c2cc..ab72e6aef292d9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1582,6 +1582,88 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } +TEST_F(HloInstructionTest, GetSetStatisticsViz) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 10}); + + HloComputation::Builder builder(TestName()); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + + StatisticsViz statistics_viz; + statistics_viz.set_stat_index_to_visualize(-1); + + x->set_statistics_viz(statistics_viz); + + EXPECT_FALSE(x->has_statistics()); + EXPECT_EQ(x->statistics_viz().stat_index_to_visualize(), -1); + + Statistic statistic; + statistic.set_stat_name("stat-1"); + statistic.set_stat_val(30.0); + + x->add_single_statistic(statistic); + x->set_stat_index_to_visualize(0); + + EXPECT_TRUE(x->has_statistics()); + EXPECT_TRUE( + protobuf_util::ProtobufEquals(x->statistic_to_visualize(), statistic)); + + statistic.set_stat_val(40.0); + *statistics_viz.add_statistics() = statistic; + + x->set_statistics_viz(statistics_viz); + + EXPECT_TRUE( + protobuf_util::ProtobufEquals(x->statistics_viz(), statistics_viz)); +} + +TEST_F(HloInstructionTest, StringifyStatisticsViz) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 10}); + + HloComputation::Builder builder(TestName()); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "y")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, x, y)); + + // Empty statistics viz must not print "statistics={}" + add->set_statistics_viz({}); + EXPECT_EQ(add->ToString(), + "%add = f32[5,10]{1,0} add(f32[5,10]{1,0} %x, f32[5,10]{1,0} %y)"); + + auto CreateStatisticsVizWithStatistics = + [](int64_t stat_index_to_visualize, + std::initializer_list> statistics) + -> StatisticsViz { + StatisticsViz statistics_viz; + statistics_viz.set_stat_index_to_visualize(stat_index_to_visualize); + + auto create_statistic = [](absl::string_view statistic_name, + double statistic_value) { + Statistic statistic; + statistic.set_stat_name(std::string(statistic_name)); + statistic.set_stat_val(statistic_value); + return statistic; + }; + + for (const auto& [statistic_name, statistic_value] : statistics) { + *statistics_viz.add_statistics() = + create_statistic(statistic_name, statistic_value); + } + + return statistics_viz; + }; + + add->set_statistics_viz(CreateStatisticsVizWithStatistics( + 1, {{"stat-1", 33.0}, {"stat-2", 44.0}})); + + EXPECT_EQ(add->ToString(), + "%add = f32[5,10]{1,0} add(f32[5,10]{1,0} %x, f32[5,10]{1,0} %y), " + "statistics={visualizing_index=1,stat-1=33,stat-2=44}"); +} + TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); Shape start_indices_tensor_shape = diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index a345920fd1ca33..eabf659879ba87 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -16,19 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ -#include - #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" namespace xla { @@ -142,8 +138,8 @@ class HloMemoryScheduler : public HloModulePass { // size_function is the function returning the number of bytes required for a // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not // specified, then DefaultMemoryScheduler is used. - HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, - const ModuleSchedulerAlgorithm& algorithm = {}); + explicit HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const ModuleSchedulerAlgorithm& algorithm = {}); ~HloMemoryScheduler() override = default; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 845368579ee490..e1293baba97728 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/protobuf.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2e18b6c9c226cd..bbfd2c83ba065b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -234,6 +234,7 @@ class HloParserImpl : public HloParser { StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); StatusOr ParseFrontendAttributesOnly(); + StatusOr ParseStatisticsVizOnly(); StatusOr> ParseParameterReplicationOnly(); StatusOr ParseBooleanListOrSingleBooleanOnly(); StatusOr ParseWindowOnly(); @@ -262,6 +263,7 @@ class HloParserImpl : public HloParser { kConvolutionDimensionNumbers, kSharding, kFrontendAttributes, + kStatisticsViz, kBracedBoolListOrBool, kParameterReplication, kInstructionList, @@ -467,6 +469,7 @@ class HloParserImpl : public HloParser { bool ParseListShardingType(std::vector* types); bool ParseSharding(OpSharding* sharding); bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes); + bool ParseStatisticsViz(StatisticsViz* statistics_viz); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseBooleanListOrSingleBoolean(BoolList* boolean_list); @@ -1204,9 +1207,12 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, absl::flat_hash_map attrs; optional sharding; optional frontend_attributes; + optional statistics_viz; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; attrs["frontend_attributes"] = { /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; + attrs["statistics"] = {/*required=*/false, AttrTy::kStatisticsViz, + &statistics_viz}; optional parameter_replication; attrs["parameter_replication"] = {/*required=*/false, AttrTy::kParameterReplication, @@ -1289,6 +1295,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, if (frontend_attributes) { instruction->set_frontend_attributes(*frontend_attributes); } + if (statistics_viz) { + instruction->set_statistics_viz(*statistics_viz); + } return AddInstruction(name, instruction, name_loc); } @@ -3109,6 +3118,52 @@ bool HloParserImpl::ParseFrontendAttributes( "expects '}' at the end of frontend attributes"); } +// statistics +// ::= '{' /*empty*/ '}' +// ::= '{' index, single_statistic '}' +// index ::= 'visualizing_index=' value +// single_statistic ::= statistic '=' value (',' statistic '=' value)* +bool HloParserImpl::ParseStatisticsViz(StatisticsViz* statistics_viz) { + CHECK(statistics_viz != nullptr); + if (!ParseToken(TokKind::kLbrace, "expected '{' to start statistics")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + // index must exist + std::string visualizing_index_attr_name; + if (!ParseAttributeName(&visualizing_index_attr_name)) { + return false; + } + if (lexer_.GetKind() != TokKind::kInt) { + return false; + } + statistics_viz->set_stat_index_to_visualize(lexer_.GetInt64Val()); + lexer_.Lex(); + + // then process statistics + while (EatIfPresent(TokKind::kComma)) { + std::string stat_name; + if (!ParseAttributeName(&stat_name)) { + return false; + } + if (lexer_.GetKind() != TokKind::kDecimal && + lexer_.GetKind() != TokKind::kInt) { + return false; + } + Statistic statistic; + statistic.set_stat_name(stat_name); + statistic.set_stat_val(lexer_.GetKind() == TokKind::kDecimal + ? lexer_.GetDecimalVal() + : lexer_.GetInt64Val()); + lexer_.Lex(); + *statistics_viz->add_statistics() = std::move(statistic); + } + } + return ParseToken(TokKind::kRbrace, "expects '}' at the end of statistics"); +} + // ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape? // ('devices=' ('[' dims ']')* device_list)? // ('metadata=' metadata)* '}' @@ -4458,6 +4513,15 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(frontend_attributes); return true; } + case AttrTy::kStatisticsViz: { + StatisticsViz statistics_viz; + if (!ParseStatisticsViz(&statistics_viz)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(statistics_viz); + return true; + } case AttrTy::kParameterReplication: { ParameterReplication parameter_replication; if (!ParseParameterReplication(¶meter_replication)) { @@ -6206,6 +6270,18 @@ StatusOr HloParserImpl::ParseFrontendAttributesOnly() { return attributes; } +StatusOr HloParserImpl::ParseStatisticsVizOnly() { + lexer_.Lex(); + StatisticsViz statistics_viz; + if (!ParseStatisticsViz(&statistics_viz)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after statistics"); + } + return statistics_viz; +} + StatusOr> HloParserImpl::ParseParameterReplicationOnly() { lexer_.Lex(); ParameterReplication parameter_replication; @@ -6366,6 +6442,11 @@ StatusOr ParseFrontendAttributes(absl::string_view str) { return parser.ParseFrontendAttributesOnly(); } +StatusOr ParseStatisticsViz(absl::string_view str) { + HloParserImpl parser(str); + return parser.ParseStatisticsVizOnly(); +} + StatusOr> ParseParameterReplication(absl::string_view str) { HloParserImpl parser(str); return parser.ParseParameterReplicationOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 0ab47a4d276755..f295beb606310f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -57,6 +57,11 @@ StatusOr ParseSharding(absl::string_view str); // "{attr_a=a,attr_b=b}". StatusOr ParseFrontendAttributes(absl::string_view str); +// Parses statistics viz from str. str is supposed to contain the body of the +// statistics visualization, i.e. just the rhs of the "statistics={...}" +// attribute string, e.g., "{visualizing_index=1,nan_percent=50}". +StatusOr ParseStatisticsViz(absl::string_view str); + // Parses parameter replication from str. str is supposed to contain the body of // the parameter replication, i.e. just the rhs of the // "parameter_replication={...}" attribute string, e.g., "{true, false}". diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 075c60bf08310d..332dd413835fbb 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -132,7 +132,6 @@ ENTRY %constant_pred_array () -> pred[2,3] { )" }, - // s32 constant { "ConstantS32", @@ -142,6 +141,17 @@ ENTRY %constant_s32 () -> s32[] { ROOT %constant = s32[] constant(-42) } +)" +}, +// s32 constant with statistics +{ +"ConstantS32WithStatistics", +R"(HloModule constant_s32_module, entry_computation_layout={()->s32[]} + +ENTRY %constant_s32 () -> s32[] { + ROOT %constant = s32[] constant(-42), statistics={visualizing_index=1,stat-1=33,stat-2=44} +} + )" }, // f32 constant, but the value is not a decimal and there is a backend diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index ce804404efb9ea..d89df1a4419f3f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -40,18 +40,12 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/service/buffer_value.h" -#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/tsl/platform/logging.h" namespace xla { namespace { @@ -498,13 +492,10 @@ UsesList GetUsers(const InstructionList& instruction_list, // (LogicalBuffers) at the current point in the instruction sequence. class MemoryUsageTracker { public: - MemoryUsageTracker( - const HloComputation* computation, - const HloRematerialization::ShapeSizeFunction& size_function, - const HloRematerialization::CompactShapeFunction& compact_shape_function, - const TuplePointsToAnalysis& points_to_analysis, - const InstructionList& instruction_list, - HloRematerialization::RematerializationMode mode); + MemoryUsageTracker(const HloRematerialization::Options& options, + const HloComputation* computation, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list); // Starts the placement of the given instruction. This adds the sizes of the // LogicalBuffers defined by the instruction to the current memory @@ -589,7 +580,7 @@ class MemoryUsageTracker { bool HasUnplacedUsers(Item* item) const; // Returns the list of uses for a specific 'item'. - const UsesList GetItemUses(Item* item) const; + UsesList GetItemUses(Item* item) const; // Returns whether 'item' is currently in progress. bool IsInProgressItem(Item* item) const { return item == in_progress_item_; } @@ -609,6 +600,8 @@ class MemoryUsageTracker { const HloComputation* computation() const { return computation_; } + const HloRematerialization::Options& options() const { return options_; } + // Check invariants of the data structure. This is expensive to call. bool Check() const; @@ -758,12 +751,15 @@ class MemoryUsageTracker { } return users_set.size(); }; - buffers_.push_back(Buffer{ - buffer_id, defining_instruction, size_function_(shape), shape, live_out, - has_indirect_uses, index, uses, get_num_of_unique_users(uses)}); + buffers_.push_back(Buffer{buffer_id, defining_instruction, + options_.size_function(shape), shape, live_out, + has_indirect_uses, index, uses, + get_num_of_unique_users(uses)}); return buffers_.back(); } + const HloRematerialization::Options& options_; + const HloComputation* computation_; // Instruction list containing the ordering of instructions in @@ -771,13 +767,6 @@ class MemoryUsageTracker { // (BeginInstruction/EndInstruction calls). const InstructionList& instruction_list_; - // Size function returns the bytes of a given buffer. - const HloRematerialization::ShapeSizeFunction& size_function_; - - // Converts a shape into compact form, returns the same shape if a shape is - // already considered compact. - const HloRematerialization::CompactShapeFunction& compact_shape_function_; - // A map that caches existing known compact shape for each instruction. absl::flat_hash_map compact_shape_; @@ -788,23 +777,18 @@ class MemoryUsageTracker { // between the calling of BeginInstruction and EndInstruction. Item* in_progress_item_ = nullptr; - HloRematerialization::RematerializationMode mode_; // All buffers in the computation. std::vector buffers_; }; MemoryUsageTracker::MemoryUsageTracker( + const HloRematerialization::Options& options, const HloComputation* computation, - const HloRematerialization::ShapeSizeFunction& size_function, - const HloRematerialization::CompactShapeFunction& compact_shape_function, const TuplePointsToAnalysis& points_to_analysis, - const InstructionList& instruction_list, - HloRematerialization::RematerializationMode mode) - : computation_(computation), - instruction_list_(instruction_list), - size_function_(size_function), - compact_shape_function_(compact_shape_function), - mode_(mode) { + const InstructionList& instruction_list) + : options_(options), + computation_(computation), + instruction_list_(instruction_list) { PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); @@ -958,7 +942,7 @@ int64_t MemoryUsageTracker::MemoryReducedIfCompressed( const Buffer& buffer = buffers_.at(buffer_id); memory_reduced += buffer.size; - int64_t compact_shape_size = size_function_(compact_shape); + int64_t compact_shape_size = options_.size_function(compact_shape); // Account for buffers that are compressed after instruction. memory_reduced -= compact_shape_size; } @@ -1027,9 +1011,10 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, Item* compressed_item, Item* uncompressed_item) { // Original buffer is now dead. - memory_usage_ -= size_function_(original_item->instruction->shape()); + memory_usage_ -= options_.size_function(original_item->instruction->shape()); // Compressed buffer is now alive. - memory_usage_ += size_function_(compressed_item->instruction->shape()); + memory_usage_ += + options_.size_function(compressed_item->instruction->shape()); UsesList placed_users; UsesList unplaced_users; @@ -1261,7 +1246,8 @@ StatusOr MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) { return it->second; } const Shape& original_shape = hlo->shape(); - TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape)); + TF_ASSIGN_OR_RETURN(Shape min_shape, + options_.compact_shape_function(original_shape)); compact_shape_[hlo] = min_shape; return min_shape; } @@ -1424,10 +1410,10 @@ MemoryUsageTracker::PickRematerializationCandidates( auto* item = block[0]; auto* candidate = item->instruction; if (item->buffers_output.size() == 1 && - (mode_ == + (options_.mode == HloRematerialization::RematerializationMode::kCompressOnly || - mode_ == HloRematerialization::RematerializationMode:: - kRecomputeAndCompress)) { + options_.mode == HloRematerialization::RematerializationMode:: + kRecomputeAndCompress)) { // Only consider compressing single output instruction. const Buffer& output_buffer = buffers_.at(item->buffers_output[0]); @@ -1442,8 +1428,10 @@ MemoryUsageTracker::PickRematerializationCandidates( // while performing the compression/uncompression, only perform // the compression if the sum of the two sizes is less than the // peak memory. - const int64_t size = size_function_(item->instruction->shape()); - const int64_t reduced_size = size_function_(compact_shape); + const int64_t size = + options_.size_function(item->instruction->shape()); + const int64_t reduced_size = + options_.size_function(compact_shape); effort++; if (memory_reduced > 0 && size + reduced_size < peak_memory_bytes) { @@ -1464,7 +1452,8 @@ MemoryUsageTracker::PickRematerializationCandidates( } } // Do not consider recomputation in compress-only mode. - if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) { + if (options_.mode == + HloRematerialization::RematerializationMode::kCompressOnly) { // break out of this loop. Move on to the next start_item. break; } @@ -1537,7 +1526,7 @@ bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const { return false; } -const UsesList MemoryUsageTracker::GetItemUses(Item* item) const { +UsesList MemoryUsageTracker::GetItemUses(Item* item) const { UsesList combined_users; for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_.at(buffer_id); @@ -1861,9 +1850,8 @@ StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, const HloInstructionSequence& order, const absl::flat_hash_set& execution_threads) const { InstructionList instruction_list(order); - MemoryUsageTracker tracker(computation, size_function_, - compact_shape_function_, *points_to_analysis_, - instruction_list, mode_); + MemoryUsageTracker tracker(options_, computation, *points_to_analysis_, + instruction_list); int64_t peak_memory = tracker.memory_usage(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { @@ -1927,9 +1915,8 @@ StatusOr HloRematerialization::RematerializeComputation( CHECK(!ContainsKey(rematerialized_computations_, computation)); InstructionList instruction_list(schedule->sequence(computation)); - MemoryUsageTracker memory_tracker( - computation, size_function_, compact_shape_function_, - *points_to_analysis_, instruction_list, mode_); + MemoryUsageTracker memory_tracker(options_, computation, *points_to_analysis_, + instruction_list); instruction_list.PromoteNodesToSkip([&](Item* item) { return memory_tracker.AllocatedSize(item) >= min_remat_size; @@ -2028,9 +2015,9 @@ StatusOr HloRematerialization::RematerializeComputation( min_block_size = 1; max_block_size = 1; } - if (max_block_size > block_size_limit_ || + if (max_block_size > options_.block_size_limit || second_phase_effort > - block_rematerialization_factor_ * first_phase_effort) { + options_.block_rematerialization_factor * first_phase_effort) { break; } } @@ -2112,7 +2099,7 @@ StatusOr HloRematerialization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "HloRematerialization() with memory limit of " - << HumanReadableNumBytes(memory_limit_bytes_); + << HumanReadableNumBytes(options_.memory_limit_bytes); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Initialize pass object state. @@ -2132,13 +2119,12 @@ StatusOr HloRematerialization::Run( int64_t module_output_size = 0; ShapeUtil::ForEachSubshape( module->result_shape(), - [&module_output_size, module, this](const Shape& subshape, - const ShapeIndex& output_index) { - module_output_size += size_function_(subshape); + [&](const Shape& subshape, const ShapeIndex& output_index) { + module_output_size += options_.size_function(subshape); }); const int64_t adjusted_memory_limit_bytes = - memory_limit_bytes_ - module_output_size; + std::max(0, options_.memory_limit_bytes - module_output_size); VLOG(1) << "Adjusted memory limit accounting for output (" << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); @@ -2175,8 +2161,8 @@ StatusOr HloRematerialization::Run( TF_ASSIGN_OR_RETURN( bool changed, RematerializeComputation(module->entry_computation(), &module->schedule(), - adjusted_memory_limit_bytes, min_remat_size_, - execution_threads)); + adjusted_memory_limit_bytes, + options_.min_remat_size, execution_threads)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -2207,19 +2193,19 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; - if (sizes_ != nullptr) { - sizes_->before_bytes = before_peak_memory; - sizes_->after_bytes = current_peak_memory; - } + sizes_.before_bytes = before_peak_memory; + sizes_.after_bytes = current_peak_memory; XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString()); - if (current_peak_memory > memory_limit_bytes_) { + if (current_peak_memory > options_.memory_limit_bytes) { LOG(WARNING) << absl::StrFormat( "Can't reduce memory use below %s (%d bytes) by rematerialization; " - "only reduced to %s (%d bytes)", - HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, - HumanReadableNumBytes(current_peak_memory), current_peak_memory); + "only reduced to %s (%d bytes), down from %s (%d bytes) originally", + HumanReadableNumBytes(options_.memory_limit_bytes), + options_.memory_limit_bytes, HumanReadableNumBytes(current_peak_memory), + current_peak_memory, HumanReadableNumBytes(before_peak_memory), + before_peak_memory); } return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index e5237451956767..618027423de37e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -15,6 +15,8 @@ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" @@ -23,7 +25,7 @@ #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/statusor.h" @@ -57,48 +59,60 @@ class HloRematerialization : public HloModulePass { kRecomputeAndCompress // Consider both kRecompute and kRemat. }; - // Enum to specify whether this rematerialization pass occurs before or after - // multi-output fusion. - enum class RematerializationPass { - kPreFusion, // Rematerialization pass before multi-output fusion. - kPostFusion // Rematerialization pass after multi-output fusion. + static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } + + struct Options { + explicit Options(const ShapeSizeFunction& size_function, + int64_t memory_limit_bytes, int block_size_limit, + int block_rematerialization_factor, + CompactShapeFunction compact_shape_function = nullptr, + RematerializationMode mode = + RematerializationMode::kRecomputeAndCompress, + int64_t min_remat_size = 0) + : size_function(size_function), + memory_limit_bytes(memory_limit_bytes), + block_size_limit(block_size_limit), + block_rematerialization_factor(block_rematerialization_factor), + compact_shape_function(compact_shape_function == nullptr + ? DefaultCompactShapeFunction + : std::move(compact_shape_function)), + mode(mode), + min_remat_size(min_remat_size) {} + + // Function which computes the size of the top-level buffer of a shape. + const ShapeSizeFunction size_function; + + // The threshold number of bytes to reduce memory use to via + // rematerialization. Size of aliased outputs should be subtracted + // from this. + int64_t memory_limit_bytes; + + // Maximum number of consecutive instructions to consider for + // rematerialization. + int block_size_limit; + + // Controls the amount of effort spent trying to find large blocks for + // rematerialization. Larger values leads to longer compilation times in + // return for potentially reduced memory consumption. + int block_rematerialization_factor; + + // Converts a shape into compact form, returns the same shape if a shape is + // already considered compact. + const CompactShapeFunction compact_shape_function; + + // Holds the rematerialization strategy configuration to be used by the + // pass. + RematerializationMode mode; + + // The minimim size, in bytes, of a tensor to be considered for + // rematerialization. All tensors smaller than this size will be skipped + // over. + int64_t min_remat_size; }; - static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } + explicit HloRematerialization(Options options, RematerializationSizes& sizes) + : options_(std::move(options)), sizes_(sizes) {} - // Constructor parameters: - // - // size_function: Function which returns the size in bytes of the top-level - // buffer of the given shape. - // - // memory_limit_bytes: The threshold number of bytes to reduce memory use to - // via rematerialization. Size of aliased outputs should be subtracted - // from this. - // - // sizes: Pointer to data structure which records the peak memory usage of - // the HLO module before/after rematerialization. Value are set during - // Run(). Can be nullptr. - // - // compact_shape_function: Function which returns the compact form of a - // shape. If nullptr is provided, an default identity function is used. - explicit HloRematerialization( - const ShapeSizeFunction& size_function, int64_t memory_limit_bytes, - RematerializationSizes* sizes, RematerializationPass pass_location, - int block_size_limit, int block_rematerialization_factor, - CompactShapeFunction compact_shape_function = nullptr, - RematerializationMode mode = RematerializationMode::kRecomputeAndCompress, - int64_t min_remat_size = 0) - : size_function_(size_function), - memory_limit_bytes_(memory_limit_bytes), - sizes_(sizes), - pass_location_(pass_location), - block_size_limit_(block_size_limit), - block_rematerialization_factor_(block_rematerialization_factor), - compact_shape_function_(compact_shape_function == nullptr - ? DefaultCompactShapeFunction - : std::move(compact_shape_function)), - mode_(mode), - min_remat_size_(min_remat_size) {} ~HloRematerialization() override = default; absl::string_view name() const override { return "rematerialization"; } @@ -160,36 +174,11 @@ class HloRematerialization : public HloModulePass { const absl::flat_hash_set& execution_threads, absl::string_view thread) const; - // Selects an algorithm to use for HLO scheduling. - MemorySchedulerAlgorithm scheduler_algorithm_; - - // Function which computes the size of the top-level buffer of a shape. - const ShapeSizeFunction size_function_; - - // The threshold number of bytes to reduce memory use to via - // rematerialization. - const int64_t memory_limit_bytes_; + const Options options_; - // Pointer to data structure which records the peak memory usage of the HLO - // module before/after rematerialization - RematerializationSizes* sizes_; - - // Specifies whether this rematerialization pass occurs before or after - // multi-output fusion. - RematerializationPass pass_location_; - - // Maximum number of consecutive instructions to consider for - // rematerialization. - int block_size_limit_; - - // Controls the amount of effort spent trying to find large blocks for - // rematerialization. Larger values leads to longer compilation times in - // return for potentially reduced memory consumption. - int block_rematerialization_factor_ = 1; - - // Converts a shape into compact form, returns the same shape if a shape is - // already considered compact. - const CompactShapeFunction compact_shape_function_; + // Reference to data structure which records the peak memory usage of the HLO + // module before/after rematerialization. + RematerializationSizes& sizes_; // Call graph of the hlo_module. std::unique_ptr call_graph_; @@ -221,10 +210,6 @@ class HloRematerialization : public HloModulePass { // upper bound (within a factor of 2) on the block size. int max_rematerialized_block_size_ = 0; - RematerializationMode mode_; - - int64_t min_remat_size_; - // Tracking available channel id numbers to use to apply to rematerialized // channel instructions int64_t next_channel_id_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index c5a17ae983d6bb..e487d840259e45 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -23,11 +23,9 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { @@ -51,13 +49,14 @@ class HloRematerializationTest : public RematerializationTestBase { ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler)); TF_EXPECT_OK(scheduler.Run(module).status()); } - HloRematerialization remat( + + HloRematerialization::Options options( ByteSizeOf, memory_limit_bytes, - /*sizes=*/nullptr, - HloRematerialization::RematerializationPass::kPreFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, nullptr, HloRematerialization::RematerializationMode::kRecomputeAndCompress, min_remat_size); + HloRematerialization::RematerializationSizes sizes; + HloRematerialization remat(options, sizes); return remat.Run(module); } }; @@ -607,14 +606,14 @@ class CompressingRematerializationTest : public RematerializationTestBase { HloModule* module, int64_t min_remat_size = 0) { TF_EXPECT_OK(verifier().Run(module).status()); - HloRematerialization remat( + HloRematerialization::Options options( ShapeSizePadMinorTo64, memory_limit_bytes, - /*sizes=*/nullptr, - HloRematerialization::RematerializationPass::kPreFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, ChooseCompactLayoutForShape, HloRematerialization::RematerializationMode::kCompressOnly, min_remat_size); + HloRematerialization::RematerializationSizes sizes; + HloRematerialization remat(options, sizes); return remat.Run(module); } }; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h b/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h index aac74366baa7e2..88f637249fa3f8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h @@ -24,14 +24,9 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_rematerialization.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils_test.cc index b3dc0861f421c1..a448e74b8c9e36 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils_test.cc @@ -19,14 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" -#include "tensorflow/compiler/xla/iterator_util.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 0f0a98bbb33f4e..ec99f4b580f457 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -849,6 +849,26 @@ Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, return OkStatus(); } +Status ShapeVerifier::CheckShardedParameter( + const HloInstruction* operand, const HloInstruction* sharded_parameter, + int64_t num_shards) { + TF_RET_CHECK(num_shards > 0); + Shape unsharded_parameter_shape = + ShapeUtil::GetUnshardedShape(sharded_parameter->shape(), num_shards); + + if (!ShapesSame(operand->shape(), unsharded_parameter_shape)) { + return InternalError( + "Operand %s shape: %s does not match sharded parameter %s expected " + "shape: %s, actual shape: %s " + "num shards: %d", + operand->name(), operand->shape().ToString(), sharded_parameter->name(), + operand->shape().ToString(), unsharded_parameter_shape.ToString(), + num_shards); + } + + return OkStatus(); +} + Status ShapeVerifier::CheckOperandAndParameter( const HloInstruction* instruction, int64_t operand_number, const HloComputation* computation, int64_t parameter_number) { @@ -863,6 +883,19 @@ Status ShapeVerifier::CheckOperandAndParameter( return OkStatus(); } +Status ShapeVerifier::CheckOperandAndShardedParameter( + const HloInstruction* instruction, int64_t operand_number, + const HloComputation* computation, int64_t parameter_number, + int64_t num_shards) { + TF_RET_CHECK(num_shards > 0); + const HloInstruction* operand = instruction->operand(operand_number); + const HloInstruction* parameter = + computation->parameter_instruction(parameter_number); + // In the case of verifying a sharded called computation parameter, check that + // the parameter is correctly sharded amongst the specified number of shards. + return CheckShardedParameter(operand, parameter, num_shards); +} + Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -1312,6 +1345,24 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) { return CheckShape(call, call->to_apply()->root_instruction()->shape()); } +Status ShapeVerifier::VerifyShardedCall(const HloInstruction* call, + int64_t num_shards) { + TF_RET_CHECK(num_shards > 0); + TF_RETURN_IF_ERROR( + CheckParameterCount(call, call->to_apply(), call->operand_count())); + for (int64_t i = 0; i < call->to_apply()->num_parameters(); ++i) { + TF_RETURN_IF_ERROR(CheckOperandAndShardedParameter( + call, i, call->to_apply(), i, num_shards)); + } + // The shape of kCall should match the shape of the computation it calls. + // In the case of verifying a sharded called computation, check that the + // output is correctly sharded amongst the specified number of shards. + const HloComputation* to_apply_computation = call->to_apply(); + Shape unsharded_output_shape = ShapeUtil::GetUnshardedShape( + to_apply_computation->root_instruction()->shape(), num_shards); + return CheckShape(call, unsharded_output_shape); +} + Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { const HloCustomCallInstruction* custom_call = DynCast(instruction); @@ -2682,6 +2733,19 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return OkStatus(); } + Status HandleScatter(HloInstruction* scatter) override { + int64_t rank = scatter->operand(0)->shape().rank(); + for (int64_t operand_dim : + scatter->scatter_dimension_numbers().scatter_dims_to_operand_dims()) { + if (operand_dim > rank) { + return absl::OutOfRangeError(absl::StrCat( + "The provided scatter_dims_to_operand_dim was out of range.", + " (operand_dim: ", operand_dim, ", rank: ", rank, ")")); + } + } + return OkStatus(); + } + Status Preprocess(HloInstruction* instruction) override { auto [it, inserted] = instructions_by_name_.emplace(instruction->name(), instruction); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index bfcd46ef3d3d98..a55f10e10e3778 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -252,6 +252,14 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckTernaryShape(const HloInstruction* instruction); Status CheckVariadicShape(const HloInstruction* instruction); + Status VerifyShardedCall(const HloInstruction* call, int64_t num_shards); + + Status CheckOperandAndShardedParameter(const HloInstruction* instruction, + int64_t operand_number, + const HloComputation* computation, + int64_t parameter_number, + int64_t num_shards); + private: bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b, bool minor_to_major_only = false) { @@ -289,6 +297,17 @@ class ShapeVerifier : public DfsHloVisitor { const HloComputation* computation, int64_t parameter_number); + // Checks that the shape of `operand` is compatible with `sharded_parameter` + // which resides within a "sharded" computation. An `operand` and + // `sharded_parameter` shape are compatible if for all of `operand` + // sub-shapes, the major dimension of the non-dynamic tensors in + // `sharded_parameter` are partitioned among `num_shards`. + // + // Precondition: `num_shards` > 1. + Status CheckShardedParameter(const HloInstruction* operand, + const HloInstruction* sharded_parameter, + int64_t num_shards); + // Checks that the shape of async op operands and results match the called // computation parameters and root. Status CheckAsyncOpComputationShapes(const HloInstruction* async_op, diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index f351eaed4ca022..1fa4f1a4a3a430 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -2364,6 +2364,33 @@ TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) { HasSubstr("Replica groups expected to be of uniform size")); } +TEST_F(HloVerifierTest, ScatterInvalidScatterDim) { + const char* const hlo_string = R"( + HloModule Module + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY CRS { + Arg_0 = s8[11,6]{1,0} parameter(0) + constant = s32[] constant(1) + broadcast = s32[1,7,9,2,16,2]{5,4,3,2,1,0} broadcast(constant), dimensions={} + Arg_1 = s8[1,7,9,2,9,4,16]{6,5,4,3,2,1,0} parameter(1) + scatter = s8[11,6]{1,0} scatter(Arg_0, broadcast, Arg_1), update_window_dims={4,5}, inserted_window_dims={}, scatter_dims_to_operand_dims={1094795585,1}, index_vector_dim=5, to_apply=add + abs = s8[11,6]{1,0} abs(scatter) + ROOT tuple = (s8[11,6]{1,0}) tuple(abs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("Invalid scatter_dims_to_operand_dims mapping")); +} + TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrder) { const char* const hlo = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 6fd0645a9a45b5..5b2168c1fa0c8e 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -3715,6 +3715,18 @@ std::vector AlternateMemoryBestFitHeap::GetInefficientAllocationSites( absl::Span allocation_values) const { + // The logic below is used mostly for testing, allowing a test case to inject + // some custom logic for this method. + if (options_.get_inefficient_allocation_sites_fn) { + std::vector defining_positions; + defining_positions.reserve(allocation_values.size()); + for (const AllocationValue& value : allocation_values) { + defining_positions.push_back(value.defining_position()); + } + return options_.get_inefficient_allocation_sites_fn( + absl::MakeSpan(defining_positions)); + } + if (!options_.cost_analysis || options_.inefficient_use_to_copy_ratio == 0.0) { return {}; @@ -5313,9 +5325,7 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( auto prev_allocation_it = std::find_if( allocation_sequence->rbegin(), allocation_sequence->rend(), [&](const auto& allocation) { - return allocation->memory_space() == - required_memory_space_at_start && - allocation->defining_position() == defining_position; + return allocation->memory_space() == required_memory_space_at_start; }); if (prev_allocation_it != allocation_sequence->rend()) { (*prev_allocation_it)->Extend(request.start_time); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index be40b81099f187..fd3f3f449949df 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -1488,6 +1488,12 @@ struct Options { // case copy_bytes would be twice the size of the tensor. float inefficient_use_to_copy_ratio = 0.0; + // This is mostly used for testing, it allows a test case to inject its own + // logic for AlternateMemoryBestFitHeap::GetInefficientAllocationSites. + std::function>( + absl::Span)> + get_inefficient_allocation_sites_fn = nullptr; + // The window size used to calculate the pipeline overhead when HLO accesses // the default memory, in MiB. float pipeline_overhead_window_size_mib = 0; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 28169ac0205ccd..86f02ec4aaffe8 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -5390,6 +5391,114 @@ TEST_P(MemorySpaceAssignmentTest, } } +TEST_P(MemorySpaceAssignmentTest, + WhileRedundantEvictionWithInefficientAllocationBug) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + while_cond { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + ROOT gte = pred[] get-tuple-element(p0), index=2 + } + + while_body { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + tanh = f32[3]{0} tanh(gte1) + gte2 = pred[] get-tuple-element(p0), index=2 + negate0 = f32[3]{0} negate(gte0) + negate1 = f32[3]{0} negate(negate0) + add = f32[3]{0} add(negate1, tanh) + ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2) + } + + while_cond1 { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + ROOT gte = pred[] get-tuple-element(p0), index=2 + } + + while_body1 { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte2 = pred[] get-tuple-element(p0), index=2 + negate0 = f32[3]{0} negate(gte0) + negate1 = f32[3]{0} negate(negate0) + negate2 = f32[3]{0} negate(negate1) + negate3 = f32[3]{0} negate(negate2) + negate4 = f32[3]{0} negate(negate3) + negate5 = f32[3]{0} negate(negate4) + negate6 = f32[3]{0} negate(negate5) + negate7 = f32[3]{0} negate(negate6) + negate8 = f32[3]{0} negate(negate7) + negate9 = f32[3]{0} negate(negate8) + negate10 = f32[3]{0} negate(negate9) + negate11 = f32[3]{0} negate(negate10) + negate12 = f32[3]{0} negate(negate11) + negate13 = f32[3]{0} negate(negate12) + negate14 = f32[3]{0} negate(negate13) + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + tanh = f32[3]{0} tanh(gte1) + add = f32[3]{0} add(negate14, tanh) + ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + p2 = f32[3]{0} parameter(2) + copy = f32[3]{0} copy(p0) + tuple1 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1) + while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple1), condition=while_cond, body=while_body + gte0 = f32[3]{0} get-tuple-element(while1), index=0 + gte1 = f32[3]{0} get-tuple-element(while1), index=1 + negate0_entry = f32[3]{0} negate(gte1) + gte2 = pred[] get-tuple-element(while1), index=2 + tuple2 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, gte1, gte2) + while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple2), condition=while_cond1, body=while_body1 + negate1 = f32[3]{0} negate(negate0_entry) + negate2 = f32[3]{0} negate(negate1) + negate3 = f32[3]{0} negate(negate2) + negate4 = f32[3]{0} negate(negate3) + negate5 = f32[3]{0} negate(negate4) + negate6 = f32[3]{0} negate(negate5) + negate7 = f32[3]{0} negate(negate6) + negate8 = f32[3]{0} negate(negate7) + negate9 = f32[3]{0} negate(negate8) + negate10 = f32[3]{0} negate(negate9) + negate11 = f32[3]{0} negate(negate10) + negate12 = f32[3]{0} negate(negate11) + negate13 = f32[3]{0} negate(negate12) + negate14 = f32[3]{0} negate(negate13) + gte = f32[3]{0} get-tuple-element(while2), index=1 + ROOT add = f32[3]{0} add(gte, negate14) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + // Inject GetInefficientAllocationSites to mark negate0_entry use as + // inefficient. This triggers a corner case bug where allocating for while2{1} + // in the retry allocation fails to find the previous required allocation in + // default memory, and creates a new one which is wrong. + bool marked_inefficient = false; + options.get_inefficient_allocation_sites_fn = + [&](absl::Span defining_positions) + -> std::vector> { + if (absl::c_find(defining_positions, + HloPosition{FindInstruction(module.get(), "while1"), + {1}}) != defining_positions.end() && + !marked_inefficient) { + LOG(INFO) << "Marking the use inefficient."; + marked_inefficient = true; + return {HloUse{FindInstruction(module.get(), "negate0_entry"), 0}}; + } + return {}; + }; + AssignMemorySpace(module.get(), options); +} + TEST_P(MemorySpaceAssignmentTest, BitcastRoot) { // Tests against a bug where the root of entry computation is a bitcast // instruction and it ends up getting an allocation in the alternate memory. diff --git a/tensorflow/compiler/xla/service/reduce_scatter_utils.cc b/tensorflow/compiler/xla/service/reduce_scatter_utils.cc index ae689d2a96cff6..d6508c93c875a4 100644 --- a/tensorflow/compiler/xla/service/reduce_scatter_utils.cc +++ b/tensorflow/compiler/xla/service/reduce_scatter_utils.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" @@ -139,17 +141,22 @@ bool IsPerIdOffsets(absl::Span offsets, // Returns if `offset` == shard_size * id. bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, const MapIdToTableOffset& map_id, int64_t group_size, - const HloAllReduceInstruction* ar) { + const HloAllReduceInstruction* ar, + bool true_scalar_for_offset_computation) { const bool iota_group = ar->replica_groups().empty() || (ar->IsCrossModuleAllReduce() && !ar->use_global_device_ids()); if (offset->opcode() == HloOpcode::kMultiply) { // Check if it's constant * IsPerIdOffset(..., shard_size / constant, ...) - if (offset->shape().rank() != 0) { + if (!ShapeUtil::IsEffectiveScalar(offset->shape())) { VLOG(2) << "Offset is not a scalar " << offset->ToString(); return false; } + if (true_scalar_for_offset_computation && offset->shape().rank() != 0) { + VLOG(2) << "Offset is not a true scalar " << offset->ToString(); + return false; + } int64_t const_operand = -1; if (offset->operand(0)->IsConstant()) { const_operand = 0; @@ -166,7 +173,8 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, return false; } return IsPerIdOffset(offset->operand(1 - const_operand), - shard_size / *multiplier, map_id, group_size, ar); + shard_size / *multiplier, map_id, group_size, ar, + true_scalar_for_offset_computation); } if (shard_size == 1 && iota_group) { bool id_mapping_is_identity = true; @@ -184,16 +192,16 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, if (offset->opcode() == HloOpcode::kBitcast || offset->opcode() == HloOpcode::kReshape || offset->opcode() == HloOpcode::kCopy) { - return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size, - ar); + return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size, ar, + true_scalar_for_offset_computation); } if (offset->opcode() == HloOpcode::kConvert && offset->operand(0)->shape().IsInteger() && primitive_util::BitWidth(offset->operand(0)->shape().element_type()) <= primitive_util::BitWidth(offset->shape().element_type())) { - return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size, - ar); + return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size, ar, + true_scalar_for_offset_computation); } if (offset->opcode() == HloOpcode::kClamp) { @@ -205,8 +213,8 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, << offset->ToString(); return false; } - return IsPerIdOffset(offset->operand(1), shard_size, map_id, group_size, - ar); + return IsPerIdOffset(offset->operand(1), shard_size, map_id, group_size, ar, + true_scalar_for_offset_computation); } const int64_t num_groups = iota_group ? 1 : ar->replica_groups().size(); @@ -260,23 +268,12 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, } // namespace -std::optional MatchReduceScatter( - const HloAllReduceInstruction* ar, int64_t num_partitions, - int64_t num_replicas, bool allow_multiple_split_dims, - bool allow_intervening_reshape, int64_t min_rank) { - HloPredicate match_partition_id = HloPredicateIsOp; - HloPredicate match_replica_id = HloPredicateIsOp; - return MatchReduceScatter(ar, num_partitions, num_replicas, - allow_multiple_split_dims, - allow_intervening_reshape, min_rank, - match_partition_id, match_replica_id); -} - std::optional MatchReduceScatter( const HloAllReduceInstruction* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, bool allow_intervening_reshape, int64_t min_rank, - HloPredicate match_partition_id, HloPredicate match_replica_id) { + HloPredicate match_partition_id, HloPredicate match_replica_id, + bool true_scalar_for_offset_computation) { if (!ar->shape().IsArray() || ar->constrain_layout() || (ar->IsCrossModuleAllReduce() && !ar->GetModule()->config().use_spmd_partitioning())) { @@ -480,8 +477,8 @@ std::optional MatchReduceScatter( } else { if (!IsPerIdOffset(user->operand(spec.split_dim + 1), user->dynamic_slice_sizes()[spec.split_dim], map_id, - group_size, ar)) { - VLOG(2) << "IsPerIdOffsets() failed " << ar->ToString(); + group_size, ar, true_scalar_for_offset_computation)) { + VLOG(2) << "IsPerIdOffset() failed " << ar->ToString(); return std::nullopt; } } diff --git a/tensorflow/compiler/xla/service/reduce_scatter_utils.h b/tensorflow/compiler/xla/service/reduce_scatter_utils.h index 5ed64fc864b603..711ddd3456f019 100644 --- a/tensorflow/compiler/xla/service/reduce_scatter_utils.h +++ b/tensorflow/compiler/xla/service/reduce_scatter_utils.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_SCATTER_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_SCATTER_UTILS_H_ -#include +#include +#include #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" @@ -35,14 +36,10 @@ struct ReduceScatterSpec { std::optional MatchReduceScatter( const HloAllReduceInstruction* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, - bool allow_intervening_reshape = false, int64_t min_rank = 1); - -// Matches the given all-reduce operation to a reduce-scatter pattern. -std::optional MatchReduceScatter( - const HloAllReduceInstruction* ar, int64_t num_partitions, - int64_t num_replicas, bool allow_multiple_split_dims, - bool allow_intervening_reshape, int64_t min_rank, - HloPredicate match_partition_id, HloPredicate match_replica_id); + bool allow_intervening_reshape = false, int64_t min_rank = 1, + HloPredicate match_partition_id = HloPredicateIsOp, + HloPredicate match_replica_id = HloPredicateIsOp, + bool true_scalar_for_offset_computation = false); } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index b6cfb614d66849..dff4a2e5188be1 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -1666,8 +1666,8 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( HloInstruction* zero = CreateZero( ShapeUtil::MakeShape(hlo_->shape().element_type(), {}), state_.b); HloSharding sharding_copy = sharding(); - auto padded_phlo = ReshardDataForPad(zero, pc, p_hlo, padded_base_shape, - sharding_copy, state_.b); + auto padded_phlo = + ReshardDataForPad(zero, pc, p_hlo, sharding_copy, state_.b); CHECK(padded_phlo.has_value()); VLOG(5) << "Resharded: " << padded_phlo->sharded_input->ToString(); VLOG(5) << "Padded Window: " << padded_phlo->shard_window.DebugString(); @@ -3870,9 +3870,8 @@ Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) { auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)) .Reshard(HloSharding::Replicate()) .hlo(); - auto reshard_operand = - ReshardDataForPad(replicated_rhs, hlo->padding_config(), lhs, - hlo->shape(), hlo->sharding(), &b_); + auto reshard_operand = ReshardDataForPad( + replicated_rhs, hlo->padding_config(), lhs, hlo->sharding(), &b_); if (!reshard_operand.has_value()) { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 3fa98496e3091c..5a6b1dee645e4b 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -13539,6 +13539,25 @@ ENTRY %entry (p0: f32[8], p1: f32[1]) -> (f32[1], token[]) { EXPECT_THAT(outfeed->operand(0), op::Shape("(u32[2]{0})")); } +TEST_F(SpmdPartitioningTest, PadUneven) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,13,257] parameter(0), sharding={devices=[1,2,1]0,1} + %const = f32[] constant(0) + ROOT %pad = f32[128,14,257] pad(%param0, %const), padding=0_0x0_1x0_0, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Select(), op::Shape("f32[128,7,257]"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 740bc394e2fd7a..8a98ec6d76f138 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -2350,8 +2350,7 @@ HloInstruction* SliceDataFromWindowReshard( std::optional ReshardDataForPad( HloInstruction* pad_value, PaddingConfig pc, PartitionedHlo to_reshard, - const Shape& target_shape, const HloSharding& target_sharding, - SpmdBuilder* b) { + const HloSharding& target_sharding, SpmdBuilder* b) { // Create a window config to represent the pad. Window window; bool needs_masking = false; @@ -2371,11 +2370,11 @@ std::optional ReshardDataForPad( // Need masking only if there is non-zero padding value or the operand is // unevenly partitioned. Halo exchange fills 0 in collective permute result // for non-destination cores. - needs_masking |= - shard_count > 1 && - (pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 || - pd.interior_padding() > 0) && - (!pad_value_is_zero || target_shape.dimensions(i) % shard_count != 0); + needs_masking |= shard_count > 1 && + (pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 || + pd.interior_padding() > 0) && + (!pad_value_is_zero || + to_reshard.base_shape().dimensions(i) % shard_count != 0); } // In compact halo exchange, we can't skip masking. return to_reshard.ReshardAsWindowedInput( diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index 57413f865693f0..51448b8f2f2036 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -555,8 +555,7 @@ HloInstruction* SliceDataFromWindowReshard( // parameters. std::optional ReshardDataForPad( HloInstruction* pad_value, PaddingConfig pc, PartitionedHlo to_reshard, - const Shape& target_shape, const HloSharding& target_sharding, - SpmdBuilder* b); + const HloSharding& target_sharding, SpmdBuilder* b); // Performs padding of data based on the windowed sharding passed as input. HloInstruction* PadDataFromWindowReshard( diff --git a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc index c2472c66b21535..ca543015af2296 100644 --- a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc @@ -1050,6 +1050,11 @@ StatusOr WhileLoopAllReduceCodeMotion::Run( TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( all_reduce, all_reduce->mutable_operand(0))); } + // Needs to rebuild the call graph or we could access removed + // instructions. + if (run_next_pass) { + break; + } } } VLOG(2) << "Hoisted " << count_all_reduce << " all-reduce and " diff --git a/tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt b/tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt index 7db99942aaec54..6b2ad5fab19bdf 100644 --- a/tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt +++ b/tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt @@ -23,7 +23,7 @@ results { } results { device: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" - hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[2,1,4,4]{3,2,1,0}, f32[2,1,3,2]{3,2,1,0}), window={size=2x3}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convBackwardFilter\", backend_config={\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0}" + hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[2,1,4,4]{3,2,1,0}, f32[2,1,3,2]{3,2,1,0}), window={size=2x3}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convBackwardFilter\", backend_config={\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0,\"leakyrelu_alpha\":0}" result { run_time { nanos: 45408 diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc index e59922a9a569f1..467d8364876d18 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc @@ -18,17 +18,18 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/service/hlo_proto_util.h" namespace xla { void XlaDebugInfoManager::RegisterModule( - ModuleIdentifier module_id, std::shared_ptr hlo_module, + std::shared_ptr hlo_module, std::shared_ptr buffer_assignment) { - CHECK(hlo_module != nullptr && module_id == hlo_module->unique_id()); + CHECK(hlo_module != nullptr); absl::MutexLock lock(&mutex_); - auto result = modules_.try_emplace(module_id); + auto result = modules_.try_emplace(hlo_module->unique_id()); CHECK(result.second); XlaModuleEntry& m = result.first->second; m.hlo_module = std::move(hlo_module); @@ -69,12 +70,12 @@ void XlaDebugInfoManager::StopTracing( modules_to_serialize.reserve(modules_.size()); for (auto it = modules_.begin(); it != modules_.end();) { auto& m = it->second; + auto cur_it = it++; if (!m.active) { modules_to_serialize.emplace_back(std::move(m)); - modules_.erase(it++); + modules_.erase(cur_it); } else { modules_to_serialize.emplace_back(m); - ++it; } } } @@ -94,4 +95,9 @@ void XlaDebugInfoManager::StopTracing( } } +bool XlaDebugInfoManager::TracksModule(ModuleIdentifier module_id) const { + absl::MutexLock lock(&mutex_); + return modules_.find(module_id) != modules_.end(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.h b/tensorflow/compiler/xla/service/xla_debug_info_manager.h index d18a7cf35dabb4..08d7e8eb54b552 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager.h +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_XLA_DEBUG_INFO_MANAGER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_XLA_DEBUG_INFO_MANAGER_H_ +#include #include #include +#include #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" @@ -40,9 +42,9 @@ class XlaDebugInfoManager { } // Registers an active module to XlaDebugInfoManager. - // The module_id is expected to be unique per process. + // The module_id of the module is expected to be unique per process. void RegisterModule( - ModuleIdentifier module_id, std::shared_ptr hlo_module, + std::shared_ptr hlo_module, std::shared_ptr buffer_assignment); // Unregisters an active module. @@ -58,6 +60,9 @@ class XlaDebugInfoManager { void StopTracing( std::vector>* module_debug_info = nullptr); + // Returns whether 'module_id' is tracked by XlaDebugInfoManager. + bool TracksModule(ModuleIdentifier module_id) const; + friend class XlaDebugInfoManagerTestPeer; private: diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc index 1aa459e29cb1ac..2fbb876e242802 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc @@ -27,9 +27,9 @@ namespace xla { class XlaDebugInfoManagerTestPeer { public: void RegisterModule( - ModuleIdentifier module_id, std::shared_ptr hlo_module, + std::shared_ptr hlo_module, std::shared_ptr buffer_assignment) { - return xla_debug_info_manager_.RegisterModule(module_id, hlo_module, + return xla_debug_info_manager_.RegisterModule(hlo_module, buffer_assignment); } @@ -85,7 +85,7 @@ class XlaDebugInfoManagerTest : public HloTestBase { debug_info.buffer_assignment = nullptr; ModuleIdentifier unique_id = debug_info.module->unique_id(); debug_info.unique_id = unique_id; - xla_debug_info_manager_.RegisterModule(unique_id, debug_info.module, + xla_debug_info_manager_.RegisterModule(debug_info.module, debug_info.buffer_assignment); external_references_.push_back(std::move(debug_info)); return unique_id; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 250c5d7bd32e5a..51d7892eee177f 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -38,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/platform/cpu_info.h" #include "tensorflow/tsl/platform/threadpool.h" namespace xla { @@ -515,6 +515,22 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(*shape)); } +// Prepend new major-most dimension sized `bound` to the shape. +Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { + Shape new_shape(shape.element_type(), {}, {}, {}); + new_shape.add_dimensions(bound); + for (const int64_t dim : shape.dimensions()) { + new_shape.add_dimensions(dim); + } + if (shape.has_layout()) { + for (const int64_t dim : shape.layout().minor_to_major()) { + new_shape.mutable_layout()->add_minor_to_major(dim + 1); + } + new_shape.mutable_layout()->add_minor_to_major(0); + } + return new_shape; +} + /* static */ void ShapeUtil::AppendMinorDimension(int bound, Shape* shape) { CHECK(LayoutUtil::IsDenseArray(*shape)); shape->add_dimensions(bound); @@ -1715,9 +1731,8 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, absl::Span count, absl::Span incr, const ForEachParallelVisitorFunction& visitor_function) { // The parallel version of ForEachIndexInternal can never fail. - CHECK( - ForEachIndexParallelWithStatus(shape, base, count, incr, visitor_function) - .ok()); + TF_CHECK_OK(ForEachIndexParallelWithStatus(shape, base, count, incr, + visitor_function)); } /* static */ Status ShapeUtil::ForEachIndexParallelWithStatus( @@ -1732,7 +1747,7 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, /* static */ void ShapeUtil::ForEachIndexParallel( const Shape& shape, const ForEachParallelVisitorFunction& visitor_function) { - CHECK(ForEachIndexParallelWithStatus(shape, visitor_function).ok()); + TF_CHECK_OK(ForEachIndexParallelWithStatus(shape, visitor_function)); } /* static */ Status ShapeUtil::ForEachIndexParallelWithStatus( @@ -2108,4 +2123,26 @@ int64_t ShapeUtil::ForEachState::CalculateNumSteps() const { return size; } +Shape ShapeUtil::GetUnshardedShape(const Shape& sharded_shape, + int64_t num_shards) { + if (ShapeUtil::IsScalar(sharded_shape)) { + return sharded_shape; + } + + Shape unsharded_shape = sharded_shape; + + ShapeUtil::ForEachMutableSubshape( + &unsharded_shape, + [sharded_shape, num_shards](Shape* subshape, const ShapeIndex& index) { + if (subshape->IsArray() && subshape->rank() >= 1 && + !subshape->is_dynamic()) { + const Shape& sharded_subshape = + ShapeUtil::GetSubshape(sharded_shape, index); + subshape->set_dimensions(0, + sharded_subshape.dimensions(0) * num_shards); + } + }); + return unsharded_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index adc93dc5408b1c..a08ab7703571ab 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -19,19 +19,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ -#include #include #include #include #include #include #include -#include #include #include #include -#include "absl/base/macros.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/types/span.h" @@ -40,9 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/tsl/platform/cpu_info.h" -#include "tensorflow/tsl/platform/env.h" -#include "tensorflow/tsl/platform/threadpool.h" namespace xla { @@ -321,6 +315,9 @@ class ShapeUtil { // Appends a major dimension to the shape with the given bound. static void AppendMajorDimension(int bound, Shape* shape); + // Prepends a major dimension sized `bound` to the shape. + static Shape PrependMajorDimension(int64_t bound, Shape shape); + // Appends a minor dimension to the shape with the given bound. static void AppendMinorDimension(int bound, Shape* shape); @@ -891,6 +888,11 @@ class ShapeUtil { // due to the tiling requirement. static int64_t ArrayDataSize(const Shape& shape); + // Returns the unsharded shape for an input `sharded_shape` that is + // partitioned among `num_shards`. + static Shape GetUnshardedShape(const Shape& sharded_shape, + int64_t num_shards); + private: // Fills *shape. Returns true on success. // REQUIRES: *shape is empty. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 84e4cee1c13a7c..a079c93af35771 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -973,6 +973,22 @@ TEST(ShapeUtilTest, PermuteDynamicDimensions) { } while (std::next_permutation(permutation.begin(), permutation.end())); } +TEST(ShapeUtilTest, PrependMajorDimension) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 20, 30}); + EXPECT_EQ(ShapeUtil::PrependMajorDimension(40, shape), + ShapeUtil::MakeShape(F32, {40, 10, 20, 30})); + + shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {10, 20, 30}, {0, 2, 1}); + EXPECT_EQ( + ShapeUtil::PrependMajorDimension(40, shape), + ShapeUtil::MakeShapeWithDenseLayout(F32, {40, 10, 20, 30}, {1, 3, 2, 0})); + + shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {10, 20, 30}, {2, 1, 0}); + EXPECT_EQ( + ShapeUtil::PrependMajorDimension(40, shape), + ShapeUtil::MakeShapeWithDenseLayout(F32, {40, 10, 20, 30}, {3, 2, 1, 0})); +} + TEST(ShapeUtilTest, AppendMinorDimension) { Shape shape = ShapeUtil::MakeShape(F32, {10, 20, 30}); ShapeUtil::AppendMinorDimension(40, &shape); diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index fe7fbef419c54e..ec1d45034a5ccc 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -20,10 +20,8 @@ limitations under the License. namespace xla { // NOLINTBEGIN(misc-unused-using-decls) -using tsl::FromAbslStatus; using tsl::OkStatus; using tsl::Status; // TENSORFLOW_STATUS_OK -using tsl::ToAbslStatus; // NOLINTEND(misc-unused-using-decls) } // namespace xla diff --git a/tensorflow/compiler/xla/stream_executor/cuda/BUILD b/tensorflow/compiler/xla/stream_executor/cuda/BUILD index 12df1c9484b55c..7ebd075c58e322 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/BUILD +++ b/tensorflow/compiler/xla/stream_executor/cuda/BUILD @@ -2,7 +2,6 @@ # CUDA-platform specific StreamExecutor support code. load("//tensorflow/tsl:tsl.bzl", "if_google", "set_external_visibility", "tsl_copts") -load("//tensorflow/tsl:tsl.default.bzl", "tsl_gpu_cc_test") load( "//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "stream_executor_friends", @@ -25,6 +24,10 @@ load( "//tensorflow/tsl/platform:rules_cc.bzl", "cc_library", ) +load( + "//tensorflow/compiler/xla:xla.bzl", + "xla_cc_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -127,49 +130,49 @@ cc_library( ], ) -tsl_gpu_cc_test( +xla_cc_test( name = "stream_search_test", size = "small", srcs = ["stream_search_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags(), deps = [ + ":cuda_platform", "//tensorflow/compiler/xla/stream_executor", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", "//tensorflow/compiler/xla/stream_executor/host:host_platform", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", ], ) -tsl_gpu_cc_test( +xla_cc_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # TODO(b/171512140): re-enable. "no_rocm", ], deps = [ ":cuda_driver", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", "@local_config_cuda//cuda:cuda_headers", ], ) -tsl_gpu_cc_test( +xla_cc_test( name = "memcpy_test", srcs = ["memcpy_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # TODO(b/171512140): re-enable. ], deps = [ + ":cuda_platform", "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:multi_platform_manager", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", ], @@ -612,30 +615,24 @@ cc_library( ), ) -tsl_gpu_cc_test( +xla_cc_test( name = "redzone_allocator_test", srcs = ["redzone_allocator_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # TODO(b/171512140): re-enable. ], deps = [ ":cuda_activation", ":cuda_gpu_executor", - ":stream_executor_cuda", "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", "//tensorflow/compiler/xla/stream_executor:event", "//tensorflow/compiler/xla/stream_executor:kernel", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_asm_opts", "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", - "//tensorflow/tsl/framework:allocator", - "//tensorflow/tsl/framework:allocator_registry_impl", "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", - "//tensorflow/tsl/profiler/backends/cpu:traceme_recorder_impl", - "//tensorflow/tsl/profiler/utils:time_utils_impl", ], ) diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc index 6249b0ee176135..5396d565d00028 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc @@ -15,12 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h" +#include +#include #include #include +#include +#include #include #include #include +#include #include +#include #include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" @@ -429,7 +435,7 @@ tsl::Status CudnnSupport::Init() { return tsl::Status(absl::StatusCode::kInternal, error); } - cudnn_.reset(new CudnnAccess(cudnn_handle)); + cudnn_ = std::make_unique(cudnn_handle); LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion(); return ::tsl::OkStatus(); @@ -1689,13 +1695,11 @@ class CudnnRnnSequenceTensorDescriptor : public dnn::RnnSequenceTensorDescriptor { CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, - cudnnDataType_t data_type, RNNDataDescriptor data_handle, TensorDescriptor handle) : max_seq_length_(max_seq_length), batch_size_(batch_size), data_size_(data_size), - data_type_(data_type), handle_(std::move(handle)), rnn_data_handle_(std::move(data_handle)), handles_(max_seq_length, handle_.get()) {} @@ -1719,13 +1723,13 @@ class CudnnRnnSequenceTensorDescriptor /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims, /*strideA=*/strides)); return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size, - data_size, data_type, nullptr, + data_size, nullptr, std::move(tensor_desc)); } static tsl::StatusOr Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, - const absl::Span& seq_lengths, bool time_major, + absl::Span seq_lengths, bool time_major, cudnnDataType_t data_type) { if (max_seq_length <= 0) { return tsl::Status(absl::StatusCode::kInvalidArgument, @@ -1754,13 +1758,13 @@ class CudnnRnnSequenceTensorDescriptor /*batchSize=*/batch_size, /*vectorSize=*/data_size, /*seqLengthArray=*/seq_lengths_array, /*paddingFill*/ (void*)&padding_fill)); - return CudnnRnnSequenceTensorDescriptor( - parent, max_seq_length, batch_size, data_size, data_type, - std::move(data_desc), std::move(tensor_desc)); + return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size, + data_size, std::move(data_desc), + std::move(tensor_desc)); } const cudnnTensorDescriptor_t* handles() const { return handles_.data(); } - const cudnnRNNDataDescriptor_t data_handle() const { + cudnnRNNDataDescriptor_t data_handle() const { return rnn_data_handle_.get(); } @@ -1773,7 +1777,6 @@ class CudnnRnnSequenceTensorDescriptor int max_seq_length_; int batch_size_; int data_size_; - cudnnDataType_t data_type_; TensorDescriptor handle_; RNNDataDescriptor rnn_data_handle_; std::vector handles_; // Copies of handle_. @@ -1788,8 +1791,7 @@ class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor { : handle_(CreateTensorDescriptor()), num_layers_(num_layers), batch_size_(batch_size), - data_size_(data_size), - data_type_(data_type) { + data_size_(data_size) { int dims[] = {num_layers, batch_size, data_size}; int strides[] = {dims[1] * dims[2], dims[2], 1}; CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor( @@ -1809,7 +1811,6 @@ class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor { int num_layers_; int batch_size_; int data_size_; - cudnnDataType_t data_type_; SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor); }; @@ -4136,8 +4137,7 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, << "\nConv: " << conv_desc.describe() << "\nOp: " << op.describe() << "\nOpGraph: " << opGraph.describe(); - return std::unique_ptr( - new cudnn_frontend::OperationGraph(std::move(opGraph))); + return std::make_unique(std::move(opGraph)); } bool SideInputNeeded(dnn::ActivationMode activation_mode, double conv_scale, @@ -4465,8 +4465,7 @@ GetCudnnFusedOperationGraph( << (act_op.has_value() ? act_op->describe() : "(identity)") << "\nOpGraph: " << op_graph.describe(); - return std::unique_ptr( - new cudnn_frontend::OperationGraph(std::move(op_graph))); + return std::make_unique(std::move(op_graph)); } tsl::StatusOr> @@ -6210,7 +6209,7 @@ class CudnnExecutionPlanRunner size_t workspace_size = plan_.getWorkspaceSize(); RETURN_MSG_IF_CUDNN_ERROR(plan_); bool should_add_scalars = - scalar_input_uids_.size() > 0 && scalar_input_values_.size() > 0; + !scalar_input_uids_.empty() && !scalar_input_values_.empty(); CHECK(scalar_input_uids_.size() == scalar_input_values_.size()); std::array data_ptrs = {inputs.opaque()...}; @@ -6223,7 +6222,7 @@ class CudnnExecutionPlanRunner data_ptrs_vec.erase(data_ptrs_vec.begin() + 2); } - if (data_ptrs_vec[sizeof...(Args) - 1] == nullptr && + if (!data_ptrs_vec.empty() && data_ptrs_vec.back() == nullptr && !has_activation_output_) { data_ptrs_vec.pop_back(); } @@ -6426,7 +6425,7 @@ tsl::Status CreateOpRunners( // Frontend, but instead they get filtered out here. VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed " "getting its workspace size): " - << runner_or.status().ToString(); + << runner_or.status(); continue; } diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc index 2540b4003cb353..6607a9adbee136 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc @@ -1234,7 +1234,11 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return false; } - if (gpu_dst == 0 || gpu_src == 0) { + // In graph capture mode we never have operations that access peer memory, so + // we can always make a call to cuMemcpyDtoDAsync. + bool is_capturing = stream_capture_status == cudaStreamCaptureStatusActive; + + if ((gpu_dst == 0 || gpu_src == 0) || is_capturing) { // CreatedContexts::GetAnyContext() doesn't works when ptr == 0. // This happens when the size is 0. result = cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream); diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc index 23de42ff1763ae..c6ea5e27bf1f0b 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h" #include -#include #include "absl/strings/str_format.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" @@ -45,6 +44,10 @@ std::atomic CudaGraphSupport::alive_cuda_graph_execs_; return allocated_cuda_graph_execs_.fetch_add(1, std::memory_order_relaxed); } +/*static*/ size_t CudaGraphSupport::NotifyGraphExecDestroyed() { + return alive_cuda_graph_execs_.fetch_sub(1, std::memory_order_relaxed) - 1; +} + /*static*/ size_t CudaGraphSupport::allocated_cuda_graph_execs() { return allocated_cuda_graph_execs_.load(std::memory_order_relaxed); } @@ -61,16 +64,13 @@ void CudaGraphSupport::DestroyGraph::operator()(cudaGraph_t graph) { void CudaGraphSupport::DestroyGraphExec::operator()(cudaGraphExec_t instance) { cudaError_t err = cudaGraphExecDestroy(instance); - alive_cuda_graph_execs_.fetch_sub(1, std::memory_order_relaxed); - VLOG(5) << "Destroy CUDA graph exec (remaining alive instances: " - << CudaGraphSupport::alive_cuda_graph_execs() << ")"; CHECK(err == cudaSuccess) << "Failed to destroy CUDA graph instance: " << cudaGetErrorString(err); } tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) { VLOG(3) << "Update CUDA graph exec with a new graph after " << num_launches_ - << " launches since last update " + << " launches since last update" << " #" << num_updates_++; num_launches_ = 0; @@ -109,6 +109,13 @@ tsl::Status OwnedCudaGraphExec::Launch(stream_executor::Stream* stream) { return tsl::OkStatus(); } +OwnedCudaGraphExec::~OwnedCudaGraphExec() { + if (*this) // do not log for moved-from instances + VLOG(5) << "Destroy CUDA graph exec #" << id_ + << " (remaining alive instances: " + << CudaGraphSupport::NotifyGraphExecDestroyed() << ")"; +} + //===----------------------------------------------------------------------===// // CUDA Graph Helpers. //===----------------------------------------------------------------------===// @@ -196,7 +203,7 @@ tsl::StatusOr InstantiateCudaGraph(OwnedCudaGraph graph) { VLOG(5) << "Instantiated CUDA graph exec instance #" << id << " (alive instances: " << CudaGraphSupport::alive_cuda_graph_execs() << ")"; - return OwnedCudaGraphExec(exec); + return OwnedCudaGraphExec(id, exec); } tsl::StatusOr IsStreamCapturing(stream_executor::Stream* stream) { diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h index 0b851440126dcc..ad56554c0ad300 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/functional/any_invocable.h" @@ -41,6 +40,7 @@ class CudaGraphSupport { }; static size_t NotifyGraphExecCreated(); + static size_t NotifyGraphExecDestroyed(); static size_t allocated_cuda_graph_execs(); static size_t alive_cuda_graph_execs(); @@ -67,11 +67,16 @@ class OwnedCudaGraph class OwnedCudaGraphExec : public std::unique_ptr, CudaGraphSupport::DestroyGraphExec> { - // Bring std::unique_ptr constructors in scope. - using std::unique_ptr, - CudaGraphSupport::DestroyGraphExec>::unique_ptr; + using Base = std::unique_ptr, + CudaGraphSupport::DestroyGraphExec>; public: + OwnedCudaGraphExec(uint64_t id, cudaGraphExec_t exec) : Base(exec), id_(id) {} + ~OwnedCudaGraphExec(); + + OwnedCudaGraphExec(OwnedCudaGraphExec&&) = default; + OwnedCudaGraphExec& operator=(OwnedCudaGraphExec&&) = default; + // Updates executable graph instance with a newly captured graph. Returns an // error if the new graph is not compatible (see `cudaGraphExecUpdate`). tsl::Status Update(OwnedCudaGraph graph); @@ -79,7 +84,10 @@ class OwnedCudaGraphExec // Launches captured graph on a given stream. tsl::Status Launch(stream_executor::Stream* stream); + uint64_t id() const { return id_; } + private: + uint64_t id_; uint64_t num_updates_ = 0; uint64_t num_launches_ = 0; }; diff --git a/tensorflow/compiler/xla/stream_executor/dnn.cc b/tensorflow/compiler/xla/stream_executor/dnn.cc index 03815ff24df596..592f354c5addf0 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/dnn.cc @@ -367,6 +367,8 @@ std::string ActivationModeString(ActivationMode mode) { return "bandpass"; case ActivationMode::kElu: return "elu"; + case ActivationMode::kLeakyRelu: + return "leakyrelu"; default: return absl::StrCat("unknown: ", static_cast(mode)); } diff --git a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc index 1ab21ed78506ab..cfd3de8bb50f21 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc +++ b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc @@ -52,7 +52,7 @@ using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; RedzoneAllocator::RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator, - GpuAsmOpts ptx_compilation_opts, + GpuAsmOpts gpu_compilation_opts, int64_t memory_limit, int64_t redzone_size, uint8_t redzone_pattern) : device_ordinal_(stream->parent()->device_ordinal()), @@ -63,7 +63,7 @@ RedzoneAllocator::RedzoneAllocator(Stream* stream, static_cast(tsl::Allocator::kAllocatorAlignment))), redzone_pattern_(redzone_pattern), memory_allocator_(memory_allocator), - gpu_compilation_opts_(ptx_compilation_opts) {} + gpu_compilation_opts_(gpu_compilation_opts) {} tsl::StatusOr> RedzoneAllocator::AllocateBytes( int64_t byte_size) { @@ -223,6 +223,10 @@ static tsl::Status RunRedzoneChecker( const ComparisonKernelT& comparison_kernel) { StreamExecutor* executor = stream->parent(); + if (redzone.size() == 0) { + return tsl::OkStatus(); + } + int64_t num_elements = redzone.size(); int64_t threads_per_block = std::min( executor->GetDeviceDescription().threads_per_block_limit(), num_elements); diff --git a/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc b/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc index 429f52216fdef4..137018f17c8876 100644 --- a/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc +++ b/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc @@ -76,7 +76,7 @@ class MultiPlatformManagerImpl { tsl::StatusOr LookupByIdLocked(const Platform::Id& id) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Returns the names of the initialied platforms satisfying the given filter. + // Returns the names of the initialized platforms satisfying the given filter. // By default, it will return all initialized platform names. std::vector InitializedPlatformNamesWithFilter( const std::function& filter = [](const Platform*) { diff --git a/tensorflow/compiler/xla/stream_executor/tpu/BUILD b/tensorflow/compiler/xla/stream_executor/tpu/BUILD index a0114e09f4e3c5..958c4430d34792 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/tpu/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/tsl:tsl.bzl", "set_external_visibility") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -58,18 +59,38 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", ], ) +xla_cc_test( + name = "c_api_conversions_test", + srcs = ["c_api_conversions_test.cc"], + deps = [ + ":c_api_conversions", + ":c_api_decl", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/tsl/platform:protobuf", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "libtftpu_header", hdrs = ["libtftpu.h"], diff --git a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.cc b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.cc index a7b8f793cca724..518fd6b0dd4c8d 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.cc +++ b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.cc @@ -16,11 +16,14 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h" #include +#include #include #include #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_defn.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" @@ -29,6 +32,90 @@ limitations under the License. namespace ApiConverter { +// Helper functions for copying data to possibly-inlined C arrays. + +// 'Src' and 'Dst' are allowed to be different types to make this usable with +// memory-identical types, e.g. int64_t and int64_t. This should not be used +// with types that require a static_cast. +template +static void CreateVectorBase(const absl::Span src, DstList* dst) { + dst->size = src.size(); + if (dst->size > TPU_C_API_MAX_INLINED) { + dst->heap = new Dst[dst->size]; + std::copy(src.begin(), src.end(), dst->heap); + } else { + std::copy(src.begin(), src.end(), dst->inlined); + } +} + +void CreateVector(const absl::Span src, IntList* dst) { + return CreateVectorBase(src, dst); +} + +void CreateVector(const absl::Span src, Int64List* dst) { + return CreateVectorBase(src, dst); +} + +void CreateVector(const absl::Span src, FloatList* dst) { + return CreateVectorBase(src, dst); +} + +void CreateVector(const absl::Span src, BoolList* dst) { + return CreateVectorBase(src, dst); +} + +void CreateVector(const absl::Span src, IntList* dst) { + CreateVectorBase(src, dst); +} + +static void CreateVector(const absl::Span src, IntList* dst) { + CreateVectorBase(src, dst); +} + +static void CreateVector(const absl::Span src, TileList* dst) { + dst->size = src.size(); + XLA_Tile* c_tiles; + if (dst->size > TPU_C_API_MAX_INLINED) { + dst->heap = new XLA_Tile[dst->size]; + c_tiles = dst->heap; + } else { + c_tiles = dst->inlined; + } + for (int i = 0; i < dst->size; ++i) { + ToC(src[i], &c_tiles[i]); + } +} + +// Helper functions for creating a view of possibly-inlined C arrays. + +// 'Src' and 'Dst' are allowed to be different types to make this usable with +// memory-identical types, e.g. int64_t and int64_t. This should not be used +// with types that require a static_cast. +template +static absl::Span MakeSpanBase(const SrcList& src_list) { + static_assert(sizeof(Src) == sizeof(Dst), "Mismatched types"); + const Src* src = src_list.size > TPU_C_API_MAX_INLINED ? src_list.heap + : &src_list.inlined[0]; + return absl::Span(reinterpret_cast(src), + src_list.size); +} + +absl::Span MakeSpan(const IntList& src_list) { + return MakeSpanBase(src_list); +} + +absl::Span MakeSpan(const Int64List& src_list) { + return MakeSpanBase(src_list); +} + +absl::Span MakeSpan(const FloatList& src_list) { + return MakeSpanBase(src_list); +} + +absl::Span MakeSpan(const BoolList& src_list) { + return MakeSpanBase(src_list); +} + xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) { xla::Shape xla_on_device_shape = ApiConverter::FromC(&c_buffer->on_device_shape); @@ -154,85 +241,6 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) { return base; } -// Helper functions for copying data to possibly-inlined C arrays. - -// 'Src' and 'Dst' are allowed to be different types to make this usable with -// memory-identical types, e.g. int64_t and int64_t. This should not be used -// with types that require a static_cast. -template -static void CreateVectorBase(const absl::Span src, DstList* dst) { - dst->size = src.size(); - if (dst->size > TPU_C_API_MAX_INLINED) { - dst->heap = new Dst[dst->size]; - std::copy(src.begin(), src.end(), dst->heap); - } else { - std::copy(src.begin(), src.end(), dst->inlined); - } -} - -void CreateVector(const absl::Span src, IntList* dst) { - return CreateVectorBase(src, dst); -} -void CreateVector(const absl::Span src, Int64List* dst) { - return CreateVectorBase(src, dst); -} -void CreateVector(const absl::Span src, FloatList* dst) { - return CreateVectorBase(src, dst); -} -void CreateVector(const absl::Span src, BoolList* dst) { - return CreateVectorBase(src, dst); -} -static void CreateVector(const absl::Span src, - IntList* dst) { - CreateVectorBase(src, dst); -} -static void CreateVector(const absl::Span src, IntList* dst) { - CreateVectorBase(src, dst); -} - -static void CreateVector(const absl::Span src, TileList* dst) { - dst->size = src.size(); - XLA_Tile* c_tiles; - if (dst->size > TPU_C_API_MAX_INLINED) { - dst->heap = new XLA_Tile[dst->size]; - c_tiles = dst->heap; - } else { - c_tiles = dst->inlined; - } - for (int i = 0; i < dst->size; ++i) { - ToC(src[i], &c_tiles[i]); - } -} - -// Helper functions for creating a view of possibly-inlined C arrays. - -// 'Src' and 'Dst' are allowed to be different types to make this usable with -// memory-identical types, e.g. int64_t and int64_t. This should not be used -// with types that require a static_cast. -template -static absl::Span MakeSpanBase(const SrcList& src_list) { - static_assert(sizeof(Src) == sizeof(Dst), "Mismatched types"); - const Src* src = src_list.size > TPU_C_API_MAX_INLINED ? src_list.heap - : &src_list.inlined[0]; - return absl::Span(reinterpret_cast(src), - src_list.size); -} - -absl::Span MakeSpan(const IntList& src_list) { - return MakeSpanBase(src_list); -} - -absl::Span MakeSpan(const Int64List& src_list) { - return MakeSpanBase(src_list); -} - -absl::Span MakeSpan(const FloatList& src_list) { - return MakeSpanBase(src_list); -} -absl::Span MakeSpan(const BoolList& src_list) { - return MakeSpanBase(src_list); -} - void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) { c_shape->element_type = xla_shape.element_type(); diff --git a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h index e1d846739f1fd0..9e4aa9ab2e1b90 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h @@ -16,18 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ -#include "absl/container/inlined_vector.h" +#include +#include +#include + #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" -#include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" -#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h" @@ -36,14 +38,19 @@ limitations under the License. namespace ApiConverter { absl::Span MakeSpan(const FloatList& src_list); -void CreateVector(const absl::Span src, FloatList* dst); +void CreateVector(absl::Span src, FloatList* dst); void Destroy(FloatList* float_list); absl::Span MakeSpan(const Int64List& src_list); -void CreateVector(const absl::Span src, Int64List* dst); +void CreateVector(absl::Span src, Int64List* dst); + +absl::Span MakeSpan(const IntList& src_list); +void CreateVector(absl::Span src, IntList* dst); absl::Span MakeSpan(const BoolList& src_list); -void CreateVector(const absl::Span src, BoolList* dst); +void CreateVector(absl::Span src, BoolList* dst); + +void CreateVector(absl::Span src, IntList* dst); // se::DeviceMemoryBase SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base); @@ -52,20 +59,20 @@ void ToC(const stream_executor::DeviceMemoryBase& base, stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base); void Destroy(SE_DeviceMemoryBase*); -// xla::Shape -xla::Shape FromC(const XLA_Shape* c_shape); -void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape); -void Destroy(XLA_Shape* c_shape); +// xla::Tile +xla::Tile FromC(const XLA_Tile* c_tile); +void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile); +void Destroy(XLA_Tile* c_tile); // xla::Layout xla::Layout FromC(const XLA_Layout* c_layout); void ToC(const xla::Layout& xla_layout, XLA_Layout* c_layout); void Destroy(XLA_Layout* c_layout); -// xla::Tile -xla::Tile FromC(const XLA_Tile* c_tile); -void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile); -void Destroy(XLA_Tile* c_tile); +// xla::Shape +xla::Shape FromC(const XLA_Shape* c_shape); +void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape); +void Destroy(XLA_Shape* c_shape); // xla::ShapeIndex XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape); diff --git a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions_test.cc b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions_test.cc new file mode 100644 index 00000000000000..333cb4066b534e --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions_test.cc @@ -0,0 +1,350 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/tsl/platform/protobuf.h" + +namespace ApiConverter { + +namespace { + +constexpr absl::string_view kHloString = + R"( +HloModule TupleCreate_module: +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} +)"; + +TEST(XlaTile, ToCInlined) { + std::vector tile_dimensions{2, 3, 4, 5}; + xla::Tile cpp_tile(tile_dimensions); + XLA_Tile c_tile; + ToC(cpp_tile, &c_tile); + + absl::Span cpp_tile_dimensions = cpp_tile.dimensions(); + ASSERT_EQ(cpp_tile_dimensions, tile_dimensions); + absl::Span c_tile_dimensions = MakeSpan(c_tile.dimensions); + EXPECT_EQ(cpp_tile_dimensions, c_tile_dimensions); + + Destroy(&c_tile); +} + +TEST(XlaTile, ToCDynamic) { + std::vector tile_dimensions{2, 3, 4, 5, 6, 7, 8, 9}; + xla::Tile cpp_tile(tile_dimensions); + XLA_Tile c_tile; + ToC(cpp_tile, &c_tile); + + absl::Span cpp_tile_dimensions = cpp_tile.dimensions(); + ASSERT_EQ(cpp_tile_dimensions, tile_dimensions); + absl::Span c_tile_dimensions = MakeSpan(c_tile.dimensions); + EXPECT_EQ(cpp_tile_dimensions, c_tile_dimensions); + + Destroy(&c_tile); +} + +TEST(XlaTile, FromCInlined) { + constexpr size_t kInlinedSize = 4; + Int64List tile_dimensions; + tile_dimensions.size = kInlinedSize; + for (int i = 0; i < kInlinedSize; ++i) { + tile_dimensions.inlined[i] = i + 2; + } + XLA_Tile c_tile{tile_dimensions}; + xla::Tile cpp_tile = FromC(&c_tile); + auto cpp_dimensions = cpp_tile.dimensions(); + EXPECT_EQ(cpp_dimensions.size(), kInlinedSize); + for (int i = 0; i < kInlinedSize; ++i) { + EXPECT_EQ(cpp_dimensions[i], i + 2); + } + Destroy(&c_tile); +} + +TEST(XlaTile, FromCDynamic) { + constexpr size_t kDynamicSize = 8; + int64_t* dynamic = new int64_t[kDynamicSize]; + for (int i = 0; i < kDynamicSize; ++i) { + dynamic[i] = i + 2; + } + Int64List tile_dimensions; + tile_dimensions.size = kDynamicSize; + tile_dimensions.heap = dynamic; + XLA_Tile c_tile{tile_dimensions}; + xla::Tile cpp_tile = FromC(&c_tile); + auto cpp_dimensions = cpp_tile.dimensions(); + EXPECT_EQ(cpp_dimensions.size(), kDynamicSize); + for (int i = 0; i < kDynamicSize; ++i) { + EXPECT_EQ(cpp_dimensions[i], i + 2); + } + Destroy(&c_tile); +} + +namespace TestImpl { + +void XlaLayout_ToC(const xla::Layout& cpp_layout) { + XLA_Layout c_layout; + ToC(cpp_layout, &c_layout); + + absl::Span cpp_minor_to_major = cpp_layout.minor_to_major(); + absl::Span c_minor_to_major = + MakeSpan(c_layout.minor_to_major); + EXPECT_EQ(cpp_minor_to_major, c_minor_to_major); + + absl::Span cpp_dim_level_types = + cpp_layout.dim_level_types(); + absl::Span c_dim_level_types = MakeSpan(c_layout.dim_level_types); + EXPECT_EQ(cpp_dim_level_types.size(), c_dim_level_types.size()); + for (int i = 0; i < c_dim_level_types.size(); ++i) { + EXPECT_EQ(static_cast(cpp_dim_level_types[i]), c_dim_level_types[i]); + } + + absl::Span cpp_dim_unique = cpp_layout.dim_unique(); + absl::Span c_dim_unique = MakeSpan(c_layout.dim_unique); + EXPECT_EQ(cpp_dim_unique.size(), c_dim_unique.size()); + for (int i = 0; i < c_dim_unique.size(); ++i) { + EXPECT_EQ(cpp_dim_unique[i], static_cast(c_dim_unique[i])); + } + + absl::Span cpp_dim_ordered = cpp_layout.dim_ordered(); + absl::Span c_dim_ordered = MakeSpan(c_layout.dim_ordered); + EXPECT_EQ(cpp_dim_ordered.size(), c_dim_ordered.size()); + for (int i = 0; i < c_dim_ordered.size(); ++i) { + EXPECT_EQ(cpp_dim_ordered[i], static_cast(c_dim_ordered[i])); + } + + absl::Span cpp_tiles = cpp_layout.tiles(); + TileList c_tiles = c_layout.tiles; + EXPECT_EQ(cpp_tiles.size(), c_tiles.size); + XLA_Tile* tile_base = + (c_tiles.size > TPU_C_API_MAX_INLINED) ? c_tiles.heap : c_tiles.inlined; + for (int i = 0; i < c_tiles.size; ++i) { + xla::Tile converted_c_tile = FromC(&tile_base[i]); + EXPECT_EQ(cpp_tiles[i], converted_c_tile); + } + + EXPECT_EQ(cpp_layout.index_primitive_type(), c_layout.index_primitive_type); + EXPECT_EQ(cpp_layout.pointer_primitive_type(), + c_layout.pointer_primitive_type); + EXPECT_EQ(cpp_layout.element_size_in_bits(), c_layout.element_size_in_bits); + EXPECT_EQ(cpp_layout.memory_space(), c_layout.memory_space); + EXPECT_EQ(cpp_layout.dynamic_shape_metadata_prefix_bytes(), + c_layout.dynamic_shape_metadata_prefix_bytes); + + Destroy(&c_layout); +} + +} // namespace TestImpl + +TEST(XlaLayout, ToCScalar) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4}); + xla::Layout cpp_layout = cpp_shape.layout(); + TestImpl::XlaLayout_ToC(cpp_layout); +} + +TEST(XlaLayout, ToCNested) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + xla::Layout cpp_layout = cpp_shape.layout(); + TestImpl::XlaLayout_ToC(cpp_layout); +} + +TEST(XlaLayout, FromCScalar) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4}); + xla::Layout in_layout = cpp_shape.layout(); + XLA_Layout c_layout; + ToC(in_layout, &c_layout); + xla::Layout out_layout = FromC(&c_layout); + EXPECT_EQ(in_layout, out_layout); + Destroy(&c_layout); +} + +TEST(XlaLayout, FromCNested) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + xla::Layout in_layout = cpp_shape.layout(); + XLA_Layout c_layout; + ToC(in_layout, &c_layout); + xla::Layout out_layout = FromC(&c_layout); + EXPECT_EQ(in_layout, out_layout); + Destroy(&c_layout); +} + +TEST(XlaShape, ToCScalar) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4}); + XLA_Shape c_shape; + ToC(cpp_shape, &c_shape); + + EXPECT_EQ(cpp_shape.element_type(), c_shape.element_type); + + absl::Span cpp_dimensions = cpp_shape.dimensions(); + absl::Span c_dimensions = MakeSpan(c_shape.dimensions); + EXPECT_EQ(cpp_dimensions, c_dimensions); + + absl::Span cpp_dynamic_dimensions = + cpp_shape.dynamic_dimensions(); + absl::Span c_dynamic_dimensions = + MakeSpan(c_shape.dynamic_dimensions); + EXPECT_EQ(cpp_dynamic_dimensions, c_dynamic_dimensions); + + int cpp_ntuple_shapes = cpp_shape.tuple_shapes_size(); + int c_ntuple_shapes = c_shape.ntuple_shapes; + EXPECT_EQ(cpp_ntuple_shapes, c_ntuple_shapes); + EXPECT_EQ(cpp_ntuple_shapes, 0); + + bool cpp_has_layout = cpp_shape.has_layout(); + bool c_has_layout = c_shape.has_layout; + EXPECT_EQ(cpp_has_layout, c_has_layout); + + Destroy(&c_shape); +} + +TEST(XlaShape, ToCNested) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + XLA_Shape c_shape; + ToC(cpp_shape, &c_shape); + + EXPECT_EQ(cpp_shape.element_type(), c_shape.element_type); + + absl::Span cpp_dimensions = cpp_shape.dimensions(); + absl::Span c_dimensions = MakeSpan(c_shape.dimensions); + EXPECT_EQ(cpp_dimensions, c_dimensions); + + absl::Span cpp_dynamic_dimensions = + cpp_shape.dynamic_dimensions(); + absl::Span c_dynamic_dimensions = + MakeSpan(c_shape.dynamic_dimensions); + EXPECT_EQ(cpp_dynamic_dimensions, c_dynamic_dimensions); + + int cpp_ntuple_shapes = cpp_shape.tuple_shapes_size(); + int c_ntuple_shapes = c_shape.ntuple_shapes; + EXPECT_EQ(cpp_ntuple_shapes, c_ntuple_shapes); + + const std::vector& cpp_tuple_shapes = cpp_shape.tuple_shapes(); + absl::Span c_tuple_shapes(c_shape.tuple_shapes, + c_ntuple_shapes); + for (int i = 0; i < c_ntuple_shapes; ++i) { + xla::Shape converted_c_shape = FromC(&c_tuple_shapes[i]); + EXPECT_EQ(cpp_tuple_shapes[i], converted_c_shape); + } + + bool cpp_has_layout = cpp_shape.has_layout(); + bool c_has_layout = c_shape.has_layout; + EXPECT_EQ(cpp_has_layout, c_has_layout); + + if (c_has_layout) { + xla::Layout converted_c_layout = FromC(&c_shape.layout); + EXPECT_EQ(cpp_shape.layout(), converted_c_layout); + } + + Destroy(&c_shape); +} + +TEST(XlaShape, FromCScalar) { + xla::Shape in_shape = xla::ShapeUtil::MakeShapeWithType({4}); + XLA_Shape c_shape; + ToC(in_shape, &c_shape); + xla::Shape out_shape = FromC(&c_shape); + EXPECT_EQ(in_shape, out_shape); + Destroy(&c_shape); +} + +TEST(XlaShape, FromCNested) { + xla::Shape in_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + XLA_Shape c_shape; + ToC(in_shape, &c_shape); + xla::Shape out_shape = FromC(&c_shape); + EXPECT_EQ(in_shape, out_shape); + Destroy(&c_shape); +} + +// TODO(b/290654348): xla::ShapeIndex, xla::Literal, xla::ShapedBuffer + +TEST(XlaHloModuleConfig, ToAndFromC) { + xla::StatusOr> hlo_module = + xla::ParseAndReturnUnverifiedModule(kHloString); + ASSERT_TRUE(hlo_module.ok()); + xla::HloModule& cpp_module = *hlo_module.value(); + xla::HloModuleConfig in_config = cpp_module.config(); + + XLA_HloModuleConfig c_config = ToC(in_config); + xla::HloModuleConfig out_config = FromC(c_config); + + TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleConfigProto in_config_proto, + in_config.ToProto()); + TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleConfigProto out_config_proto, + out_config.ToProto()); + + tsl::protobuf::util::MessageDifferencer diff; + diff.set_message_field_comparison( + tsl::protobuf::util::MessageDifferencer::EQUIVALENT); + EXPECT_TRUE(diff.Equals(in_config_proto, out_config_proto)); + + Destroy(&c_config); +} + +TEST(XlaHloModule, ToAndFromC) { + xla::StatusOr> hlo_module = + xla::ParseAndReturnUnverifiedModule(kHloString); + ASSERT_TRUE(hlo_module.ok()); + xla::HloModule& in_module = *hlo_module.value(); + + XLA_HloModule c_module = ToC(in_module); + xla::StatusOr> out_module_ptr = + FromC(c_module); + ASSERT_TRUE(out_module_ptr.ok()); + xla::HloModule& out_module = *out_module_ptr.value(); + + TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleProtoWithConfig in_module_proto, + in_module.ToProtoWithConfig()); + TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleProtoWithConfig out_module_proto, + out_module.ToProtoWithConfig()); + + tsl::protobuf::util::MessageDifferencer diff; + diff.set_message_field_comparison( + tsl::protobuf::util::MessageDifferencer::EQUIVALENT); + const auto* ignore_unique_id = + xla::HloModuleProto::GetDescriptor()->FindFieldByName("id"); + diff.IgnoreField(ignore_unique_id); + EXPECT_TRUE(diff.Compare(in_module_proto, out_module_proto)); + + Destroy(&c_module); +} + +// TODO(b/290654348): SE_DeviceMemoryBase, SE_DeviceMemoryAllocator, +// SE_MaybeOwningDeviceMemory + +} // namespace + +} // namespace ApiConverter diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc b/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc index 429dc6a1ce3756..35b75ac9e95d1a 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.h" -#include -#include #include #include "tensorflow/compiler/xla/status.h" @@ -24,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h" #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -111,8 +108,8 @@ xla::Status TpuOpExecutable::LoadProgramAndEnqueueToStream( } absl::string_view TpuOpExecutable::fingerprint() const { - TpuProgramFingerprint fingerprint = TpuProgram_GetFingerprint(core_program_); - return absl::string_view(fingerprint.bytes, fingerprint.size); + // TODO(skye): the fingerprint can be plumbed through via core_program_ + LOG(FATAL) << "TpuOpExecutable::fingerprint() unimplemented"; } } // namespace tensorflow diff --git a/tensorflow/compiler/xla/strict.default.bzl b/tensorflow/compiler/xla/strict.default.bzl new file mode 100644 index 00000000000000..2042d4a98d05fb --- /dev/null +++ b/tensorflow/compiler/xla/strict.default.bzl @@ -0,0 +1,13 @@ +"""Default (OSS) build versions of Python strict rules.""" + +# Placeholder to use until bazel supports py_strict_binary. +def py_strict_binary(name, **kwargs): + native.py_binary(name = name, **kwargs) + +# Placeholder to use until bazel supports py_strict_library. +def py_strict_library(name, **kwargs): + native.py_library(name = name, **kwargs) + +# Placeholder to use until bazel supports py_strict_test. +def py_strict_test(name, **kwargs): + native.py_test(name = name, **kwargs) diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 61aa2f7b4ebe67..e20493f698e99b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -114,6 +114,7 @@ cc_library( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:path", "//tensorflow/tsl/platform:test", diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 210422b6009571..425c144d1a9033 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -136,8 +136,8 @@ def xla_test( name = test_name, srcs = srcs, tags = tags + backend_tags.get(backend, []) + this_backend_tags, - extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + - this_backend_copts, + copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, args = args + this_backend_args, deps = deps + backend_deps, data = data + this_backend_data, diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index ea81983429479a..be96a4b2982aad 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_comparison.h" +#include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/path.h" #include "tensorflow/tsl/platform/test.h" diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 9e8e8256c25a88..aa4a8ff25e303d 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -426,6 +426,7 @@ cc_library( "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/tsl/platform:path", "//tensorflow/tsl/platform:status", diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.cc b/tensorflow/compiler/xla/tools/run_hlo_module.cc index 3cb0bbdeac9fa4..09d7e316ceb6fa 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module.cc +++ b/tensorflow/compiler/xla/tools/run_hlo_module.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h" #include "tensorflow/compiler/xla/tools/hlo_module_loader.h" @@ -254,6 +255,10 @@ Status RunAndCompare( options.use_buffer_assignment_from_proto ? &buffer_assignment_proto : nullptr)); + HloVerifier verifier( + HloVerifierOpts{}.WithLayoutSensitive(false).WithAllowMixedPrecision( + true)); + TF_RETURN_IF_ERROR(verifier.Run(test_module.get()).status()); if (compilation_env_modifier_hook) { TF_CHECK_OK(compilation_env_modifier_hook(options, *test_module)) << "Could not adjust the compilation environment for user provided " diff --git a/tensorflow/compiler/xla/translate/mhlo_to_hlo/BUILD b/tensorflow/compiler/xla/translate/mhlo_to_hlo/BUILD index dedfb49b68f29b..b80ab89e28e199 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_hlo/BUILD +++ b/tensorflow/compiler/xla/translate/mhlo_to_hlo/BUILD @@ -1,7 +1,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_binary", "cc_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") package( @@ -122,7 +122,7 @@ cc_binary( gentbl_cc_library( name = "operator_writer_inc", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [([], "operator_writers.inc")], tblgen = ":operator_writer_gen", td_file = "//tensorflow/compiler/xla/mlir_hlo:mhlo/IR/hlo_ops.td", diff --git a/tensorflow/compiler/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/tensorflow/compiler/xla/translate/mhlo_to_hlo/attribute_exporter.cc index 2f5161e5f5c4a3..8796c517915fb1 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_hlo/attribute_exporter.cc @@ -76,6 +76,8 @@ StatusOr ConvertConvActivationMode( return stream_executor::dnn::kBandPass; case mlir::lmhlo_gpu::Activation::Elu: return stream_executor::dnn::kElu; + case mlir::lmhlo_gpu::Activation::LeakyRelu: + return stream_executor::dnn::kLeakyRelu; default: return InternalError("Unexpected activation"); } diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index 40babb062f280a..6b5622727e2d23 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -1104,6 +1104,8 @@ static tsl::StatusOr GetLHLOActivation( return mlir::lmhlo_gpu::Activation::BandPass; case stream_executor::dnn::kElu: return mlir::lmhlo_gpu::Activation::Elu; + case stream_executor::dnn::kLeakyRelu: + return mlir::lmhlo_gpu::Activation::LeakyRelu; default: return xla::InternalError("Unknown activation"); } @@ -1235,6 +1237,8 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnConvolution( auto cnn_fused, CreateOpWithoutAttrs(custom_call)); TF_RETURN_IF_ERROR(set_activation(cnn_fused)); + cnn_fused.setLeakyreluAlphaAttr( + builder_.getF64FloatAttr(backend_config.leakyrelu_alpha())); return set_common_conv_attributes(cnn_fused); } diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 6259d8a540a098..603469afa22efb 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -79,11 +79,9 @@ def xla_cc_binary(deps = None, copts = tsl_copts(), **kwargs): def xla_cc_test( name, deps = [], - extra_copts = [], **kwargs): native.cc_test( name = name, - copts = extra_copts, deps = deps + if_tsl_link_protobuf( [], [ @@ -102,6 +100,7 @@ def xla_cc_test( clean_dep("//tensorflow/tsl/profiler/utils:time_utils_impl"), clean_dep("//tensorflow/tsl/profiler/backends/cpu:annotation_stack_impl"), clean_dep("//tensorflow/tsl/profiler/backends/cpu:traceme_recorder_impl"), + clean_dep("//tensorflow/tsl/profiler/protobuf:xplane_proto_cc_impl"), clean_dep("//tensorflow/compiler/xla:autotuning_proto_cc_impl"), clean_dep("//tensorflow/tsl/protobuf:dnn_proto_cc_impl"), clean_dep("//tensorflow/tsl/protobuf:protos_all_cc_impl"), diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 3eb4ae20db045d..f41923aeae68b6 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -570,7 +570,9 @@ message DebugOptions { bool xla_gpu_triton_gemm_disable_reduced_precision_reduction = 226; - // Next id: 229 + int32 xla_gpu_triton_fusion_level = 229; + + // Next id: 230 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index ce89c5fc61b2f4..42d6766439d0e8 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -759,6 +759,23 @@ message FrontendAttributes { map map = 1; } +// Represents a single statistic to track. +message Statistic { + // Must be a single word consisting of any alphanumeric characters + string stat_name = 1; + // Must be within a range of [0, 100], in order for the graph dumper to + // properly render the statistic onto the graph. + double stat_val = 2; +} + +// Represents the information needed to visualize propagation statistics when +// rendering an HLO graph. This includes an array of statistics as well as the +// index of the statistic to render. +message StatisticsViz { + int64 stat_index_to_visualize = 1; + repeated Statistic statistics = 2; +} + // LINT.IfChange message OpSharding { enum Type { diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index 27c77a6a9e4fb0..1287646bb1736f 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -63,6 +63,20 @@ limitations under the License. #endif // !IS_MOBILE_PLATFORM namespace tensorflow { + +// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the +// server object (which currently CHECK-fails) and we miss the error, instead, +// we log the error, and then return to allow the user to see the error +// message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.message(); \ + return _status; \ + } \ + } while (0); + #if !defined(IS_MOBILE_PLATFORM) namespace { @@ -641,7 +655,6 @@ Status UpdateContextWithServerDef(EagerContext* context, added_workers, removed_workers)); LOG_AND_RETURN_IF_ERROR(sg.as_summary_status()); } -#undef LOG_AND_RETURN_IF_ERROR return OkStatus(); } @@ -682,21 +695,76 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef( return s; } +Status EagerContextDistributedManager::InitializeLocalOnlyContext( + const ServerDef& server_def, int keep_alive_secs) { + string worker_name = + strings::StrCat("/job:", server_def.job_name(), + "/replica:0/task:", server_def.task_index()); + // New server created for new server_def. Unused if updating server_def. + std::unique_ptr new_server; + ServerInterface* server; + DeviceMgr* device_mgr = AreLocalDevicesCompatible(context_, server_def) + ? context_->local_device_mgr() + : nullptr; + LOG_AND_RETURN_IF_ERROR( + NewServerWithOptions(server_def, {device_mgr}, &new_server)); + server = new_server.get(); + uint64 context_id = EagerContext::NewContextId(); + // Make master eager context accessible by local eager service, which might + // receive send tensor requests from remote workers. + LOG_AND_RETURN_IF_ERROR( + server->AddMasterEagerContextToEagerService(context_id, context_)); + + std::vector local_device_attributes; + server->worker_env()->device_mgr->ListDeviceAttributes( + &local_device_attributes); + + auto session_name = strings::StrCat("eager_", context_id); + auto* session_mgr = server->worker_env()->session_mgr; + tsl::core::RefCountPtr r = + server->worker_env()->rendezvous_mgr->Find(context_id); + std::shared_ptr worker_session; + protobuf::RepeatedPtrField device_attributes( + local_device_attributes.begin(), local_device_attributes.end()); + LOG_AND_RETURN_IF_ERROR(session_mgr->CreateSession( + session_name, server_def, device_attributes, + context_->session_options().config.isolate_session_state())); + LOG_AND_RETURN_IF_ERROR(server->SetCoordinationServiceAgentInstance( + session_mgr->GetCoordinationServiceAgent())); + LOG_AND_RETURN_IF_ERROR( + session_mgr->WorkerSessionForSession(session_name, &worker_session)); + + // Initialize remote tensor communication based on worker session. + LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get())); + + DistributedFunctionLibraryRuntime* cluster_flr = + eager::CreateClusterFLR(context_id, context_, worker_session.get()); + auto remote_mgr = std::make_unique( + /*is_master=*/true, context_); + + // The remote workers and device manager are ignored since this initialization + // is local only. + LOG_AND_RETURN_IF_ERROR(context_->InitializeRemoteMaster( + std::move(new_server), server->worker_env(), worker_session, + /*remote_eager_workers=*/nullptr, /*remote_device_manager=*/nullptr, + /*remote_contexts=*/{}, context_id, std::move(r), + server->worker_env()->device_mgr, keep_alive_secs, cluster_flr, + std::move(remote_mgr))); + + // NOTE: We start the server after all other initialization, because the + // GrpcServer cannot be destroyed after it is started. + LOG_AND_RETURN_IF_ERROR(server->Start()); + + // If context is reset, make sure pointer is set to the new agent. + coordination_service_agent_ = + context_->GetServer() + ->worker_env() + ->session_mgr->GetCoordinationServiceAgent(); + return OkStatus(); +} + Status EagerContextDistributedManager::EnableCollectiveOps( const ServerDef& server_def) { - // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the - // server object (which currently CHECK-fails) and we miss the error, instead, - // we log the error, and then return to allow the user to see the error - // message. -#define LOG_AND_RETURN_IF_ERROR(...) \ - do { \ - const tensorflow::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - LOG(ERROR) << _status.message(); \ - return _status; \ - } \ - } while (0); - ServerInterface* server = context_->GetServer(); if (server == nullptr) { std::unique_ptr new_server; @@ -789,7 +857,6 @@ Status EagerContextDistributedManager::EnableCollectiveOps( /*new_server=*/nullptr, server->worker_env()->device_mgr, server->worker_env()->collective_executor_mgr.get())); } -#undef LOG_AND_RETURN_IF_ERROR return OkStatus(); } diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.h b/tensorflow/core/common_runtime/eager/context_distributed_manager.h index 279a792a87b522..f13c01c842ca1a 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.h +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.h @@ -44,6 +44,9 @@ class EagerContextDistributedManager Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, int keep_alive_secs) override; + Status InitializeLocalOnlyContext(const ServerDef& server_def, + int keep_alive_secs) override; + Status EnableCollectiveOps(const ServerDef& server_def) override; Status CheckRemoteAlive(const std::string& remote_task_name, diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index e4df4068ec6848..c3ea9c2c562d49 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -1,5 +1,6 @@ load( "//tensorflow:tensorflow.bzl", + "clean_dep", "if_cuda_or_rocm", "if_google", "if_linux_x86_64", @@ -192,15 +193,21 @@ tf_cuda_library( ] + if_google( # TODO(b/282068262): PJRT pulls in TFRT components that are incompatible with ARM platform. # Clean up so that PJRT can run on ARM. - if_linux_x86_64([ - "//tensorflow/compiler/tf2xla:layout_util", - "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/jit:pjrt_device_context", - "//tensorflow/compiler/xla/pjrt/gpu:gpu_helpers", - "//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", - "//tensorflow/compiler/xla/stream_executor:tf_allocator_adapter", - "//tensorflow/core/tfrt/common:pjrt_util", - ]) + if_cuda_or_rocm([ + # Also it won't build with WeightWatcher which tracks OSS build binaries. + # TODO(b/290533709): Clean up this build rule. + select({ + clean_dep("//tensorflow:linux_x86_64_with_weightwatcher"): [], + clean_dep("//tensorflow:linux_x86_64"): [ + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:pjrt_device_context", + "//tensorflow/compiler/xla/pjrt/gpu:gpu_helpers", + "//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", + "//tensorflow/compiler/xla/stream_executor:tf_allocator_adapter", + "//tensorflow/core/tfrt/common:pjrt_util", + ], + "//conditions:default": [], + }) + if_cuda_or_rocm([ "//tensorflow/compiler/xla/service:gpu_plugin_impl", # for registering cuda compiler. ]), ), diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index fdc56d0a35f9f9..b699239fdb979b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -219,8 +219,13 @@ void GPUUtil::DeviceToDeviceCopy( DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes); void* dst_ptr = GetBase(output); DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); - auto recv_stream = - static_cast(recv_dev_context)->stream(); + // For GpuDevice, always gets receive stream from + // dst->tensorflow_accelerator_device_info()->default_context which is + // GPUDeviceContext. + stream_executor::Stream* recv_stream = + static_cast( + dst->tensorflow_accelerator_device_info()->default_context) + ->stream(); if (recv_stream == nullptr) { done(errors::Internal("No recv gpu stream is available.")); return; diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 8aeb877e86881e..dae6ce57bde069 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -396,7 +396,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { #endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, RewriteIfX86, GetRewriteCause()}); rinfo_.push_back({csinfo_.avg_pool_grad, mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); @@ -722,7 +722,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { #endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, RewriteIfX86, GetRewriteCause()}); #ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.squared_difference, @@ -735,7 +735,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { #endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.transpose, mkl_op_registry::GetMklOpName(csinfo_.transpose), - CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); + CopyAttrsAll, RewriteIfX86, kRewriteForOpNameChange}); // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); @@ -1463,6 +1463,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Default rewrite rule to be used in scenario 1 for rewrite. // @return - true (since we want to always rewrite) static bool AlwaysRewrite(const Node* n) { return true; } + static bool RewriteIfX86(const Node* n) { +#ifdef DNNL_AARCH64_USE_ACL + return false; +#else + return true; +#endif + } // Rewrite rule which considers "context" of the current node to decide if we // should rewrite. By "context" we currently mean all the inputs of current diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 71bfed7bb1d8ae..7332de6026402f 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -977,6 +977,8 @@ REGISTER_DATASET_EXPERIMENT("stage_based_autotune_v2", RandomJobSamplePercentage<0>, IndependentHostTasks); REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<50>, AllTasks); +REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<0>, + IndependentHostTasks); } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index 9b03486e80552d..f010348410effb 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -156,16 +156,6 @@ class RootDataset::Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params) { - if (dataset()->params_.autotune) { - model_ = std::make_shared(); - auto experiments = GetExperiments(); - if (experiments.contains("stage_based_autotune_v2")) { - model_->AddExperiment("stage_based_autotune_v2"); - } - if (experiments.contains("autotune_buffer_optimization")) { - model_->AddExperiment("autotune_buffer_optimization"); - } - } if (dataset()->params_.max_intra_op_parallelism >= 0) { max_intra_op_parallelism_ = value_or_default(dataset()->params_.max_intra_op_parallelism, 0, @@ -187,6 +177,17 @@ class RootDataset::Iterator : public DatasetIterator { bool SymbolicCheckpointCompatible() const override { return true; } Status Initialize(IteratorContext* ctx) override { + if (dataset()->params_.autotune) { + model_ = ctx->model() != nullptr ? ctx->model() + : std::make_shared(); + absl::flat_hash_set experiments = GetExperiments(); + if (experiments.contains("stage_based_autotune_v2")) { + model_->AddExperiment("stage_based_autotune_v2"); + } + if (experiments.contains("autotune_buffer_optimization")) { + model_->AddExperiment("autotune_buffer_optimization"); + } + } IteratorContext iter_ctx(CreateParams(ctx)); TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(&iter_ctx, this, prefix(), &input_impl_)); diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 4855381b7f8e74..c3eaaf3f60da1b 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -6,7 +6,7 @@ load( "tf_proto_library", "tf_protos_profiler_service", ) -load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "get_compatible_with_cloud", "tf_grpc_cc_dependencies") +load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "get_compatible_with_portable", "tf_grpc_cc_dependencies") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -331,7 +331,7 @@ tf_cc_test( cc_grpc_library( name = "dispatcher_cc_grpc_proto", srcs = [":dispatcher_proto"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), # copybara:uncomment copts = ["-Wthread-safety-analysis"], generate_mocks = True, grpc_only = True, @@ -982,7 +982,7 @@ tf_cc_test( cc_grpc_library( name = "worker_cc_grpc_proto", srcs = [":worker_proto"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), # copybara:uncomment copts = ["-Wthread-safety-analysis"], generate_mocks = True, grpc_only = True, diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD index cdb6e3d7ffb184..b4d5fa11433571 100644 --- a/tensorflow/core/data/service/snapshot/BUILD +++ b/tensorflow/core/data/service/snapshot/BUILD @@ -226,6 +226,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:snapshot_utils", + "//tensorflow/core/data:utils", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc index 78d24dfcb9ff27..0b4c179560b7a2 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/data/utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -116,8 +117,8 @@ class SnapshotChunkDatasetOp::Dataset : public DatasetBase { Status Initialize(IteratorContext* ctx) override { reader_ = std::make_unique( - dataset()->chunk_file_, dataset()->compression_, dataset()->dtypes_, - kTFRecordReaderOutputBufferSize); + TranslateFileName(dataset()->chunk_file_), dataset()->compression_, + dataset()->dtypes_, kTFRecordReaderOutputBufferSize); return reader_->Initialize(ctx->env()); } diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc index a1c702b0c77b66..ab3eb810af6b98 100644 --- a/tensorflow/core/data/standalone.cc +++ b/tensorflow/core/data/standalone.cc @@ -195,6 +195,7 @@ Status Dataset::MakeIterator( std::back_inserter(params.split_providers)); params.thread_factory = unbounded_thread_pool_.get_thread_factory(); params.thread_pool = &unbounded_thread_pool_; + params.model = std::make_shared(); ctx = std::make_unique(std::move(params)); SerializationContext::Params serialization_params(&op_ctx); auto serialization_ctx = diff --git a/tensorflow/core/data/utils.cc b/tensorflow/core/data/utils.cc index a8f72ce1773bac..4e5d1211644ec8 100644 --- a/tensorflow/core/data/utils.cc +++ b/tensorflow/core/data/utils.cc @@ -33,5 +33,7 @@ std::string TranslateFileName(const std::string& fname) { return fname; } std::string DefaultDataTransferProtocol() { return "grpc"; } +std::string LocalityOptimizedPath(const std::string& path) { return path; } + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/utils.h b/tensorflow/core/data/utils.h index 710a480e927a3d..43ae1263490580 100644 --- a/tensorflow/core/data/utils.h +++ b/tensorflow/core/data/utils.h @@ -34,6 +34,10 @@ std::string TranslateFileName(const std::string& fname); // user. std::string DefaultDataTransferProtocol(); +// Returns a path pointing to the same file as `path` with a potential locality +// optimization. +std::string LocalityOptimizedPath(const std::string& path); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index d97ccae98de998..846aa8f8472c6d 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -117,6 +117,7 @@ exports_files( "api_def.proto", "attr_value.proto", "cost_graph.proto", + "cpp_shape_inference.proto", "dataset_metadata.proto", "dataset_options.proto", "device_attributes.proto", @@ -1695,6 +1696,18 @@ tf_proto_library( ], ) +tf_proto_library( + name = "cpp_shape_inference_proto", + srcs = ["cpp_shape_inference.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + protodeps = [ + ":full_type_proto", + ":tensor_shape_proto", + ":types_proto", + ], +) + tf_proto_library( name = "variable_proto", srcs = ["variable.proto"], @@ -1833,6 +1846,7 @@ tf_proto_library( protodeps = [ ":allocation_description_proto", ":api_def_proto", + ":cpp_shape_inference_proto", ":attr_value_proto", ":cost_graph_proto", ":dataset_proto", diff --git a/tensorflow/core/framework/cpp_shape_inference.proto b/tensorflow/core/framework/cpp_shape_inference.proto new file mode 100644 index 00000000000000..4cdbf5dd5c80ca --- /dev/null +++ b/tensorflow/core/framework/cpp_shape_inference.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package tensorflow.core; + +import "tensorflow/core/framework/full_type.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +option cc_enable_arenas = true; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto"; + +message CppShapeInferenceResult { + message HandleShapeAndType { + reserved 3; + + TensorShapeProto shape = 1; + DataType dtype = 2; + FullTypeDef type = 4; + } + message HandleData { + bool is_set = 1; + + // Only valid if . + repeated HandleShapeAndType shape_and_type = 2; + } + TensorShapeProto shape = 1; + + reserved 2; // was handle_shape + reserved 3; // was handle_dtype + HandleData handle_data = 4; +} + +message CppShapeInferenceInputsNeeded { + repeated int32 input_tensors_needed = 1; + repeated int32 input_tensors_as_shapes_needed = 2; +} diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index a197e73ed8ed47..02b6953d4ea1ab 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -545,7 +545,8 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx, auto factory = [ctx, this](model::Node::Args args) { return CreateNode(ctx, std::move(args)); }; - model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_); + model->AddNode(std::move(factory), prefix(), + parent == nullptr ? nullptr : parent->model_node(), &node_); cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); }); } return OkStatus(); diff --git a/tensorflow/core/framework/kernel_shape_util.cc b/tensorflow/core/framework/kernel_shape_util.cc index 2d52a5022bd486..c03540a9dad034 100644 --- a/tensorflow/core/framework/kernel_shape_util.cc +++ b/tensorflow/core/framework/kernel_shape_util.cc @@ -14,13 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/kernel_shape_util.h" +#include +#include + #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { -Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size, +Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, int64_t dilation_rate, int64_t stride, - Padding padding_type, - int64_t* output_size, + Padding padding_type, int64_t* output_size, int64_t* padding_before, int64_t* padding_after) { if (stride <= 0) { @@ -64,17 +66,6 @@ Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size, return OkStatus(); } -Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, - int64_t stride, Padding padding_type, - int64_t* output_size, - int64_t* padding_before, - int64_t* padding_after) { - return GetWindowedOutputSizeVerboseV2(input_size, filter_size, - /*dilation_rate=*/1, stride, - padding_type, output_size, - padding_before, padding_after); -} - Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, int dilation_rate, int64_t stride, Padding padding_type, int64_t* output_size, @@ -82,12 +73,12 @@ Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, if (padding_type == Padding::EXPLICIT) { return errors::Internal( "GetWindowedOutputSize does not handle EXPLICIT padding; call " - "GetWindowedOutputSizeVerboseV2 instead"); + "GetWindowedOutputSizeVerbose instead"); } int64_t padding_after_unused; - return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate, - stride, padding_type, output_size, - padding_size, &padding_after_unused); + return GetWindowedOutputSizeVerbose(input_size, filter_size, dilation_rate, + stride, padding_type, output_size, + padding_size, &padding_after_unused); } Status Get3dOutputSizeV2(const std::array& input, diff --git a/tensorflow/core/framework/kernel_shape_util.h b/tensorflow/core/framework/kernel_shape_util.h index 6ffda766ca449c..551a863e3d38e5 100644 --- a/tensorflow/core/framework/kernel_shape_util.h +++ b/tensorflow/core/framework/kernel_shape_util.h @@ -85,14 +85,6 @@ Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, // excess padding (caused by an odd padding size value) is added to the // 'padding_after' dimension. Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, - int64_t stride, Padding padding_type, - int64_t* output_size, - int64_t* padding_before, - int64_t* padding_after); - -// The V2 version computes the same outputs with arbitrary dilation_rate. For -// detailed equations, refer to the comments for GetWindowedOutputSize(). -Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size, int64_t dilation_rate, int64_t stride, Padding padding_type, int64_t* output_size, diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index f613e96f7bd8b7..f6341ed75a2162 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -2007,7 +2007,12 @@ std::shared_ptr Node::SnapshotHelper( cloned_current->processing_time_.store(processing_time_); { mutex_lock l2(cloned_current->mu_); - cloned_current->parameters_ = parameters_; + cloned_current->parameters_ = + absl::flat_hash_map>(); + for (const auto& [parameter_name, parameter_ptr] : parameters_) { + cloned_current->parameters_[parameter_name] = + std::make_shared(parameter_ptr); + } cloned_current->previous_processing_time_ = previous_processing_time_; cloned_current->processing_time_ema_ = processing_time_ema_; } diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 58569f9773bef2..8e815eafb1abe9 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -107,6 +107,13 @@ struct Parameter { max(max), state(std::move(state)) {} + explicit Parameter(const std::shared_ptr parameter) + : name(parameter->name), + value(parameter->value), + min(parameter->min), + max(parameter->max), + state(parameter->state) {} + // Human-readable name of the parameter. const string name; diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 513f49aee578c7..eacd62a62492cc 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -380,17 +380,27 @@ string ContainerInfo::DebugString() const { "]"); } -// TODO(b/228388547) users of this method should be migrated to the one below. +// TODO(b/228388547) users of this method should be migrated to the ones below. const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) { return ctx->input(input).flat()(0); } +Status HandleFromInput(OpKernelContext* ctx, int input, + ResourceHandle* handle) { + TF_ASSIGN_OR_RETURN(const Tensor* tensor, ctx->get_input(input)); + if (tensor->NumElements() == 0) { + return absl::InvalidArgumentError("Empty resource handle"); + } + *handle = tensor->flat()(0); + return OkStatus(); +} + Status HandleFromInput(OpKernelContext* ctx, StringPiece input, ResourceHandle* handle) { const Tensor* tensor; TF_RETURN_IF_ERROR(ctx->input(input, &tensor)); if (tensor->NumElements() == 0) { - return errors::InvalidArgument("Empty resouce handle"); + return absl::InvalidArgumentError("Empty resource handle"); } *handle = tensor->flat()(0); return OkStatus(); diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index ffdcb9aebfe038..fc043dd84bad1d 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -366,6 +366,12 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, // Returns a resource handle from a numbered op input. const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); + +// Safely returns a resource handle from a numbered op input. +// Prevents segfault by checking for empty resource handle. +Status HandleFromInput(OpKernelContext* ctx, int input, ResourceHandle* handle); +// Returns a resource handle by name, as defined in the OpDef. +// Also prevents segfault by checking for empty resource handle. Status HandleFromInput(OpKernelContext* ctx, StringPiece input, ResourceHandle* handle); diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index e5ad8f0b094c25..5c079cb2ac7318 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -16,10 +16,12 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include +#include #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/resource_handle.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -371,6 +373,63 @@ TEST(ResourceHandleTest, CRUD) { } } +TEST(ResourceHandleTest, ResourceFromValidIntInput) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 1); + + ResourceHandleProto proto; + proto.set_device("cpu:0"); + proto.set_container("test_container"); + proto.set_name("test_var"); + auto handle = std::make_unique(proto); + auto expected_summary = + "ResourceHandle(name=\"test_var\", device=\"cpu:0\", " + "container=\"test_container\", type=\"\", dtype and shapes : \"[ ]\")"; + EXPECT_EQ(handle->SummarizeValue(), expected_summary); + + Tensor arg0(DT_RESOURCE, TensorShape({2})); + arg0.flat()(0) = *handle; + std::vector inputs{TensorValue(new Tensor(arg0))}; + params.inputs = inputs; + + ResourceHandle get_int_handle; + TF_ASSERT_OK(HandleFromInput(&ctx, 0, &get_int_handle)); + EXPECT_EQ(get_int_handle.SummarizeValue(), expected_summary); + delete inputs.at(0).tensor; +} + +TEST(ResourceHandleTest, ResourceFromInvalidIntInput) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 0); + + ResourceHandle get_int_handle; + EXPECT_FALSE(HandleFromInput(&ctx, 0, &get_int_handle).ok()); +} + +TEST(ResourceHandleTest, ResourceFromIntInputWithoutResource) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 1); + + std::vector inputs{TensorValue(new Tensor())}; + params.inputs = inputs; + + ResourceHandle get_int_handle; + EXPECT_FALSE(HandleFromInput(&ctx, 0, &get_int_handle).ok()); + delete inputs.at(0).tensor; +} + TEST(ResourceHandleTest, LookupDeleteGenericResource) { ResourceMgr resource_mgr(""); OpKernelContext::Params params; diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 6d34e333d65cb5..0f37ceb05aa9db 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" +#include #include #include #include @@ -617,7 +618,7 @@ TensorBuffer* FromProtoField(Allocator* a, if (!s.ok()) { LOG(ERROR) << "Could not decode resource handle from proto \"" << in.resource_handle_val(i).ShortDebugString() - << "\", returned status: " << s.ToString(); + << "\", returned status: " << s; buf->Unref(); return nullptr; } @@ -741,7 +742,8 @@ void UnrefIfNonNull(core::RefCounted* buf) { Tensor::Tensor() : Tensor(DT_FLOAT) {} -Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {} +// Note: TensorShape has a valid constructor that takes DataType. +Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) { set_dtype(type); } Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf) : shape_(shape), buf_(buf) { diff --git a/tensorflow/core/function/polymorphism/function_type.py b/tensorflow/core/function/polymorphism/function_type.py index 06c7e1047c44a9..033afdef6d9b0b 100644 --- a/tensorflow/core/function/polymorphism/function_type.py +++ b/tensorflow/core/function/polymorphism/function_type.py @@ -355,9 +355,10 @@ def placeholder_arguments( def flat_inputs(self) -> List[trace.TraceType]: """Flat tensor inputs accepted by this FunctionType.""" if not hasattr(self, "_cached_flat_inputs"): - self._cached_flat_inputs = [] + cached_flat_inputs = [] for p in self.parameters.values(): - self._cached_flat_inputs.extend(p.type_constraint._flatten()) # pylint: disable=protected-access + cached_flat_inputs.extend(p.type_constraint._flatten()) # pylint: disable=protected-access + self._cached_flat_inputs = cached_flat_inputs return self._cached_flat_inputs @@ -399,9 +400,10 @@ def unpack_inputs( def flat_captures(self) -> List[trace.TraceType]: """Flat tensor captures needed by this FunctionType.""" if not hasattr(self, "_cached_flat_captures"): - self._cached_flat_captures = [] + cached_flat_captures = [] for t in self.captures.values(): - self._cached_flat_captures.extend(t._flatten()) # pylint: disable=protected-access + cached_flat_captures.extend(t._flatten()) # pylint: disable=protected-access + self._cached_flat_captures = cached_flat_captures return self._cached_flat_captures diff --git a/tensorflow/core/function/transform/BUILD b/tensorflow/core/function/transform/BUILD index 8bedf764edb48d..91e7302db8229a 100644 --- a/tensorflow/core/function/transform/BUILD +++ b/tensorflow/core/function/transform/BUILD @@ -27,7 +27,7 @@ pytype_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:function_def_to_graph", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:custom_gradient", "//tensorflow/python/ops:default_gradient", "//tensorflow/python/ops:handle_data_util", diff --git a/tensorflow/core/function/transform/transform.py b/tensorflow/core/function/transform/transform.py index a4a0ecd77a407f..2d259ddae0d2d0 100644 --- a/tensorflow/core/function/transform/transform.py +++ b/tensorflow/core/function/transform/transform.py @@ -25,13 +25,14 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import function_def_to_graph as function_def_lib from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import custom_gradient as custom_gradient_lib from tensorflow.python.ops import default_gradient from tensorflow.python.ops import handle_data_util from tensorflow.python.platform import tf_logging from tensorflow.python.util import compat -_TensorType = Union[ops.EagerTensor, ops.Tensor] +_TensorType = Union[ops.EagerTensor, tensor.Tensor] _FunctionDefTransformerType = Callable[[function_pb2.FunctionDef], None] @@ -233,8 +234,8 @@ def add(x, y): # Set handle data. for i, output in enumerate(cf.outputs): func_graph_output = func_graph.outputs[i] - if isinstance(output, ops.Tensor) and isinstance( - func_graph_output, ops.Tensor + if isinstance(output, tensor.Tensor) and isinstance( + func_graph_output, tensor.Tensor ): func_graph_output.set_shape(output.shape) handle_data_util.copy_handle_data(output, func_graph_output) diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index c54f73e6e7a074..2477c3cb863de3 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -826,8 +826,10 @@ void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { } // namespace -void Graph::ToGraphDef(GraphDef* graph_def, bool include_flib_def) const { - ToGraphDefSubRange(graph_def, /*from_node_id=*/0, include_flib_def); +void Graph::ToGraphDef(GraphDef* graph_def, bool include_flib_def, + bool include_debug_info) const { + ToGraphDefSubRange(graph_def, /*from_node_id=*/0, include_flib_def, + include_debug_info); } GraphDef Graph::ToGraphDefDebug() const { @@ -837,13 +839,17 @@ GraphDef Graph::ToGraphDefDebug() const { } void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id, - bool include_flib_def) const { + bool include_flib_def, + bool include_debug_info) const { graph_def->Clear(); *graph_def->mutable_versions() = versions(); if (include_flib_def) { *graph_def->mutable_library() = ops_.ToProto(); } + if (include_debug_info) { + *graph_def->mutable_debug_info() = BuildDebugInfo(); + } graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id)); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 617f071a6333a3..0c83580b15134e 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -677,8 +677,14 @@ class Graph { // contain references to functions whose definition is not included. It can // make sense to do this in cases where the caller already has a copy of the // function library. + // If `include_debug_info` is true, the `debug_info` field of the GraphDef + // will be populated with stack traces from the nodes and the function + // library. Note that if `include_debug_info` is true and `include_flib_def` + // is false, then `debug_info` will contain stack traces for nodes in the + // function library, which will not itself be included in the GraphDef. void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id, - bool include_flib_def = true) const; + bool include_flib_def = true, + bool include_debug_info = false) const; // Serialize to a GraphDef. `include_flib_def` indicates whether the function // library will be populated in the `graph_def`. `include_flib_def` should be @@ -687,7 +693,13 @@ class Graph { // `graph_def` is incomplete and may contain references to functions whose // definition is not included. It can make sense to do this in cases where the // caller already has a copy of the function library. - void ToGraphDef(GraphDef* graph_def, bool include_flib_def = true) const; + // If `include_debug_info` is true, the `debug_info` field of the GraphDef + // will be populated with stack traces from the nodes and the function + // library. Note that if `include_debug_info` is true and `include_flib_def` + // is false, then `debug_info` will contain stack traces for nodes in the + // function library, which will not itself be included in the GraphDef. + void ToGraphDef(GraphDef* graph_def, bool include_flib_def = true, + bool include_debug_info = false) const; // This version can be called from debugger to inspect the graph content. // Use the previous version outside debug context for efficiency reasons. diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index a4f09383c63b5d..1a3f5216c6b26b 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_partition.h" #include +#include #include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/tensor_id.h" @@ -980,7 +982,10 @@ void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) { Status Partition(const PartitionOptions& opts, Graph* g, std::unordered_map* partitions) { + // TODO(b/290689453) Refactor this into smaller functions Status status; + absl::flat_hash_map> + debug_info_builders; partitions->clear(); GraphInfo g_info; @@ -1219,6 +1224,19 @@ Status Partition(const PartitionOptions& opts, Graph* g, Graph::kControlSlot); } } + + // For each partition, lazily create a GraphDebugInfoBuilder. Gather stack + // traces for the nodes in that partition into the builder. + const std::shared_ptr& stack_trace = + dst->GetStackTrace(); + if (stack_trace != nullptr) { + std::unique_ptr& builder = + debug_info_builders[dstp]; + if (!builder) { + builder = std::make_unique(); + } + builder->AccumulateStackTrace(*stack_trace, dst->name()); + } } const FunctionLibraryDefinition* flib_def = opts.flib_def; @@ -1250,6 +1268,15 @@ Status Partition(const PartitionOptions& opts, Graph* g, VLOG(1) << "Added send/recv: controls=" << num_control << ", data=" << num_data; + // For each partition, build the GraphDebugInfo for all of its nodes' stack + // traces, and add it to the GraphDef for that partition. + for (auto& it : *partitions) { + const auto& builder_iter = debug_info_builders.find(it.first); + if (builder_iter != debug_info_builders.end()) { + GraphDef& gdef = it.second; + *gdef.mutable_debug_info() = builder_iter->second->Build(); + } + } if (VLOG_IS_ON(2)) { for (auto& it : *partitions) { GraphDef* gdef = &it.second; diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index e62d20456edfa0..51f59c02897a28 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -15,9 +15,14 @@ limitations under the License. #include "tensorflow/core/graph/graph_partition.h" +#include +#include #include #include +#include +#include +#include "absl/strings/str_cat.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/control_flow_ops.h" @@ -58,9 +63,33 @@ using ops::Const; using ops::Identity; using ops::LoopCond; using ops::NextIteration; +using ::testing::Eq; +using ::testing::Ne; const char gpu_device[] = "/job:a/replica:0/task:0/device:GPU:0"; +class TestStackTrace : public AbstractStackTrace { + public: + explicit TestStackTrace(const std::vector frames) + : frames_(std::move(frames)) {} + + absl::Span ToFrames() const override { return frames_; } + + std::vector GetUserFrames(int limit) const override { + return frames_; + } + + StackFrame LastUserFrame() const override { return frames_.back(); } + + std::string ToString(const TracePrintingOptions& opts) const override { + auto frame = LastUserFrame(); + return absl::StrCat(frame.file_name, ":", frame.line_number, ":", + frame.function_name); + } + + std::vector frames_; +}; + string SplitByDevice(const Node* node) { return node->assigned_device_name(); } string DeviceName(const Node* node) { @@ -194,6 +223,18 @@ Output Combine(const Scope& scope, Input a, Input b) { return ConstructOp(scope, "Combine", {std::move(a), std::move(b)}); } +std::string FormatStackTrace(const GraphDebugInfo::StackTrace& stack_trace, + const GraphDebugInfo& debug_info) { + std::string result; + for (const GraphDebugInfo::FileLineCol& file_line_col : + stack_trace.file_line_cols()) { + const std::string& file = debug_info.files(file_line_col.file_index()); + absl::StrAppend(&result, file_line_col.func(), "@", file, ":", + file_line_col.line(), ".", file_line_col.col(), "\n"); + } + return result; +} + class GraphPartitionTest : public ::testing::Test { protected: GraphPartitionTest() @@ -203,8 +244,8 @@ class GraphPartitionTest : public ::testing::Test { scope_b_(Scope::NewRootScope().ExitOnError().WithDevice( "/job:a/replica:0/task:0/cpu:1")) {} - const GraphDef& ToGraphDef() { - TF_EXPECT_OK(in_.ToGraphDef(&in_graph_def_)); + const GraphDef& ToGraphDef(bool include_debug_info = false) { + TF_EXPECT_OK(in_.ToGraphDef(&in_graph_def_, include_debug_info)); return in_graph_def_; } @@ -465,7 +506,6 @@ TEST_F(GraphPartitionTest, Functions) { *fdef_lib.add_function() = test::function::XTimesFour(); TF_ASSERT_OK(in_.graph()->AddFunctionLibrary(fdef_lib)); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) auto a1 = FloatInput(in_.WithOpName("A1")); auto b1 = FloatInput(in_.WithOpName("B1")); ConstructOp(in_.WithOpName("A2"), "XTimesTwo", {a1}); @@ -523,6 +563,71 @@ TEST_F(GraphPartitionTest, SetIncarnation) { } } +TEST_F(GraphPartitionTest, GraphDebugInfo) { + GraphDef graph_def; + Output a1 = FloatInput(in_.WithOpName("A1")); + Output b1 = FloatInput(in_.WithOpName("B1")); + Combine(in_.WithOpName("B2"), a1, b1); + + Node *a1_node = nullptr, *b1_node = nullptr, *b2_node = nullptr; + for (Node* node : in_.graph()->op_nodes()) { + if (node->name() == "A1") { + a1_node = node; + } else if (node->name() == "B1") { + b1_node = node; + } else if (node->name() == "B2") { + b2_node = node; + } + } + EXPECT_NE(a1_node, nullptr); + EXPECT_NE(b1_node, nullptr); + EXPECT_NE(b2_node, nullptr); + + TestStackTrace a1_stack_trace( + std::vector{{"main.cc", 20, "x"}, {"alpha.cc", 30, "a1"}}); + TestStackTrace b1_stack_trace( + std::vector{{"window.cc", 21, "y"}, {"beta.cc", 35, "b1"}}); + TestStackTrace b2_stack_trace( + std::vector{{"cache.cc", 22, "bar"}, {"beta.cc", 39, "b2"}}); + a1_node->SetStackTrace(std::make_shared(a1_stack_trace)); + b1_node->SetStackTrace(std::make_shared(b1_stack_trace)); + b2_node->SetStackTrace(std::make_shared(b2_stack_trace)); + + TF_EXPECT_OK(in_.ToGraphDef(&graph_def, /*include_debug_info=*/true)); + + // `Partition()` uses the first letter of the op name ('A' or 'B') to choose a + // device for each node. It calls the function under test, also named + // `Partition()`, to do the actual partitioning. + Partition(ToGraphDef(/*include_debug_info=*/true), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + // Expect each partitioned graph to contain the stack traces for its nodes. + // A stack trace for A1 should be in the A partition (".../cpu:0"). + string a = "/job:a/replica:0/task:0/cpu:0"; + const GraphDebugInfo& a_debug_info = partitions_[a].debug_info(); + const auto& a_it = a_debug_info.traces().find("A1"); + EXPECT_EQ(1, a_debug_info.traces().size()); + EXPECT_THAT(a_it, Ne(a_debug_info.traces().end())); + EXPECT_THAT(FormatStackTrace(a_it->second, a_debug_info), + Eq("x@main.cc:20.0\n" + "a1@alpha.cc:30.0\n")); + + // Stack traces for B1 and B2 should be in the B partition (".../cpu:1"). + string b = "/job:a/replica:0/task:0/cpu:1"; + const GraphDebugInfo& b_debug_info = partitions_[b].debug_info(); + const auto& b1_it = b_debug_info.traces().find("B1"); + const auto& b2_it = b_debug_info.traces().find("B2"); + EXPECT_EQ(2, b_debug_info.traces().size()); + EXPECT_THAT(b1_it, Ne(b_debug_info.traces().end())); + EXPECT_THAT(b2_it, Ne(b_debug_info.traces().end())); + EXPECT_THAT(FormatStackTrace(b1_it->second, b_debug_info), + Eq("y@window.cc:21.0\n" + "b1@beta.cc:35.0\n")); + EXPECT_THAT(FormatStackTrace(b2_it->second, b_debug_info), + Eq("bar@cache.cc:22.0\n" + "b2@beta.cc:39.0\n")); +} + TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) { // Create placeholders, shuffle them so the order in the graph is not strictly // increasing. diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc index 6de0c5d29d6d34..c4f4fdc1f3ced5 100644 --- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/base/call_once.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -111,8 +112,11 @@ PluginGraphOptimizerRegistry::CreateOptimizers( for (auto it = GetPluginRegistrationMap()->begin(); it != GetPluginRegistrationMap()->end(); ++it) { if (device_types.find(it->first) == device_types.end()) continue; - LOG(INFO) << "Plugin optimizer for device_type " << it->first - << " is enabled."; + static absl::once_flag plugin_optimizer_flag; + absl::call_once(plugin_optimizer_flag, [&]() { + LOG(INFO) << "Plugin optimizer for device_type " << it->first + << " is enabled."; + }); optimizer_list.emplace_back( std::unique_ptr(it->second())); } diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc index 9c757fa333a53a..bcea37d586529f 100644 --- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc +++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc @@ -784,10 +784,12 @@ Status ApplyAutoShard(const NodeDef& sink_node, int64_t num_workers, default: Status s = ShardByFile(sink_node, num_workers, index, &flib, graph); if (absl::IsNotFound(s)) { - LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy " - "as it failed to apply FILE sharding policy because of " - "the following reason: " - << s.message(); + if (VLOG_IS_ON(2)) { + VLOG(2) << "AUTO sharding policy will apply DATA sharding policy " + "as it failed to apply FILE sharding policy because of " + "the following reason: " + << s.message(); + } *policy_applied = AutoShardPolicy::DATA; return ShardByData(sink_node, num_workers, index, num_replicas, graph); } diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index c894f57cd8a2f2..97e59383d23e5f 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -758,6 +758,10 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index, IsMatMul(*contraction_node_def) || IsDepthwiseConv2dNative(*contraction_node_def); +#ifdef DNNL_AARCH64_USE_ACL + if (IsDepthwiseConv2dNative(*contraction_node_def)) is_contraction = false; +#endif + if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) || HasControlFaninOrFanout(*contraction_node_view) || !HasAtMostOneFanoutAtPort0(*contraction_node_view) || @@ -4439,6 +4443,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } +#ifndef DNNL_AARCH64_USE_ACL // Remap {Conv2D,Conv3D}+BiasAdd+Add into the _FusedConv2D/3D. if (FindContractionWithBiasAddAndAdd(ctx, i, &contract_with_bias_and_add)) { @@ -4447,6 +4452,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, &invalidated_nodes, &nodes_to_delete)); continue; } +#endif PadWithConv3D pad_with_conv3d; // Remap Pad+{Conv3D,_FusedConv3D} into the _FusedConv3D. @@ -4483,6 +4489,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } +#ifndef DNNL_AARCH64_USE_ACL // Fuse Conv2d + BiasAdd/FusedBatchNorm + Swish. std::map fusedconv2dSwish_matched_nodes_map; std::set fusedconv2dSwish_remove_node_indices; @@ -4494,6 +4501,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, &invalidated_nodes, &nodes_to_delete)); continue; } +#endif + // Remap Maximum(x, alpha * x) pattern, fuse them into the LeakyRelu(x). std::map mulmax_matched_nodes_map; std::set mulmax_remove_node_indices; diff --git a/tensorflow/core/grappler/utils/pattern_utils.cc b/tensorflow/core/grappler/utils/pattern_utils.cc index 2d4c0a9b5a1c05..1bf827fcc6f98e 100644 --- a/tensorflow/core/grappler/utils/pattern_utils.cc +++ b/tensorflow/core/grappler/utils/pattern_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/pattern_utils.h" #include +#include #include "absl/container/flat_hash_set.h" @@ -171,7 +172,7 @@ bool SubGraphMatcher::GetMatchedNodes( MutableNodeView* node_view, std::map* matched_nodes_map, std::set* remove_node_indices) { bool found_match = false; - match_.reset(new NodeViewMatch()); + match_ = std::make_unique(); if (DoesOpTypePatternMatch(pattern, node_view, match_.get())) { if (IsSafeNodesToRemove(nodes_to_preserve)) { found_match = true; diff --git a/tensorflow/core/grappler/utils/scc_test.cc b/tensorflow/core/grappler/utils/scc_test.cc index b5fa76ef8bf4fc..b43fc1c40fdf1a 100644 --- a/tensorflow/core/grappler/utils/scc_test.cc +++ b/tensorflow/core/grappler/utils/scc_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/scc.h" + +#include + #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -31,7 +34,7 @@ class SCCTest : public ::testing::Test { std::unordered_map devices; DeviceProperties unknown_device; devices["MY_DEVICE"] = unknown_device; - cluster_.reset(new VirtualCluster(devices)); + cluster_ = std::make_unique(devices); TF_CHECK_OK(cluster_->Provision()); } diff --git a/tensorflow/core/ir/BUILD b/tensorflow/core/ir/BUILD index c1618ee0cdf169..3900892c21d85f 100644 --- a/tensorflow/core/ir/BUILD +++ b/tensorflow/core/ir/BUILD @@ -1,10 +1,10 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/compiler/mlir/tensorflow:__subpackages__", "//tensorflow/core:__subpackages__", diff --git a/tensorflow/core/ir/importexport/tests/BUILD b/tensorflow/core/ir/importexport/tests/BUILD index 71fa743ef9b557..f25d004408d86f 100644 --- a/tensorflow/core/ir/importexport/tests/BUILD +++ b/tensorflow/core/ir/importexport/tests/BUILD @@ -1,9 +1,9 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [":__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/BUILD b/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/BUILD index 71fa743ef9b557..f25d004408d86f 100644 --- a/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/BUILD +++ b/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/BUILD @@ -1,9 +1,9 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [":__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/ir/importexport/tests/mlir_to_graphdef/BUILD b/tensorflow/core/ir/importexport/tests/mlir_to_graphdef/BUILD index 71fa743ef9b557..f25d004408d86f 100644 --- a/tensorflow/core/ir/importexport/tests/mlir_to_graphdef/BUILD +++ b/tensorflow/core/ir/importexport/tests/mlir_to_graphdef/BUILD @@ -1,9 +1,9 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [":__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/ir/tests/BUILD b/tensorflow/core/ir/tests/BUILD index 315309f01fdce1..14304dfdee7d5e 100644 --- a/tensorflow/core/ir/tests/BUILD +++ b/tensorflow/core/ir/tests/BUILD @@ -1,10 +1,10 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [":__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/ir/types/BUILD b/tensorflow/core/ir/types/BUILD index dcaa5ed2b765f4..f638202aa752c0 100644 --- a/tensorflow/core/ir/types/BUILD +++ b/tensorflow/core/ir/types/BUILD @@ -1,10 +1,10 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = ["//tensorflow/core:__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 4657f2c18d3bec..fc98d646c003dd 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -157,6 +157,27 @@ class BatchResource : public serving::BatchResourceBase { const std::vector& allowed_batch_sizes, bool enable_large_batch_splitting, std::unique_ptr* resource) { + return Create(has_process_batch_function, num_batch_threads, + max_execution_batch_size, batch_timeout_micros, + max_enqueued_batches, allowed_batch_sizes, + /*low_priority_max_batch_size=*/0, + /*low_priority_batch_timeout_micros=*/0, + /*low_priority_max_enqueued_batches=*/0, + /*low_priority_allowed_batch_sizes=*/{}, + enable_large_batch_splitting, resource); + } + + static Status Create( + bool has_process_batch_function, int32_t num_batch_threads, + int32_t max_execution_batch_size, int32_t batch_timeout_micros, + int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + int32_t low_priority_max_batch_size, + int32_t low_priority_batch_timeout_micros, + int32_t low_priority_max_enqueued_batches, + const std::vector& low_priority_allowed_batch_sizes, + bool enable_large_batch_splitting, + std::unique_ptr* resource) { BatcherT::Options batcher_options; batcher_options.num_batch_threads = num_batch_threads; std::shared_ptr batcher; @@ -167,7 +188,11 @@ class BatchResource : public serving::BatchResourceBase { GetBatcherQueueOptions( num_batch_threads, max_execution_batch_size, batch_timeout_micros, max_enqueued_batches, allowed_batch_sizes, - enable_large_batch_splitting, /*disable_padding=*/false), + enable_large_batch_splitting, + /*disable_padding=*/false, low_priority_max_batch_size, + low_priority_batch_timeout_micros, + low_priority_max_enqueued_batches, + low_priority_allowed_batch_sizes), allowed_batch_sizes)); return OkStatus(); } @@ -393,7 +418,10 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { TF_RETURN_IF_ERROR(BatchResource::Create( /*has_process_batch_function=*/true, num_batch_threads_, max_batch_size_, batch_timeout_micros_, max_enqueued_batches_, - allowed_batch_sizes_, enable_large_batch_splitting_, &new_resource)); + allowed_batch_sizes_, low_priority_max_batch_size_, + low_priority_batch_timeout_micros_, + low_priority_max_enqueued_batches_, low_priority_allowed_batch_sizes_, + enable_large_batch_splitting_, &new_resource)); if (session_metadata) { new_resource->set_session_metadata(*session_metadata); } diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index b7e6bfd2f6cf45..ad431a0956ec90 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -363,6 +363,7 @@ cc_library( "//tensorflow/core/profiler/lib:traceme_encode", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/util:incremental_barrier", + "//tensorflow/tsl/platform:criticality", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index fcdf0e46fc8a6a..4f2c00811a7732 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -272,6 +272,11 @@ Status BatchResourceBase::RegisterInput( batch_components->start_time = EnvTime::NowNanos(); batch_components->guid = guid; batch_components->propagated_context = Context(ContextKind::kThread); + + if (batcher_queue_options_.enable_priority_queue) { + batch_components->criticality = tsl::criticality::GetCriticality(); + } + OpInputList tensors; TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors)); batch_components->inputs.reserve(tensors.size()); @@ -381,10 +386,44 @@ BatchResourceBase::GetBatcherQueueOptions( int32_t batch_timeout_micros, int32_t max_enqueued_batches, const std::vector& allowed_batch_sizes, bool enable_large_batch_splitting, bool disable_padding) { + return GetBatcherQueueOptions( + num_batch_threads, max_batch_size, batch_timeout_micros, + max_enqueued_batches, allowed_batch_sizes, enable_large_batch_splitting, + disable_padding, /*low_priority_max_batch_size=*/0, + /*low_priority_batch_timeout_micros=*/0, + /*low_priority_max_enqueued_batches=*/0, + /*low_priority_allowed_batch_sizes=*/{}); +} + +/*static*/ BatchResourceBase::BatcherT::QueueOptions +BatchResourceBase::GetBatcherQueueOptions( + int32_t num_batch_threads, int32_t max_batch_size, + int32_t batch_timeout_micros, int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + bool enable_large_batch_splitting, bool disable_padding, + int32_t low_priority_max_batch_size, + int32_t low_priority_batch_timeout_micros, + int32_t low_priority_max_enqueued_batches, + const std::vector& low_priority_allowed_batch_sizes) { BatcherT::QueueOptions batcher_queue_options; batcher_queue_options.input_batch_size_limit = max_batch_size; batcher_queue_options.max_enqueued_batches = max_enqueued_batches; batcher_queue_options.batch_timeout_micros = batch_timeout_micros; + if (low_priority_max_batch_size > 0) { + batcher_queue_options.enable_priority_queue = true; + } + batcher_queue_options.high_priority_queue_options.input_batch_size_limit = + max_batch_size; + batcher_queue_options.high_priority_queue_options.max_enqueued_batches = + max_enqueued_batches; + batcher_queue_options.high_priority_queue_options.batch_timeout_micros = + batch_timeout_micros; + batcher_queue_options.low_priority_queue_options.input_batch_size_limit = + low_priority_max_batch_size; + batcher_queue_options.low_priority_queue_options.max_enqueued_batches = + low_priority_max_enqueued_batches; + batcher_queue_options.low_priority_queue_options.batch_timeout_micros = + low_priority_batch_timeout_micros; batcher_queue_options.enable_large_batch_splitting = enable_large_batch_splitting; if (enable_large_batch_splitting) { @@ -398,9 +437,21 @@ BatchResourceBase::GetBatcherQueueOptions( if (allowed_batch_sizes.empty()) { batcher_queue_options.max_execution_batch_size = max_batch_size; + batcher_queue_options.high_priority_queue_options + .max_execution_batch_size = max_batch_size; } else { batcher_queue_options.max_execution_batch_size = *allowed_batch_sizes.rbegin(); + batcher_queue_options.high_priority_queue_options + .max_execution_batch_size = *allowed_batch_sizes.rbegin(); + } + if (low_priority_allowed_batch_sizes.empty()) { + batcher_queue_options.low_priority_queue_options + .max_execution_batch_size = low_priority_max_batch_size; + } else { + batcher_queue_options.low_priority_queue_options + .max_execution_batch_size = + *low_priority_allowed_batch_sizes.rbegin(); } } batcher_queue_options.disable_padding = disable_padding; diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 446186472f5a62..a2fbfbcb7d9756 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/tsl/platform/criticality.h" namespace tensorflow { namespace serving { @@ -107,6 +108,8 @@ class BatchResourceBase : public ResourceBase { // this task's processing costs. RequestCost* request_cost = nullptr; + tsl::criticality::Criticality criticality; + protected: virtual std::unique_ptr CreateDerivedTask() { return std::make_unique(); @@ -166,6 +169,16 @@ class BatchResourceBase : public ResourceBase { const std::vector& allowed_batch_sizes, bool enable_large_batch_splitting, bool disable_padding); + static BatcherT::QueueOptions GetBatcherQueueOptions( + int32_t num_batch_threads, int32_t max_batch_size, + int32_t batch_timeout_micros, int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + bool enable_large_batch_splitting, bool disable_padding, + int32_t low_priority_max_batch_size, + int32_t low_priority_batch_timeout_micros, + int32_t low_priority_max_enqueued_batches, + const std::vector& low_priority_allowed_batch_sizes); + static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions( int32_t max_batch_size, int32_t batch_timeout_micros, int32_t max_enqueued_batches, bool enable_large_batch_splitting, diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 0e5c0b2f210709..df7866c5f3c473 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -221,6 +221,27 @@ class SharedBatchScheduler // If true, the padding will not be appended. bool disable_padding = false; + + // If true, queue implementation would split high priority and low priority + // inputs into two sub queues. + bool enable_priority_queue = false; + + // A separate set of queue options for different priority inputs. + // Use iff `enable_priority_queue` is true. + struct PriorityQueueOptions { + // See QueueOptions.max_execution_batch_size + size_t max_execution_batch_size = 0; + // See QueueOptions.batch_timeout_micros + int64_t batch_timeout_micros = 0; + // See QueueOptions.input_batch_size_limit + size_t input_batch_size_limit = 0; + // See QueueOptions.max_enqueued_batches + size_t max_enqueued_batches = 0; + }; + // A subset of queue options for high priority input. + PriorityQueueOptions high_priority_queue_options; + // A subset of queue options for low priority input. + PriorityQueueOptions low_priority_queue_options; }; Status AddQueue(const QueueOptions& options, std::function>)> @@ -465,6 +486,14 @@ class Queue { std::deque>>> task_handle_batches_ TF_GUARDED_BY(mu_); + // The enqueued batches for low priority input + std::deque>> low_priority_batches_ + TF_GUARDED_BY(mu_); + + // The enqueued batches for high priority input + std::deque>> high_priority_batches_ + TF_GUARDED_BY(mu_); + // The counter of the TraceMe context ids. uint64 traceme_context_id_counter_ TF_GUARDED_BY(mu_) = 0; diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 57e7fb63c59827..2c1025a7cab5ef 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -314,13 +314,13 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { context, GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, - dims.spatial_dims[0].stride, padding_, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); OP_REQUIRES_OK( context, GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, - dims.spatial_dims[1].stride, padding_, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); // The total dimension size of each kernel. diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc index 0cb52b8419da8c..69fc08bb3d364f 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc @@ -16,7 +16,10 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS +#include +#include #include +#include #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/numeric_op.h" @@ -375,24 +378,27 @@ class Conv3DCustomBackpropFilterOp : public OpKernel { int64_t top_pad_rows, bottom_pad_rows; int64_t left_pad_cols, right_pad_cols; - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[0].input_size, - dims.spatial_dims[0].filter_size, - dims.spatial_dims[0].stride, padding_, - &dims.spatial_dims[0].output_size, - &top_pad_planes, &bottom_pad_planes)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[1].input_size, - dims.spatial_dims[1].filter_size, - dims.spatial_dims[1].stride, padding_, - &dims.spatial_dims[1].output_size, - &top_pad_rows, &bottom_pad_rows)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[2].input_size, - dims.spatial_dims[2].filter_size, - dims.spatial_dims[2].stride, padding_, - &dims.spatial_dims[2].output_size, - &left_pad_cols, &right_pad_cols)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &top_pad_planes, + &bottom_pad_planes)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &top_pad_rows, + &bottom_pad_rows)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, dims.spatial_dims[2].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, &left_pad_cols, + &right_pad_cols)); // TODO(ezhulenev): Extract work size and shard estimation to shared // functions in conv_grad_ops, and update 2d convolution backprop. diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc index 1738c9413c9d84..403a6122d7f273 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc @@ -100,12 +100,12 @@ struct LaunchConv2DBackpropFilterOp { int64_t expected_out_rows, expected_out_cols; // The function is guaranteed to succeed because we checked the output and // padding was valid earlier. - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, row_dilation, row_stride, padding, &expected_out_rows, &padding_top, &padding_bottom)); DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, col_dilation, col_stride, padding, &expected_out_cols, &padding_left, &padding_right)); @@ -206,12 +206,12 @@ void LaunchConv2DBackpropFilterOpImpl( int64_t expected_out_rows, expected_out_cols; // The function is guaranteed to succeed because we checked the output and // padding was valid earlier. - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, row_dilation, row_stride, padding, &expected_out_rows, &padding_top, &padding_bottom)); DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, col_dilation, col_stride, padding, &expected_out_cols, &padding_left, &padding_right)); diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index f2b487f4aeab28..bf6b7dc986e3fb 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -114,12 +114,12 @@ void LaunchConv2DBackpropInputOpGpuImpl( int64_t expected_out_rows, expected_out_cols; // The function is guaranteed to succeed because we checked the output and // padding was valid earlier. - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, row_dilation, row_stride, padding, &expected_out_rows, &padding_top, &padding_bottom)); DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, col_dilation, col_stride, padding, &expected_out_cols, &padding_left, &padding_right)); diff --git a/tensorflow/core/kernels/conv_grad_input_ops.h b/tensorflow/core/kernels/conv_grad_input_ops.h index f330ee672a66b2..b7e4e9c6837c41 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.h +++ b/tensorflow/core/kernels/conv_grad_input_ops.h @@ -144,13 +144,13 @@ struct LaunchConv2DBackpropInputOpImpl { int64_t expected_out_rows, expected_out_cols; // The function is guaranteed to succeed because we checked the output and // padding was valid earlier. - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, row_dilation, row_stride, padding, &expected_out_rows, &padding_top, &padding_bottom)); DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, col_dilation, col_stride, padding, &expected_out_cols, &padding_left, &padding_right)); @@ -525,13 +525,13 @@ class Conv2DCustomBackpropInputOp : public OpKernel { context, GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, - dims.spatial_dims[0].stride, padding_, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); OP_REQUIRES_OK( context, GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, - dims.spatial_dims[1].stride, padding_, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); // The total dimension size of each kernel. diff --git a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc index 677c8c3b2b2654..5e15e72a66eaa0 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc @@ -16,7 +16,10 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS +#include +#include #include +#include #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/numeric_op.h" @@ -360,24 +363,27 @@ class Conv3DCustomBackpropInputOp : public OpKernel { int64_t top_pad_rows, bottom_pad_rows; int64_t left_pad_cols, right_pad_cols; - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[0].input_size, - dims.spatial_dims[0].filter_size, - dims.spatial_dims[0].stride, padding_, - &dims.spatial_dims[0].output_size, - &top_pad_planes, &bottom_pad_planes)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[1].input_size, - dims.spatial_dims[1].filter_size, - dims.spatial_dims[1].stride, padding_, - &dims.spatial_dims[1].output_size, - &top_pad_rows, &bottom_pad_rows)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[2].input_size, - dims.spatial_dims[2].filter_size, - dims.spatial_dims[2].stride, padding_, - &dims.spatial_dims[2].output_size, - &left_pad_cols, &right_pad_cols)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &top_pad_planes, + &bottom_pad_planes)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &top_pad_rows, + &bottom_pad_rows)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, dims.spatial_dims[2].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, &left_pad_cols, + &right_pad_cols)); // TODO(ezhulenev): Extract work size and shard estimation to shared // functions in conv_grad_ops, and update 2d convolution backprop. diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.cc b/tensorflow/core/kernels/conv_grad_shape_utils.cc index f2686e1dd6cc60..9560a37fd6eea6 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.cc +++ b/tensorflow/core/kernels/conv_grad_shape_utils.cc @@ -63,7 +63,7 @@ Status ConvBackpropExtractAndVerifyDimension( dim->stride = strides[spatial_dim]; dim->dilation = dilations[spatial_dim]; int64_t out_size = 0; - TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( dim->input_size, dim->filter_size, dim->dilation, dim->stride, padding, &out_size, &padding_before, &padding_after)); if (dim->output_size != out_size) { diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index f66b25674c3704..96fa2185f3fe01 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -194,10 +194,10 @@ Status ComputeConv2DDimension(const Conv2DParameters& params, // Compute windowed output sizes for rows and columns. int64_t out_rows = 0, out_cols = 0; - TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( input_rows, filter_rows, dilation_rows, stride_rows, params.padding, &out_rows, &pad_rows_before, &pad_rows_after)); - TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( input_cols, filter_cols, dilation_cols, stride_cols, params.padding, &out_cols, &pad_cols_before, &pad_cols_after)); diff --git a/tensorflow/core/kernels/conv_ops_impl.h b/tensorflow/core/kernels/conv_ops_impl.h index be916e42b48d93..22b88454435f8c 100644 --- a/tensorflow/core/kernels/conv_ops_impl.h +++ b/tensorflow/core/kernels/conv_ops_impl.h @@ -479,7 +479,7 @@ class ConvOp : public BinaryOp { // Compute windowed output sizes for spatial dimensions. std::vector out_dims(spatial_dims); for (int i = 0; i < spatial_dims; ++i) { - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerboseV2( + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( input_dims[i], filter_dims[i], dilation_dims[i], stride_dims[i], padding_, &out_dims[i], &pad_before[i], &pad_after[i])); @@ -843,16 +843,16 @@ void LaunchConv2DOpImpl(OpKernelContext* ctx, bool use_cudnn, &padding_right); } int64_t out_rows_check, out_cols_check; - Status status = GetWindowedOutputSizeVerboseV2( + Status status = GetWindowedOutputSizeVerbose( in_rows, patch_rows, row_dilation, row_stride, padding, &out_rows_check, &padding_top, &padding_bottom); // The status is guaranteed to be OK because we checked the output and padding // was valid earlier. TF_CHECK_OK(status); DCHECK_EQ(out_rows, out_rows_check); - status = GetWindowedOutputSizeVerboseV2(in_cols, patch_cols, col_dilation, - col_stride, padding, &out_cols_check, - &padding_left, &padding_right); + status = GetWindowedOutputSizeVerbose(in_cols, patch_cols, col_dilation, + col_stride, padding, &out_cols_check, + &padding_left, &padding_right); TF_CHECK_OK(status); DCHECK_EQ(out_cols, out_cols_check); diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index 1e97ef38b5b7a1..b16458aa8052ca 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -121,12 +121,14 @@ typedef Eigen::GpuDevice GPUDevice; GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left, \ &pad_right); \ } \ - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( \ - input_rows, filter_rows, stride_, padding_, \ - &out_rows, &pad_top, &pad_bottom)); \ - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( \ - input_cols, filter_cols, stride_, padding_, \ - &out_cols, &pad_left, &pad_right)); \ + OP_REQUIRES_OK(context, \ + GetWindowedOutputSizeVerbose( \ + input_rows, filter_rows, /*dilation_rate=*/1, stride_, \ + padding_, &out_rows, &pad_top, &pad_bottom)); \ + OP_REQUIRES_OK(context, \ + GetWindowedOutputSizeVerbose( \ + input_cols, filter_cols, /*dilation_rate=*/1, stride_, \ + padding_, &out_cols, &pad_left, &pad_right)); \ OP_REQUIRES( \ context, output_rows == out_rows, \ errors::InvalidArgument( \ diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index b282855666b4ae..a636759602d92d 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -411,12 +411,14 @@ class DepthwiseConv2dNativeOp : public BinaryOp { GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left, &pad_right); } - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - input_rows, filter_rows, stride_, padding_, - &out_rows, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - input_cols, filter_cols, stride_, padding_, - &out_cols, &pad_left, &pad_right)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + input_rows, filter_rows, /*dilation_rate=*/1, stride_, + padding_, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + input_cols, filter_cols, /*dilation_rate=*/1, stride_, + padding_, &out_cols, &pad_left, &pad_right)); TensorShape out_shape; OP_REQUIRES_OK(context, ShapeFromFormatWithStatus(data_format_, batch, out_rows, diff --git a/tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc b/tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc index 787e10009c96da..ea17709afeba3d 100644 --- a/tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc +++ b/tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc @@ -93,14 +93,14 @@ class ROCmFusionKernelConvolutionBiasActivation : public OpKernel { int64 output_rows = 0, padding_left = 0, padding_right = 0; OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( + ctx, GetWindowedOutputSizeVerbose( input_rows, filter_rows, dilation_rows, stride_rows, padding_type_, &output_rows, &padding_left, &padding_right)); int64 padding_rows = padding_left + padding_right; int64 output_cols = 0, padding_top = 0, padding_bottom = 0; OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( + ctx, GetWindowedOutputSizeVerbose( input_cols, filter_cols, dilation_cols, stride_cols, padding_type_, &output_cols, &padding_top, &padding_bottom)); int64 padding_cols = padding_top + padding_bottom; diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 5d6561a201a25a..1ce48822ea2c20 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -457,7 +457,10 @@ REGISTER_KERNEL_BUILDER( Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint("T"), InplaceOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER); +REGISTER(int8_t); +REGISTER(uint8_t); REGISTER(int64_t); +REGISTER(uint64_t); REGISTER_EMPTY(int32, GPU); #undef REGISTER diff --git a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc index f2536da06bec24..06a9a0b56aefd0 100644 --- a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc @@ -177,10 +177,13 @@ Status DoInplace(const Device& d, InplaceOpType op, const Tensor& i, CASE(double) CASE(Eigen::half) CASE(Eigen::bfloat16) - CASE(int64) + CASE(uint8_t) + CASE(int8_t) + CASE(int64_t) + CASE(uint64_t) #undef CASE default: - return errors::InvalidArgument("Unsupported data type: ", + return errors::InvalidArgument("Unsupported data type from DoInplace: ", DataTypeString(v.dtype())); } return OkStatus(); @@ -202,10 +205,13 @@ Status DoCopy(const Device& d, const Tensor& x, Tensor* y) { CASE(Eigen::bfloat16) CASE(complex64) CASE(complex128) - CASE(int64) + CASE(uint8_t) + CASE(int8_t) + CASE(int64_t) + CASE(uint64_t) #undef CASE default: - return errors::InvalidArgument("Unsupported dtype: ", + return errors::InvalidArgument("Unsupported dtype from DoCopy: ", DataTypeString(x.dtype())); } return OkStatus(); diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 369a81bb1d5105..09814bcbb6026a 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -691,6 +691,12 @@ class TensorListGather : public OpKernel { if (!tensor_list->element_shape.IsFullyDefined()) { for (int index = 0; index < indices.NumElements(); ++index) { const int i = indices.flat()(index); + + OP_REQUIRES(c, 0 <= i && i < tensor_list->tensors().size(), + absl::InvalidArgumentError(absl::StrCat( + "Trying to gather element ", i, " in a list with ", + tensor_list->tensors().size(), " elements."))); + const Tensor& t = tensor_list->tensors()[i]; if (t.dtype() != DT_INVALID) { PartialTensorShape tmp = partial_element_shape; diff --git a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc index b1c7081e3dbe8a..c3ca762f54349e 100644 --- a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc @@ -123,7 +123,12 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { dnnl::algorithm::pooling_avg_exclude_padding, pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, this->native_format_); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); // Allocate output tensor. @@ -340,7 +345,10 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { prop_kind::forward_training, static_cast(this->data_format_mkldnn_), src_md, this->native_format_); - MklDnnThreadPool eigen_tp(context); + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklPoolingBwdPrimitive* pooling_bwd = MklPoolingBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index 3562e2c83da56e..58ac1064e0d104 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -160,7 +160,12 @@ class BatchMatMulMkl : public OpKernel { out_shape, adj_x_, adj_y_); this->ExtendMklMatMulParams(ctx, *params); - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // Create or retrieve matmul primitive from cache. MklMatMulPrimitive* matmul_prim = MklMatMulPrimitiveFactory::Get( diff --git a/tensorflow/core/kernels/mkl/mkl_concat_op.cc b/tensorflow/core/kernels/mkl/mkl_concat_op.cc index 07e758a5e77d61..06c51b4016fd8e 100644 --- a/tensorflow/core/kernels/mkl/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_concat_op.cc @@ -153,10 +153,10 @@ class EigenConcatBaseOp : public OpKernel { const OpInputList& input_mins, const OpInputList& input_maxes, bool quantized_input) { const Tensor* concat_dim_tensor; - const char* axis_attribute_name = - AxisArgName == NAME_IS_AXIS - ? "axis" - : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : ""; + const char* axis_attribute_name = AxisArgName == NAME_IS_AXIS ? "axis" + : AxisArgName == NAME_IS_CONCAT_DIM + ? "concat_dim" + : ""; OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); OP_REQUIRES(c, TensorShapeUtils::IsScalar(concat_dim_tensor->shape()), errors::InvalidArgument( @@ -760,7 +760,12 @@ class MklConcatOp : public OpKernel { // then since MklDnn order is NCHW, concat_dim needs to be 1. if (are_all_mkl_inputs) concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); if (!inputs.empty()) { if (are_all_mkl_inputs) { auto concat_pd = diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc index 0b5519fc46b06f..8d8c9b40451f17 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc @@ -517,7 +517,12 @@ class MklConvCustomBackpropFilterOp // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true. bool do_not_cache = MklPrimitiveFactory::IsPrimitiveMemOptEnabled(); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklConvBwdFilterPrimitive* conv_bwd_filter = MklConvBwdFilterPrimitiveFactory::Get(convBwdFilterDims, do_not_cache); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc index 3eab40d24ee07f..16a6db176843b1 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc @@ -470,7 +470,12 @@ class MklConvCustomBackpropInputOp (MklPrimitiveFactory::IsLegacyPlatform() || IsConv1x1StrideNot1(fwd_filter_dims, strides)); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklConvBwdInputPrimitive* conv_bwd_input = MklConvBwdInputPrimitiveFactory::Get(convBwdInputDims, do_not_cache); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 8dae3705e0a811..c179fca517df3b 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -50,6 +50,13 @@ namespace tensorflow { #define SET_FUSE_ACTIVATION_FOR_RELU6 \ set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 6.0) #define SET_MKL_LAYOUT(md) SetMklLayout(&md) +#define OUTPUT_SCALE_DCHECK (post_op_param.name == "output_scale") +#define TSCALED_BIAS Tbias +#define SCALE scales +#define SUMMAND_SCALE_U8(summand_range, output_range) \ + summand_range / output_range +#define SUMMAND_SCALE_S8(summand_range, output_range) \ + 255.0f * summand_range / (output_range * 127.0f) #else #define APPEND_DEPTHWISE(wei_dt, bias_dt, dst_dt, kernel, stride, padding, \ scales_mask, scales) \ @@ -59,8 +66,22 @@ namespace tensorflow { #define SET_FUSE_ACTIVATION_FOR_RELU6 \ set_fuse_activation(true, dnnl::algorithm::eltwise_clip, 0.0, 6.0) #define SET_MKL_LAYOUT(md) SetMklLayout(md) +#define OUTPUT_SCALE_DCHECK \ + (post_op_param.name == "src_scale") || \ + (post_op_param.name == "wei_scale") || \ + (post_op_param.name == "dst_scale") +#define TSCALED_BIAS float +#define SCALE wei_scale +#define SUMMAND_SCALE_U8(summand_range, output_range) summand_range / 255.0f +#define SUMMAND_SCALE_S8(summand_range, output_range) summand_range / 127.0f #endif // !ENABLE_ONEDNN_V3 +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#define FWD_STREAM , *fwd_stream +#else +#define FWD_STREAM +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 + // TODO(intel-tf) Remove this once old API of quantized ops is abandoned namespace quantized_fusions { string none[] = {""}; @@ -88,12 +109,14 @@ struct MklConvFwdParams { memory::dims fuse_bn_dims; MklTensorFormat tf_fmt; bool native_format; + bool is_depthwise; string dtypes = string(""); struct PostOpParam { string name; dnnl::algorithm alg; std::vector param; std::string partial_key; + DataType dtype = DT_INVALID; }; std::vector post_op_params; @@ -102,7 +125,7 @@ struct MklConvFwdParams { memory::dims strides, memory::dims dilations, memory::dims padding_left, memory::dims padding_right, memory::dims fuse_bn_dims, MklTensorFormat tf_fmt, - bool native_format) + bool native_format, bool is_depthwise) : src_dims(src_dims), filter_dims(filter_dims), bias_dims(bias_dims), @@ -113,7 +136,8 @@ struct MklConvFwdParams { padding_right(padding_right), fuse_bn_dims(fuse_bn_dims), tf_fmt(tf_fmt), - native_format(native_format) {} + native_format(native_format), + is_depthwise(is_depthwise) {} }; // With quantization, input, filter, and output can have different types @@ -140,16 +164,18 @@ class MklConvFwdPrimitive : public MklPrimitive { // bias_data: input data buffer of bias // dst_data: output data buffer of dst void Execute(const Tinput* src_data, const Tfilter* filter_data, - const Tbias* bias_data, const Toutput* dst_data, + const void* bias_data, const Toutput* dst_data, + const MklConvFwdParams& convFwdDims, std::shared_ptr fwd_stream, void* sp_data = nullptr) { Execute(src_data, filter_data, bias_data, dst_data, nullptr, nullptr, - nullptr, nullptr, fwd_stream, sp_data); + nullptr, nullptr, convFwdDims, fwd_stream, sp_data); } void Execute(const Tinput* src_data, const Tfilter* filter_data, - const Tbias* bias_data, const Toutput* dst_data, + const void* bias_data, const Toutput* dst_data, const Tinput* bn_scale_data, const Tinput* bn_mean_data, const Tinput* bn_offset_data, const Tinput* bn_rsqrt_data, + const MklConvFwdParams& convFwdDims, std::shared_ptr fwd_stream, void* sp_data) { #ifdef DNNL_AARCH64_USE_ACL // When we are using single global cache then in this case we can have @@ -157,54 +183,44 @@ class MklConvFwdPrimitive : public MklPrimitive { // should happen under the lock. mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) context_.src_mem->set_data_handle( - static_cast(const_cast(src_data)), *fwd_stream); + static_cast(const_cast(src_data)) FWD_STREAM); context_.filter_mem->set_data_handle( - static_cast(const_cast(filter_data)), *fwd_stream); + static_cast(const_cast(filter_data)) FWD_STREAM); if (bias_data != nullptr) { - context_.bias_mem->set_data_handle( - static_cast(const_cast(bias_data)), *fwd_stream); + context_.bias_mem->set_data_handle(const_cast(bias_data) + FWD_STREAM); + } + auto const& post_op_params = convFwdDims.post_op_params; + if (!post_op_params.empty()) { + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "src_scale") { + context_.src_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data())) FWD_STREAM); + } else if (post_op_param.name == "wei_scale") { + context_.wei_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data())) FWD_STREAM); + } else if (post_op_param.name == "dst_scale") { + context_.dst_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data())) FWD_STREAM); + } + } } if (bn_scale_data != nullptr) { context_.bn_scale_mem->set_data_handle( - static_cast(const_cast(bn_scale_data)), *fwd_stream); + static_cast(const_cast(bn_scale_data)) FWD_STREAM); context_.bn_mean_mem->set_data_handle( - static_cast(const_cast(bn_mean_data)), *fwd_stream); + static_cast(const_cast(bn_mean_data)) FWD_STREAM); context_.bn_rsqrt_mem->set_data_handle( - static_cast(const_cast(bn_rsqrt_data)), *fwd_stream); + static_cast(const_cast(bn_rsqrt_data)) FWD_STREAM); context_.bn_offset_mem->set_data_handle( - static_cast(const_cast(bn_offset_data)), *fwd_stream); + static_cast(const_cast(bn_offset_data)) FWD_STREAM); } context_.dst_mem->set_data_handle( - static_cast(const_cast(dst_data)), *fwd_stream); + static_cast(const_cast(dst_data)) FWD_STREAM); if (sp_data) { - context_.sp_mem->set_data_handle(static_cast(sp_data), - *fwd_stream); - } -#else - context_.src_mem->set_data_handle( - static_cast(const_cast(src_data))); - context_.filter_mem->set_data_handle( - static_cast(const_cast(filter_data))); - if (bias_data != nullptr) { - context_.bias_mem->set_data_handle( - static_cast(const_cast(bias_data))); - } - if (bn_scale_data != nullptr) { - context_.bn_scale_mem->set_data_handle( - static_cast(const_cast(bn_scale_data))); - context_.bn_mean_mem->set_data_handle( - static_cast(const_cast(bn_mean_data))); - context_.bn_rsqrt_mem->set_data_handle( - static_cast(const_cast(bn_rsqrt_data))); - context_.bn_offset_mem->set_data_handle( - static_cast(const_cast(bn_offset_data))); + context_.sp_mem->set_data_handle(static_cast(sp_data) FWD_STREAM); } - context_.dst_mem->set_data_handle( - static_cast(const_cast(dst_data))); - if (sp_data) context_.sp_mem->set_data_handle(static_cast(sp_data)); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_primitives_args.size()); @@ -236,10 +252,10 @@ class MklConvFwdPrimitive : public MklPrimitive { // filter_data: input data buffer of filter (weights) // dst_data: output data buffer of dst void Execute(const Tinput* src_data, const Tfilter* filter_data, - const Toutput* dst_data, std::shared_ptr fwd_stream, - void* sp_data) { + const Toutput* dst_data, const MklConvFwdParams& convFwdDims, + std::shared_ptr fwd_stream, void* sp_data) { Execute(src_data, filter_data, nullptr, dst_data, nullptr, nullptr, nullptr, - nullptr, fwd_stream, sp_data); + nullptr, convFwdDims, fwd_stream, sp_data); } std::shared_ptr GetPrimitiveDesc() const { @@ -262,6 +278,11 @@ class MklConvFwdPrimitive : public MklPrimitive { std::shared_ptr bn_rsqrt_mem; std::shared_ptr bn_offset_mem; + // Quantization scale related memory + std::shared_ptr src_scale_mem; + std::shared_ptr wei_scale_mem; + std::shared_ptr dst_scale_mem; + // Desc & primitive desc #ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; @@ -280,6 +301,11 @@ class MklConvFwdPrimitive : public MklPrimitive { std::shared_ptr bn_rsqrt_md; std::shared_ptr bn_offset_md; + // Quantization scale related memory descriptors + std::shared_ptr src_scale_md; + std::shared_ptr wei_scale_md; + std::shared_ptr dst_scale_md; + // Convolution primitive std::shared_ptr conv_fwd; @@ -296,6 +322,9 @@ class MklConvFwdPrimitive : public MklPrimitive { bn_mean_mem(nullptr), bn_rsqrt_mem(nullptr), bn_offset_mem(nullptr), + src_scale_mem(nullptr), + wei_scale_mem(nullptr), + dst_scale_mem(nullptr), #ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), #endif // !ENABLE_ONEDNN_V3 @@ -307,6 +336,9 @@ class MklConvFwdPrimitive : public MklPrimitive { bn_mean_md(nullptr), bn_rsqrt_md(nullptr), bn_offset_md(nullptr), + src_scale_md(nullptr), + wei_scale_md(nullptr), + dst_scale_md(nullptr), fwd_pd(nullptr), conv_fwd(nullptr) { } @@ -331,9 +363,15 @@ class MklConvFwdPrimitive : public MklPrimitive { {convFwdDims.dst_dims}, MklDnnType(), user_data_fmt)); if (!convFwdDims.bias_dims.empty()) { - context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, - MklDnnType(), - memory::format_tag::any)); + if (std::is_same::value) { + context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, + MklDnnType(), + memory::format_tag::any)); + } else { + context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, + MklDnnType(), + memory::format_tag::any)); + } #ifndef ENABLE_ONEDNN_V3 // Create a convolution descriptor context_.fwd_desc.reset(new convolution_forward::desc( @@ -371,6 +409,7 @@ class MklConvFwdPrimitive : public MklPrimitive { dnnl::primitive_attr post_ops_attr; dnnl::post_ops post_ops; post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + std::unordered_map is_scale_set; if (!post_op_params.empty()) { for (auto const& post_op_param : post_op_params) { if (post_op_param.name == "activation") { @@ -384,21 +423,54 @@ class MklConvFwdPrimitive : public MklPrimitive { } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); float op_scale = post_op_param.param[0]; +#ifndef ENABLE_ONEDNN_V3 post_ops.append_sum(op_scale); - } else if (post_op_param.name == "output_scale") { +#else + if (post_op_param.dtype != DT_INVALID) { + if (post_op_param.dtype == DT_FLOAT) { + post_ops.append_sum(op_scale, /*zero_point=*/0, + MklDnnType()); + } else { + TF_CHECK_OK(absl::FailedPreconditionError( + "Summand data type is expected to be float")); + } + } else { + post_ops.append_sum(op_scale); + } +#endif //! ENABLE_ONEDNN_V3 #ifndef ENABLE_ONEDNN_V3 + } else if (post_op_param.name == "output_scale") { if (post_op_param.param.size() == 1) { post_ops_attr.set_output_scales(0, post_op_param.param); } else { post_ops_attr.set_output_scales(2, post_op_param.param); } #else - // TODO(intel-tf): Enable this for int8 when using oneDNN v3.x - // and return a status instead of using DCHECK_EQ - DCHECK_EQ(post_op_param.param.size(), 1); + } else if (post_op_param.name == "src_scale") { + is_scale_set.insert({"src", true}); post_ops_attr.set_scales_mask(DNNL_ARG_SRC, 0); - post_ops_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + context_.src_scale_md.reset(new memory::desc({1}, MklDnnType(), + memory::format_tag::x)); + context_.src_scale_mem.reset( + new memory(*context_.src_scale_md, cpu_engine_, DummyData)); + } else if (post_op_param.name == "wei_scale") { + is_scale_set.insert({"wei", true}); + const int scale_size = post_op_param.param.size(); + const int mask = scale_size == 1 ? 0 + : convFwdDims.is_depthwise ? 3 + : 1; + post_ops_attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask); + context_.wei_scale_md.reset(new memory::desc( + {scale_size}, MklDnnType(), memory::format_tag::x)); + context_.wei_scale_mem.reset( + new memory(*context_.wei_scale_md, cpu_engine_, DummyData)); + } else if (post_op_param.name == "dst_scale") { + is_scale_set.insert({"dst", true}); post_ops_attr.set_scales_mask(DNNL_ARG_DST, 0); + context_.dst_scale_md.reset(new memory::desc({1}, MklDnnType(), + memory::format_tag::x)); + context_.dst_scale_mem.reset( + new memory(*context_.dst_scale_md, cpu_engine_, DummyData)); #endif // !ENABLE_ONEDNN_V3 } else if (post_op_param.name == "fuse_bn") { post_ops.append_binary(dnnl::algorithm::binary_sub, @@ -411,8 +483,7 @@ class MklConvFwdPrimitive : public MklPrimitive { *context_.bn_offset_md); } else { DCHECK((post_op_param.name == "activation") || - (post_op_param.name == "sum") || - (post_op_param.name == "output_scale") || + (post_op_param.name == "sum") || OUTPUT_SCALE_DCHECK || (post_op_param.name == "fuse_bn")); } } @@ -451,16 +522,24 @@ class MklConvFwdPrimitive : public MklPrimitive { new dnnl::memory(scratchpad_md, cpu_engine_, DummyData)); // Create convolution primitive and add it to net + std::unordered_map net_args; if (!convFwdDims.bias_dims.empty()) { - context_.bias_mem.reset(new memory( - {{convFwdDims.bias_dims}, MklDnnType(), memory::format_tag::x}, - cpu_engine_, DummyData)); - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_BIAS, *context_.bias_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_DST, *context_.dst_mem}}); + context_.bias_mem.reset(new memory(context_.fwd_pd.get()->bias_desc(), + cpu_engine_, DummyData)); + net_args = {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_BIAS, *context_.bias_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}}; +#ifdef ENABLE_ONEDNN_V3 + if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { + net_args.insert( + {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}, + { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, + *context_.dst_scale_mem }}); + } +#endif // ENABLE_ONEDNN_V3 } else if (!convFwdDims.fuse_bn_dims.empty()) { context_.bn_scale_mem.reset( new memory(*context_.bn_scale_md, cpu_engine_, DummyData)); @@ -471,26 +550,34 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.bn_rsqrt_mem.reset( new memory(*context_.bn_rsqrt_md, cpu_engine_, DummyData)); - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_DST, *context_.dst_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, - *context_.bn_mean_mem}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, - *context_.bn_rsqrt_mem}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, - *context_.bn_scale_mem}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, - *context_.bn_offset_mem}}); + net_args = {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_DST, *context_.dst_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, + *context_.bn_mean_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, + *context_.bn_rsqrt_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, + *context_.bn_scale_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, + *context_.bn_offset_mem}}; } else { - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_DST, *context_.dst_mem}}); + net_args = {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}}; +#ifdef ENABLE_ONEDNN_V3 + if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { + net_args.insert( + {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}, + { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, + *context_.dst_scale_mem }}); + } +#endif // ENABLE_ONEDNN_V3 } + context_.fwd_primitives_args.push_back(net_args); context_.fwd_primitives.push_back(*context_.conv_fwd); } @@ -577,7 +664,13 @@ class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { for (auto& param : post_op_param.param) { key_creator.AddAsKey(param); } +#ifndef ENABLE_ONEDNN_V3 } else if (post_op_param.name == "output_scale") { +#else + } else if (post_op_param.name == "src_scale" || + post_op_param.name == "wei_scale" || + post_op_param.name == "dst_scale") { +#endif // !ENABLE_ONEDNN_V3 key_creator.AddAsKey(post_op_param.partial_key); } else if (post_op_param.name == "fuse_bn") { key_creator.AddAsKey(post_op_param.name); @@ -620,7 +713,7 @@ class MklConvOp : public OpKernel { context, !(context->HasAttr("padding_list") && context->HasAttr("explicit_paddings")), - errors::InvalidArgument("Can only have 1 `padding` list at most")); + absl::InvalidArgumentError("Can only have 1 `padding` list at most")); if (context->HasAttr("padding_list")) { OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); } @@ -632,17 +725,17 @@ class MklConvOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str_)); OP_REQUIRES(context, FormatFromString(data_format_str_, &data_format_), - errors::InvalidArgument("Invalid data format")); + absl::InvalidArgumentError("Invalid data format")); OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), - errors::InvalidArgument("Sliding window strides field must " - "specify 4 or 5 dimensions")); + absl::InvalidArgumentError("Sliding window strides field must " + "specify 4 or 5 dimensions")); const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); OP_REQUIRES( context, stride_n == 1 && stride_c == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); + absl::UnimplementedError("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); is_filter_const_ = false; @@ -654,28 +747,29 @@ class MklConvOp : public OpKernel { } if (strides_.size() == 4) { - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); + OP_REQUIRES( + context, dilations_.size() == 4, + absl::InvalidArgumentError("Sliding window dilations field must " + "specify 4 dimensions")); const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Current implementation does not yet support " "dilations in the batch and depth dimensions.")); OP_REQUIRES( context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); + absl::InvalidArgumentError("Dilated rates should be larger than 0.")); } else if (strides_.size() == 5) { OP_REQUIRES(context, dilations_.size() == 5, - errors::InvalidArgument("Dilation rates field must " - "specify 5 dimensions")); + absl::InvalidArgumentError("Dilation rates field must " + "specify 5 dimensions")); OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 && GetTensorDim(dilations_, data_format_, 'C') == 1), - errors::InvalidArgument( + absl::InvalidArgumentError( "Current implementation does not yet support " "dilations rates in the batch and depth dimensions.")); OP_REQUIRES( @@ -683,7 +777,7 @@ class MklConvOp : public OpKernel { (GetTensorDim(dilations_, data_format_, '0') > 0 && GetTensorDim(dilations_, data_format_, '1') > 0 && GetTensorDim(dilations_, data_format_, '2') > 0), - errors::InvalidArgument("Dilated rates should be larger than 0.")); + absl::InvalidArgumentError("Dilated rates should be larger than 0.")); } } @@ -694,8 +788,8 @@ class MklConvOp : public OpKernel { const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); OP_REQUIRES( context, filter_tensor.NumElements() > 0, - errors::InvalidArgument("filter must not have zero elements " - "(i.e. all dimensions must be non-zero)")); + absl::InvalidArgumentError("filter must not have zero elements " + "(i.e. all dimensions must be non-zero)")); if (std::is_same::value) { (void)SetFPMathMode(); @@ -707,8 +801,8 @@ class MklConvOp : public OpKernel { native_format); OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(), - errors::InvalidArgument("Filter should not be in " - "Mkl Layout")); + absl::InvalidArgumentError("Filter should not be in " + "Mkl Layout")); MklDnnData src(&cpu_engine_); MklDnnData filter(&cpu_engine_); @@ -780,18 +874,18 @@ class MklConvOp : public OpKernel { bool is_conv3d = (strides_.size() == 5); if (!is_conv2d && !is_conv3d) { - OP_REQUIRES( - context, !pad_enabled, - errors::InvalidArgument("Pad + Conv fusion only works for 2D/3D")); + OP_REQUIRES(context, !pad_enabled, + absl::InvalidArgumentError( + "Pad + Conv fusion only works for 2D/3D")); OP_REQUIRES( context, !fuse_pad_, - errors::InvalidArgument("Pad+Conv fusion only works for 2D/3D")); + absl::InvalidArgumentError("Pad+Conv fusion only works for 2D/3D")); } // TODO(intel-tf) 3-D support for Depthwise is not there if (is_depthwise) { OP_REQUIRES(context, is_conv2d, - errors::InvalidArgument( + absl::InvalidArgumentError( "Only 2D convolution is supported for depthwise.")); } @@ -804,7 +898,7 @@ class MklConvOp : public OpKernel { auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, - errors::InvalidArgument("Invalid data format")); + absl::InvalidArgumentError("Invalid data format")); // If input is in MKL layout, then simply grab the layout; otherwise, // construct TF layout for input. @@ -862,8 +956,9 @@ class MklConvOp : public OpKernel { // Inputs to FusedBatchNorm have same 1D shape fuse_bn_shape = MklGetInput(context, kInputIndex_BN_Mean).shape(); OP_REQUIRES(context, fuse_bn_shape.dims() == 1, - errors::InvalidArgument("FusedBatchNorm must be 1D, not: ", - fuse_bn_shape.DebugString())); + absl::InvalidArgumentError( + absl::StrCat("FusedBatchNorm must be 1D, not: ", + fuse_bn_shape.DebugString()))); // Note - MKL-DNN expects {1, C, 1, 1} for binary post-op even for NHWC fuse_bn_dims = {1, fuse_bn_shape.dim_size(0), 1, 1}; @@ -872,11 +967,16 @@ class MklConvOp : public OpKernel { MklConvFwdParams convFwdDims( src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS, dst_dims_mkl_order, strides, dilations, padding_left, padding_right, - fuse_bn_dims, tf_fmt, native_format); + fuse_bn_dims, tf_fmt, native_format, is_depthwise); // TODO(intel-tf): Extend the basic parameters for data types and fusions this->ExtendConvFwdParams(context, convFwdDims); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); conv_fwd = MklConvFwdPrimitiveFactory::Get( convFwdDims, do_not_cache); @@ -950,10 +1050,10 @@ class MklConvOp : public OpKernel { fwd_cpu_stream.reset(CreateStream(&eigen_tp, conv_fwd->GetEngine())); if (fuse_biasadd_) { const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); - Tbias* bias_data = + void* bias_data = this->GetBiasHandle(context, conv_fwd_pd, bias_tensor); conv_fwd->Execute(src_data, filter_data, bias_data, dst_data, - fwd_cpu_stream, scratch_pad.Get()); + convFwdDims, fwd_cpu_stream, scratch_pad.Get()); } else if (fuse_bn_) { const Tensor& bn_scale_tensor = MklGetInput(context, kInputIndex_BN_Scale); @@ -978,10 +1078,11 @@ class MklConvOp : public OpKernel { bn_rsqrt_data); conv_fwd->Execute(src_data, filter_data, nullptr, dst_data, bn_scale_data, bn_mean_data, bn_offset_data, - bn_rsqrt_data, fwd_cpu_stream, scratch_pad.Get()); - } else { - conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream, + bn_rsqrt_data, convFwdDims, fwd_cpu_stream, scratch_pad.Get()); + } else { + conv_fwd->Execute(src_data, filter_data, dst_data, convFwdDims, + fwd_cpu_stream, scratch_pad.Get()); } // Delete primitive since it is not cached. @@ -991,9 +1092,9 @@ class MklConvOp : public OpKernel { string error_msg = tensorflow::strings::StrCat( "Status: ", e.status, ", message: ", string(e.message), ", in file ", __FILE__, ":", __LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + OP_REQUIRES_OK(context, + absl::AbortedError(absl::StrCat( + "Operation received an exception:", error_msg))); } } @@ -1006,8 +1107,9 @@ class MklConvOp : public OpKernel { } else { const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); OP_REQUIRES(context, paddings_tf.dims() == 2, - errors::InvalidArgument("paddings must be 2-dimensional: ", - paddings_tf.shape().DebugString())); + absl::InvalidArgumentError( + absl::StrCat("paddings must be 2-dimensional: ", + paddings_tf.shape().DebugString()))); // Flatten tensor to get individual paddings. paddings = static_cast( const_cast(paddings_tf.flat().data())); @@ -1102,9 +1204,9 @@ class MklConvOp : public OpKernel { virtual void ComputeBNScale(OpKernelContext* context, float epsilon, int bn_variance_index, Tinput* scale_buf_ptr) { - OP_REQUIRES( - context, false, - errors::Unimplemented("Compute BN scale not expected in base class")); + OP_REQUIRES(context, false, + absl::UnimplementedError( + "Compute BN scale not expected in base class")); return; } @@ -1143,9 +1245,9 @@ class MklConvOp : public OpKernel { } } - virtual Tbias* GetBiasHandle(OpKernelContext* context, - std::shared_ptr& conv2d_fwd_pd, - const Tensor& bias_tensor) { + virtual void* GetBiasHandle(OpKernelContext* context, + std::shared_ptr& conv2d_fwd_pd, + const Tensor& bias_tensor) { if (fuse_biasadd_) { return static_cast( const_cast(bias_tensor.flat().data())); @@ -1160,6 +1262,7 @@ class MklConvOp : public OpKernel { MklDnnShape* output_mkl_shape, Tensor** output_tensor) { DCHECK(output_tensor); +#ifndef ENABLE_ONEDNN_V3 auto dst_md = conv_prim_desc.dst_desc(); if (!std::is_same::value) { @@ -1174,6 +1277,14 @@ class MklConvOp : public OpKernel { MklTensorFormatToMklDnnDataFormat(output_tf_format)); #endif // !ENABLE_ONEDNN_V3 } +#else + auto dst_md = + std::is_same::value + ? conv_prim_desc.dst_desc() + : memory::desc(conv_prim_desc.dst_desc().get_dims(), + MklDnnType(), + MklTensorFormatToMklDnnDataFormat(output_tf_format)); +#endif // !ENABLE_ONEDNN_V3 // Allocate shape of MKL tensor output_mkl_shape->SetMklTensor(true); @@ -1215,7 +1326,7 @@ class MklConvOp : public OpKernel { auto output_format_tag = MklTensorFormatToMklDnnDataFormat( output_mkl_shape->GetTfDataFormat()); OP_REQUIRES(context, output_format_tag != memory::format_tag::undef, - errors::InvalidArgument( + absl::InvalidArgumentError( "MklConvOp: AddN fusion: Invalid data format")); auto add_md = add_mkl_shape.IsMklTensor() @@ -1493,14 +1604,14 @@ class MklFusedConvOp int num_args; OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have at least one fused op.")); // TODO(intel-tf): Compact the code for activation checking if (fused_ops == std::vector{"BiasAdd"}) { this->set_fuse_biasadd(true); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"Relu"}) { this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); @@ -1519,26 +1630,26 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); } else if (fused_ops == std::vector{"BiasAdd", "Relu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Relu6"}) { this->set_fuse_biasadd(true); this->SET_FUSE_ACTIVATION_FOR_RELU6; OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Elu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "LeakyRelu"}) { this->set_fuse_biasadd(true); @@ -1548,21 +1659,21 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, leakyrelu_alpha); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Add"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"FusedBatchNorm", "Relu"}) { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); @@ -1571,7 +1682,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->SET_FUSE_ACTIVATION_FOR_RELU6; @@ -1580,7 +1691,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); @@ -1592,7 +1703,7 @@ class MklFusedConvOp context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, @@ -1603,7 +1714,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_swish, 1.0); @@ -1613,7 +1724,7 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "Relu6"}) { this->set_fuse_biasadd(true); @@ -1621,7 +1732,7 @@ class MklFusedConvOp this->SET_FUSE_ACTIVATION_FOR_RELU6; OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "Elu"}) { this->set_fuse_biasadd(true); @@ -1629,7 +1740,7 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "LeakyRelu"}) { @@ -1642,7 +1753,7 @@ class MklFusedConvOp leakyrelu_alpha); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Mish"}) { this->set_fuse_biasadd(true); @@ -1654,12 +1765,13 @@ class MklFusedConvOp this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_swish, 1.0); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } if (pad_enabled) { @@ -1706,7 +1818,7 @@ class MklFusedDepthwiseConvOp int num_args; OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused DepthwiseConv2D must have at least one fused op.")); if (fused_ops == std::vector{"BiasAdd"}) { @@ -1722,13 +1834,14 @@ class MklFusedDepthwiseConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); } else { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::InvalidArgumentError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } OP_REQUIRES( context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused DepthwiseConv2D must have one extra argument: bias.")); if (pad_enabled) { @@ -1796,7 +1909,7 @@ class MklQuantizedConvOp // TODO(intel-tf): num_fused_ops and legacy_fused_ops should go away once // old API is abandoned. OP_REQUIRES(context, !(fused_ops_attr.size() > 0 && num_fused_ops > 0), - errors::InvalidArgument( + absl::InvalidArgumentError( "QuantizedConv fused ops should be only available through " "either new API or old API, got both.")); @@ -1813,8 +1926,9 @@ class MklQuantizedConvOp std::find(supported_fusions.begin(), supported_fusions.end(), fused_ops_) != supported_fusions.end(); OP_REQUIRES(context, is_fusion_supported, - errors::InvalidArgument("Unsupported QuantizedConv fusion: [", - absl::StrJoin(fused_ops_, ","), "]")); + absl::InvalidArgumentError( + absl::StrCat("Unsupported QuantizedConv fusion: [", + absl::StrJoin(fused_ops_, ","), "]"))); } // Set the flag for every fused op. @@ -1838,9 +1952,10 @@ class MklQuantizedConvOp const bool fuse_requantize = IsFused(oneDNNFusedOps::kRequantize); OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_dt)); if (fuse_requantize) { - OP_REQUIRES(context, out_dt == DT_QINT8 || out_dt == DT_QUINT8, - errors::InvalidArgument("QuantizedConv: unsupported output " - "type when Requantize is fused.")); + OP_REQUIRES( + context, out_dt == DT_QINT8 || out_dt == DT_QUINT8, + absl::InvalidArgumentError("QuantizedConv: unsupported output " + "type when Requantize is fused.")); } if (context->HasAttr("Tsummand")) { @@ -1848,7 +1963,7 @@ class MklQuantizedConvOp if (!this->get_fuse_add()) { OP_REQUIRES( context, summand_dt == out_dt, - errors::InvalidArgument( + absl::InvalidArgumentError( "QuantizedConv: incorrect summand data type. When Sum is not " "fused, Tsummand attribute must have same value as out_type.")); } @@ -1857,10 +1972,19 @@ class MklQuantizedConvOp // If Requantize is fused, we set output_scale as first post op since it is // logically applied before any post op. Then we maintain the order of post // ops according to the order of fused_ops. +#ifndef ENABLE_ONEDNN_V3 int idx = fuse_requantize ? 1 : 0; +#else + post_op_to_idx_["src_scale"] = 0; + post_op_to_idx_["wei_scale"] = 1; + post_op_to_idx_["dst_scale"] = 2; + int idx = 3; +#endif // !ENABLE_ONEDNN_V3 for (int i = 0; i < fused_ops_.size(); ++i) { if (fused_ops_[i] == "Requantize") { +#ifndef ENABLE_ONEDNN_V3 post_op_to_idx_["output_scale"] = 0; +#endif // !ENABLE_ONEDNN_V3 } else if (fused_ops_[i] == "Sum") { post_op_to_idx_["sum"] = idx++; } else if (fused_ops_[i] == "Relu") { @@ -1874,7 +1998,7 @@ class MklQuantizedConvOp OP_REQUIRES( context, is_filter_const, - errors::InvalidArgument("QuantizedConv: filter must be a constant")); + absl::InvalidArgumentError("QuantizedConv: filter must be a constant")); if (num_fused_ops == -1) { // If num_fused_ops is -1 then the new API (ops) are being used. @@ -2010,24 +2134,30 @@ class MklQuantizedConvOp /*pad_enabled*/ false, is_depthwise, /*native_format*/ true>::ExtendConvFwdParams(context, params); params.post_op_params.resize(post_op_to_idx_.size()); - // When the output type is quint8, the output data is requantized - // into quint8. A post_op "output_scale" is added to do the conversion. + const float min_input = + context->input(min_input_idx_).template scalar()(); + const float max_input = + context->input(max_input_idx_).template scalar()(); + const Tensor& min_filter_vector = context->input(min_filter_idx_); + const Tensor& max_filter_vector = context->input(max_filter_idx_); + OP_REQUIRES( + context, + ((min_filter_vector.NumElements() > 0) && + (max_filter_vector.NumElements() > 0) && + (min_filter_vector.shape() == max_filter_vector.shape())), + absl::InvalidArgumentError("`min_ and max_filter` must have same" + "shape and contain at least one element.")); + float int_input_limit = + std::is_same::value ? 255.0f : 127.0f; + size_t depth = min_filter_vector.NumElements(); + const float* min_filter = min_filter_vector.flat().data(); + const float* max_filter = max_filter_vector.flat().data(); + std::vector SCALE(depth); + float float_input_range = + std::max(std::abs(min_input), std::abs(max_input)); + const float src_scale = float_input_range / int_input_limit; if (std::is_same::value || std::is_same::value) { - const float min_input = - context->input(min_input_idx_).template scalar()(); - const float max_input = - context->input(max_input_idx_).template scalar()(); - const Tensor& min_filter_vector = context->input(min_filter_idx_); - const Tensor& max_filter_vector = context->input(max_filter_idx_); - OP_REQUIRES( - context, - ((min_filter_vector.NumElements() > 0) && - (max_filter_vector.NumElements() > 0) && - (min_filter_vector.shape() == max_filter_vector.shape())), - errors::InvalidArgument("`min_ and max_filter` must have same" - "shape and contain at least one element.")); - // min_freezed_output and max_freezed_output are the actual range // for the output. const float min_freezed_output = @@ -2037,12 +2167,6 @@ class MklQuantizedConvOp float int_output_limit = std::is_same::value ? 255.0f : 127.0f; - size_t depth = min_filter_vector.NumElements(); - const float* min_filter = min_filter_vector.flat().data(); - const float* max_filter = max_filter_vector.flat().data(); - std::vector scales(depth); - float float_input_range = - std::max(std::abs(min_input), std::abs(max_input)); float float_output_range = std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); const float int_const_scale_limit = @@ -2053,13 +2177,18 @@ class MklQuantizedConvOp float float_filter_range = std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); // To understand the scaling, please see mkl_requantize_ops_test. +#ifndef ENABLE_ONEDNN_V3 scales[i] = int_output_limit * float_input_range * float_filter_range / (int_const_scale_limit * float_output_range); +#else + wei_scale[i] = float_filter_range / 127.0; +#endif // !ENABLE_ONEDNN_V3 } // we are creating a partial key here to use with primitive key caching to // improve key creation performance. Instead of using actual values we are // using the pointers for min/max_filter_vector, and this works since the // filter vector here is a constant. +#ifndef ENABLE_ONEDNN_V3 FactoryKeyCreator param_key; param_key.AddAsKey(min_input); param_key.AddAsKey(max_input); @@ -2069,12 +2198,63 @@ class MklQuantizedConvOp param_key.AddAsKey(max_filter); params.post_op_params[post_op_to_idx_["output_scale"]] = { "output_scale", dnnl::algorithm::undef, scales, param_key.GetKey()}; +#else + const float dst_scale = float_output_range / int_output_limit; + FactoryKeyCreator dst_param_key; + dst_param_key.AddAsKey(min_freezed_output); + dst_param_key.AddAsKey(max_freezed_output); + params.post_op_params[post_op_to_idx_["dst_scale"]] = { + "dst_scale", + dnnl::algorithm::undef, + {dst_scale}, + dst_param_key.GetKey()}; +#endif // !ENABLE_ONEDNN_V3 + } else { +#ifdef ENABLE_ONEDNN_V3 + if (!std::is_same::value) + TF_CHECK_OK(absl::FailedPreconditionError( + "Output datatype is expected to be qint32.")); + float min_min_filter = min_filter[0]; + float max_max_filter = max_filter[0]; + for (size_t i = 0; i < depth; ++i) { + float float_filter_range = + std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); + wei_scale[i] = float_filter_range / 127.0; + if (min_filter[i] < min_min_filter) min_min_filter = min_filter[i]; + if (max_filter[i] > max_max_filter) max_max_filter = max_filter[i]; + } + const float single_wei_scale = + std::max(std::abs(min_min_filter), std::abs(max_max_filter)) / 127.0; + const float dst_scale = single_wei_scale * src_scale; + FactoryKeyCreator dst_param_key; + dst_param_key.AddAsKey(dst_scale); + params.post_op_params[post_op_to_idx_["dst_scale"]] = { + "dst_scale", + dnnl::algorithm::undef, + {dst_scale}, + dst_param_key.GetKey()}; +#endif // ENABLE_ONEDNN_V3 } +#ifdef ENABLE_ONEDNN_V3 + FactoryKeyCreator src_param_key; + src_param_key.AddAsKey(min_input); + src_param_key.AddAsKey(max_input); + FactoryKeyCreator wei_param_key; + wei_param_key.AddAsKey(min_filter); + wei_param_key.AddAsKey(max_filter); + params.post_op_params[post_op_to_idx_["src_scale"]] = { + "src_scale", + dnnl::algorithm::undef, + {src_scale}, + src_param_key.GetKey()}; + params.post_op_params[post_op_to_idx_["wei_scale"]] = { + "wei_scale", dnnl::algorithm::undef, wei_scale, wei_param_key.GetKey()}; +#endif // ENABLE_ONEDNN_V3 if (this->get_fuse_add()) { // Calculate the scale (beta in oneDNN api term) for sum + DataType summand_dt = this->input_type(this->get_input_add_idx()); if (std::is_same::value) { - DataType summand_dt = this->input_type(this->get_input_add_idx()); bool summand_condition = (summand_dt == DT_QINT8) || (summand_dt == DT_QUINT8); DCHECK((summand_condition)); @@ -2086,15 +2266,15 @@ class MklQuantizedConvOp OP_REQUIRES( context, TensorShapeUtils::IsScalar(min_freezed_output_tensor.shape()), - errors::InvalidArgument( - "`min_freezed_output` must be rank 0 but is rank ", - min_freezed_output_tensor.dims())); + absl::InvalidArgumentError( + absl::StrCat("`min_freezed_output` must be rank 0 but is rank ", + min_freezed_output_tensor.dims()))); OP_REQUIRES( context, TensorShapeUtils::IsScalar(max_freezed_output_tensor.shape()), - errors::InvalidArgument( - "`max_freezed_output` must be rank 0 but is rank ", - max_freezed_output_tensor.dims())); + absl::InvalidArgumentError( + absl::StrCat("`max_freezed_output` must be rank 0 but is rank ", + max_freezed_output_tensor.dims()))); const Tensor& min_freezed_summand_tensor = context->input(min_summand_idx_); const Tensor& max_freezed_summand_tensor = @@ -2102,15 +2282,15 @@ class MklQuantizedConvOp OP_REQUIRES( context, TensorShapeUtils::IsScalar(min_freezed_summand_tensor.shape()), - errors::InvalidArgument( + absl::InvalidArgumentError(absl::StrCat( "`min_freezed_summand` must be rank 0 but is rank ", - min_freezed_summand_tensor.dims())); + min_freezed_summand_tensor.dims()))); OP_REQUIRES( context, TensorShapeUtils::IsScalar(max_freezed_summand_tensor.shape()), - errors::InvalidArgument( + absl::InvalidArgumentError(absl::StrCat( "`max_freezed_summand` must be rank 0 but is rank ", - max_freezed_summand_tensor.dims())); + max_freezed_summand_tensor.dims()))); const float min_freezed_output = min_freezed_output_tensor.template scalar()(); const float max_freezed_output = @@ -2131,18 +2311,24 @@ class MklQuantizedConvOp params.post_op_params[post_op_to_idx_["sum"]] = { "sum", dnnl::algorithm::undef, - {summand_range / output_range}, + {SUMMAND_SCALE_U8(summand_range, output_range)}, ""}; } else { params.post_op_params[post_op_to_idx_["sum"]] = { "sum", dnnl::algorithm::undef, - {255.0f * summand_range / (output_range * 127.0f)}, + {SUMMAND_SCALE_S8(summand_range, output_range)}, ""}; } } else { - params.post_op_params[post_op_to_idx_["sum"]] = { - "sum", dnnl::algorithm::undef, {1.0}, ""}; + params.post_op_params[post_op_to_idx_["sum"]] = {"sum", + dnnl::algorithm::undef, + {1.0}, + "", +#ifdef ENABLE_ONEDNN_V3 + summand_dt +#endif // ENABLE_ONEDNN_V3 + }; } } @@ -2185,10 +2371,11 @@ class MklQuantizedConvOp OP_REQUIRES(context, context->forward_input_to_output_with_shape( summand_idx, 0, summand.shape(), output_tensor), - errors::InvalidArgument( + absl::InvalidArgumentError( "Summand cannot be forwarded in the current fusion.")); return; } +#ifndef ENABLE_ONEDNN_V3 MklConvOp< Device, Tinput, /*Tfilter*/ qint8, Tbias, Toutput, Ttemp_output, /*Tpadding*/ int32, @@ -2200,8 +2387,8 @@ class MklQuantizedConvOp output_tensor); const Tensor& summand = context->input(this->get_input_add_idx()); if (summand.dtype() != DT_FLOAT) - TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, - "Current fusion requires summand to be float")); + TF_CHECK_OK(absl::FailedPreconditionError( + "Current fusion requires summand to be float")); // We need to compute scale for the summand const float min_input = context->input(min_input_idx_).template scalar()(); @@ -2251,15 +2438,34 @@ class MklQuantizedConvOp conv_prim_desc.dst_desc(), reorder_attr); CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_, context); +#else + // In oneDNN v3.0 summand does not need to be scaled. + int summand_idx = this->get_input_add_idx(); + DataType summand_dt = this->input_type(summand_idx); + if (summand_dt != DT_FLOAT) + TF_CHECK_OK(absl::FailedPreconditionError( + "Summand datatype is expected to be float.")); + Tensor& summand_float = const_cast(context->input(summand_idx)); + OP_REQUIRES_OK(context, + summand_float.BitcastFrom(summand_float, DT_QINT32, + summand_float.shape())); + OP_REQUIRES(context, + context->forward_input_to_output_with_shape( + summand_idx, 0, summand_float.shape(), output_tensor), + absl::InvalidArgumentError( + "Summand cannot be forwarded in the current fusion.")); + +#endif // !ENABLE_ONEDNN_V3 } } - Tbias* GetBiasHandle(OpKernelContext* context, - std::shared_ptr& conv_fwd_pd, - const Tensor& bias_tensor) override { + void* GetBiasHandle(OpKernelContext* context, + std::shared_ptr& conv_fwd_pd, + const Tensor& bias_tensor) override { if (!this->get_fuse_biasadd()) { return nullptr; } +#ifndef ENABLE_ONEDNN_V3 if (std::is_same::value) { return static_cast( const_cast(bias_tensor.flat().data())); @@ -2342,6 +2548,99 @@ class MklQuantizedConvOp return bias_data; } return GetCachedBias(context); +#else + if (std::is_same::value) { + return static_cast( + const_cast(bias_tensor.flat().data())); + } + // Starting with oneDNN v3.0, bias needs to be passed as is (in float + // datatype). However, for backward compatibility we need to handle the case + // where bias is qint32. Since oneDNN v3.0 does not support qint32 bias, we + // need to dequantize to float. + const float min_input = + context->input(min_input_idx_).template scalar()(); + const float max_input = + context->input(max_input_idx_).template scalar()(); + const Tensor& min_filter_vector = context->input(min_filter_idx_); + const Tensor& max_filter_vector = context->input(max_filter_idx_); + if ((min_filter_vector.NumElements() == 0) || + (max_filter_vector.NumElements() == 0) || + (min_filter_vector.shape() != max_filter_vector.shape())) { + TF_CHECK_OK(absl::FailedPreconditionError( + "`min_filter and max_filter` must have same" + "shape and contain at least one element.")); + } + const float* min_filter = min_filter_vector.flat().data(); + const float* max_filter = max_filter_vector.flat().data(); + const float int_const_scale_limit = + (std::is_same::value) ? 255.0 * 127.0 : 127.0 * 127.0; + + // Re-scale bias if either of following 2 conditions are met: + // 1. Bias is not const; + // 2. Bias is const, bias has not been cached (first iteration). + size_t depth = min_filter_vector.NumElements(); + bool scales_are_valid = (depth == scales_.size()); + scales_.resize(depth); + for (size_t i = 0; i < depth; ++i) { + float tmp_scale = + int_const_scale_limit / + (std::max(std::abs(max_input), std::abs(min_input)) * + std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); + if (scales_are_valid && std::abs(tmp_scale - scales_[i]) > 1e-6) { + scales_are_valid = false; + } + scales_[i] = tmp_scale; + } + if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) { + dnnl::primitive_attr reorder_attr; + + if (depth == 1) { + reorder_attr.set_scales_mask(DNNL_ARG_DST, 0); + } else { + reorder_attr.set_scales_mask(DNNL_ARG_DST, 1); + } + + auto bias_md = memory::desc({static_cast(bias_tensor.NumElements())}, + MklDnnType(), memory::format_tag::x); + void* bias_buf = static_cast( + const_cast(bias_tensor.flat().data())); + if (!input_bias_) { + input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf); + } else { + input_bias_->set_data_handle(bias_buf); + } + + if (!scaled_bias_buf_) { + AllocTmpBuffer(context, &scaled_bias_tensor_, + conv_fwd_pd->bias_desc(), &scaled_bias_buf_); + } + if (!scaled_bias_) { + scaled_bias_ = new memory(conv_fwd_pd->bias_desc(), this->cpu_engine_, + scaled_bias_buf_); + } else { + scaled_bias_->set_data_handle(scaled_bias_buf_); + } + std::unique_ptr scale_mem( + new memory({{static_cast(depth)}, + MklDnnType(), + memory::format_tag::x}, + this->cpu_engine_, scales_.data())); + auto reorder_desc = + ReorderPd(this->cpu_engine_, input_bias_->get_desc(), + this->cpu_engine_, scaled_bias_->get_desc(), reorder_attr); + CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, + this->cpu_engine_, context, scale_mem.get()); + + float* bias_data = + reinterpret_cast(scaled_bias_->get_data_handle()); + if (is_bias_const_) + CacheBias(context, conv_fwd_pd, bias_data, scaled_bias_); + + return bias_data; + } + return GetCachedBias(context); + +#endif // !ENABLE_ONEDNN_V3 } bool is_bias_const_; @@ -2396,9 +2695,9 @@ class MklQuantizedConvOp DCHECK(bias_tensor); TensorShape bias_tf_shape; bias_tf_shape.AddDim( - (conv_prim_desc.bias_desc().get_size() / sizeof(Tbias))); + (conv_prim_desc.bias_desc().get_size() / sizeof(TSCALED_BIAS))); OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, + context->allocate_temp(DataTypeToEnum::value, bias_tf_shape, &cached_bias_data_)); *bias_tensor = &cached_bias_data_; } @@ -2416,7 +2715,7 @@ class MklQuantizedConvOp // Only one thread can execute this method at any given time. void CacheBias(OpKernelContext* context, const std::shared_ptr& conv_fwd_pd, - Tbias* bias_data, const memory* scaled_bias) + TSCALED_BIAS* bias_data, const memory* scaled_bias) TF_LOCKS_EXCLUDED(bias_cache_mu_) { mutex_lock lock(bias_cache_mu_); @@ -2429,18 +2728,18 @@ class MklQuantizedConvOp Tensor* bias_tensor_ptr = nullptr; AllocateTensor(context, *conv_fwd_pd, &bias_tensor_ptr); void* cached_bias_data = const_cast( - static_cast(bias_tensor_ptr->flat().data())); + static_cast(bias_tensor_ptr->flat().data())); size_t cached_bias_data_size = scaled_bias->get_desc().get_size(); memcpy(cached_bias_data, bias_data, cached_bias_data_size); } - Tbias* GetCachedBias(OpKernelContext* context) + TSCALED_BIAS* GetCachedBias(OpKernelContext* context) TF_LOCKS_EXCLUDED(bias_cache_mu_) { tf_shared_lock lock(bias_cache_mu_); const Tensor& cached_bias_data = cached_bias_data_; - return static_cast( - const_cast(cached_bias_data.flat().data())); + return static_cast(const_cast( + cached_bias_data.flat().data())); } }; @@ -2466,13 +2765,14 @@ class MklFusedConv3DOp std::vector padding_list; OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list)); if (padding_list.empty()) { - OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument("Fused Conv3D must have at least one " - "fused op when Pad is not fused.")); + OP_REQUIRES( + context, !fused_ops.empty(), + absl::InvalidArgumentError("Fused Conv3D must have at least one " + "fused op when Pad is not fused.")); if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == fused_ops.end()) { OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv3D must have one extra argument: bias.")); } else if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == fused_ops.end() && @@ -2480,7 +2780,7 @@ class MklFusedConv3DOp fused_ops.end()) { OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv3D must have two extra arguments: bias and add.")); } } @@ -2533,8 +2833,9 @@ class MklFusedConv3DOp } else { if (padding_list.empty()) { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } } } @@ -2949,6 +3250,11 @@ REGISTER_KERNEL_BUILDER( #undef GET_DATA_TYPE #undef SET_FUSE_ACTIVATION_FOR_RELU6 #undef SET_MKL_LAYOUT +#undef OUTPUT_SCALE_DCHECK +#undef TSCALED_BIAS +#undef SCALE +#undef SUMMAND_SCALE_U8 +#undef SUMMAND_SCALE_S8 } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.h b/tensorflow/core/kernels/mkl/mkl_conv_ops.h index 2f35decb3548f0..0384df4b309285 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.h @@ -446,11 +446,11 @@ class MklDnnConvUtil { padding_type = padding_; } OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2( + GetWindowedOutputSizeVerbose( input_rows, filter_rows, dilation_rows, stride_rows, padding_type, &out_rows, &pad_top, &pad_bottom)); OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2( + GetWindowedOutputSizeVerbose( input_cols, filter_cols, dilation_cols, stride_cols, padding_type, &out_cols, &pad_left, &pad_right)); } else { @@ -466,16 +466,16 @@ class MklDnnConvUtil { } else { padding_type = padding_; } - OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerboseV2( + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( input_planes, filter_planes, dilation_planes, stride_planes, padding_type, &out_planes, &pad_front, &pad_back)); OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2( + GetWindowedOutputSizeVerbose( input_rows, filter_rows, dilation_rows, stride_rows, padding_type, &out_rows, &pad_top, &pad_bottom)); OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2( + GetWindowedOutputSizeVerbose( input_cols, filter_cols, dilation_cols, stride_cols, padding_type, &out_cols, &pad_left, &pad_right)); } diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc index ce293200bb3ea2..c12b516074da27 100644 --- a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc @@ -72,7 +72,12 @@ class MklDequantizeOp : public OpKernel { MklDnnData dst(&cpu_engine); std::shared_ptr reorder_stream; - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine)); memory::format_tag dst_layout_type; diff --git a/tensorflow/core/kernels/mkl/mkl_einsum_op.cc b/tensorflow/core/kernels/mkl/mkl_einsum_op.cc index 48b495a461b309..698dcdb12ec530 100644 --- a/tensorflow/core/kernels/mkl/mkl_einsum_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_einsum_op.cc @@ -112,8 +112,13 @@ struct MklEinsumHelper { auto params = bmm.CreateMatMulParams(prefix, lhs.shape(), rhs.shape(), out_shape, trans_x, trans_y); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // Create or retrieve matmul primitive from cache. - MklDnnThreadPool eigen_tp(ctx); MklMatMulPrimitive* matmul_prim = MklMatMulPrimitiveFactory::Get( *params, false /* value for do_not_cache */); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index 611c3709878e11..e5c8bfea686348 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -837,8 +837,13 @@ class MklFusedBatchNormOp : public OpKernel { MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, tensor_format_, src_md, activation_mode_); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // Get forward batch-normalization op from the primitive caching pool. - MklDnnThreadPool eigen_tp(context); MklFusedBatchNormFwdPrimitive* bn_fwd = MklFusedBatchNormFwdPrimitiveFactory::Get(fwdParams); @@ -1312,7 +1317,10 @@ class MklFusedBatchNormGradOp : public OpKernel { MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, is_training_, tensor_format_, src_md, diff_dst_md); - MklDnnThreadPool eigen_tp(context); + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklFusedBatchNormBwdPrimitive* bn_bwd = MklFusedBatchNormBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index 6373bf09539fe4..c103986c198faa 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -41,7 +41,7 @@ class MklFusedInstanceNormOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("reduction_axes", &mean_reduction_axes)); OP_REQUIRES(context, InferDataFormat(mean_reduction_axes), - errors::InvalidArgument( + absl::InvalidArgumentError( "Failed to infer data format from reduction axes")); CheckFusedActivation(context); } @@ -57,21 +57,26 @@ class MklFusedInstanceNormOp : public OpKernel { (src_tensor.dims() == 4 && data_format_ == "NCHW") || (src_tensor.dims() == 5 && data_format_ == "NDHWC") || (src_tensor.dims() == 5 && data_format_ == "NCDHW"), - errors::InvalidArgument( + absl::InvalidArgumentError(absl::StrCat( "Unsupported input: ", src_tensor.shape().DebugString(), - ", ", data_format_)); + ", ", data_format_))); size_t num_elements_scale = scale_tensor.NumElements(); size_t num_elements_shift = shift_tensor.NumElements(); - OP_REQUIRES( - ctx, num_elements_scale == num_elements_shift, - errors::InvalidArgument("Number of elements in scale and shift", - "tensors are not same.")); + OP_REQUIRES(ctx, num_elements_scale == num_elements_shift, + absl::InvalidArgumentError( + absl::StrCat("Number of elements in scale and shift", + "tensors are not same."))); TensorFormat tensor_format; OP_REQUIRES(ctx, FormatFromString(data_format_, &tensor_format), - errors::InvalidArgument("Invalid data format")); - - MklDnnThreadPool eigen_tp(ctx); + absl::InvalidArgumentError("Invalid data format")); + + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); std::shared_ptr engine_stream_ptr; engine_stream_ptr.reset(CreateStream(&eigen_tp, cpu_engine_)); @@ -102,18 +107,26 @@ class MklFusedInstanceNormOp : public OpKernel { void* src_buf = static_cast(const_cast(src_tensor.flat().data())); +#ifndef ENABLE_ONEDNN_V3 +#define NUM_DUPLICATE 2 +#else +#define NUM_DUPLICATE 1 +#endif // !ENABLE_ONEDNN_V3 memory::dims scale_shift_dims = { - 2, static_cast(num_elements_scale)}; + static_cast(NUM_DUPLICATE * num_elements_scale)}; auto scale_shift_md = memory::desc(scale_shift_dims, MklDnnType(), - memory::format_tag::nc); - Tensor scale_shift_tensor; + memory::format_tag::x); int64_t tensor_shape = scale_shift_md.get_size() / sizeof(float); +#undef NUM_DUPLICATE + +#ifndef ENABLE_ONEDNN_V3 + Tensor scale_shift_tensor; OP_REQUIRES_OK( ctx, ctx->allocate_temp(DataTypeToEnum::v(), {tensor_shape}, &scale_shift_tensor)); void* scale_shift_buf = static_cast(scale_shift_tensor.flat().data()); - SetupScaleShiftBuffer(scale_tensor, shift_tensor, engine_stream_ptr, + SetupScaleShiftBuffer(ctx, scale_tensor, shift_tensor, engine_stream_ptr, num_elements_scale, scale_shift_buf); auto scale_shift_mem = memory(scale_shift_md, cpu_engine_, scale_shift_buf); @@ -122,18 +135,53 @@ class MklFusedInstanceNormOp : public OpKernel { auto bnorm_desc = batch_normalization_forward::desc( prop_kind::forward_inference, src_md, epsilon_, normalization_flags::use_scale_shift); +#else + Tensor scale_fp32_tensor; + Tensor shift_fp32_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::v(), {tensor_shape}, + &scale_fp32_tensor)); + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::v(), {tensor_shape}, + &shift_fp32_tensor)); + void* scale_fp32_buf = + static_cast(scale_fp32_tensor.flat().data()); + void* shift_fp32_buf = + static_cast(shift_fp32_tensor.flat().data()); + + SetupScaleShiftBuffer(ctx, scale_tensor, shift_tensor, engine_stream_ptr, + num_elements_scale, scale_fp32_buf, shift_fp32_buf); + auto scale_mem = memory(scale_shift_md, cpu_engine_, scale_fp32_buf); + auto shift_mem = memory(scale_shift_md, cpu_engine_, shift_fp32_buf); +#endif // !ENABLE_ONEDNN_V3 batch_normalization_forward::primitive_desc bnorm_pd; if (fuse_activation_) { dnnl::post_ops post_ops; dnnl::primitive_attr post_ops_attr; +#ifndef ENABLE_ONEDNN_V3 post_ops.append_eltwise(1.0, dnnl::algorithm::eltwise_relu, leakyrelu_alpha_, 0.0); post_ops_attr.set_post_ops(post_ops); bnorm_pd = batch_normalization_forward::primitive_desc( bnorm_desc, post_ops_attr, cpu_engine_); +#else + post_ops.append_eltwise(dnnl::algorithm::eltwise_relu, leakyrelu_alpha_, + 0.0); + post_ops_attr.set_post_ops(post_ops); + bnorm_pd = batch_normalization_forward::primitive_desc( + cpu_engine_, prop_kind::forward_inference, src_md, src_md, epsilon_, + normalization_flags::use_scale | normalization_flags::use_shift, + post_ops_attr); +#endif // !ENABLE_ONEDNN_V3 } else { +#ifndef ENABLE_ONEDNN_V3 bnorm_pd = batch_normalization_forward::primitive_desc(bnorm_desc, cpu_engine_); +#else + bnorm_pd = batch_normalization_forward::primitive_desc( + cpu_engine_, prop_kind::forward_inference, src_md, src_md, epsilon_, + normalization_flags::use_scale | normalization_flags::use_shift); +#endif // !ENABLE_ONEDNN_V3 } auto bnorm_prim = batch_normalization_forward(bnorm_pd); @@ -154,8 +202,13 @@ class MklFusedInstanceNormOp : public OpKernel { std::unordered_map bnorm_args; bnorm_args.insert({DNNL_ARG_SRC, *src_mem_ptr}); - bnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem}); bnorm_args.insert({DNNL_ARG_DST, *dst_mem_ptr}); +#ifndef ENABLE_ONEDNN_V3 + bnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem}); +#else + bnorm_args.insert({DNNL_ARG_SCALE, scale_mem}); + bnorm_args.insert({DNNL_ARG_SHIFT, shift_mem}); +#endif // !ENABLE_ONEDNN_V3 // Perform batchnorm computation for each batch in input for (int i = 0; i < batch_size; i++) { @@ -169,8 +222,8 @@ class MklFusedInstanceNormOp : public OpKernel { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - ctx, errors::Aborted("Operation received an exception:", error_msg)); + OP_REQUIRES_OK(ctx, absl::AbortedError(absl::StrCat( + "Operation received an exception:", error_msg))); } } @@ -199,8 +252,9 @@ class MklFusedInstanceNormOp : public OpKernel { context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha_)); } else { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } } @@ -221,23 +275,35 @@ class MklFusedInstanceNormOp : public OpKernel { return valid; } - // Helper function to add scale and shift data into same buffer in float - // type as requested by oneDNN - void SetupScaleShiftBuffer(const Tensor& scale_tensor, + // Helper function to prepare scale and shift data in float type as + // required by oneDNN library. Prior to oneDNN 3.x version, the library + // requires the final scale and shift data to be passed in the same buffer + // whereas the 3.x version requires separate buffers for scale and shift + // data. + void SetupScaleShiftBuffer(OpKernelContext* ctx, const Tensor& scale_tensor, const Tensor& shift_tensor, std::shared_ptr engine_stream_ptr, - int num_elements, void* scale_shift_buf) { + int num_elements, void* fp32_scale_or_combine_buf, + void* fp32_shift_buf = nullptr) { void* scale_buf_src = static_cast(const_cast(scale_tensor.flat().data())); void* shift_buf_src = static_cast(const_cast(shift_tensor.flat().data())); - auto scale_offset = sizeof(float) * num_elements; - void* scale_buf_dst = scale_shift_buf; - void* shift_buf_dst = static_cast(scale_shift_buf) + scale_offset; + auto data_size = sizeof(float) * num_elements; + void* scale_buf_dst = fp32_scale_or_combine_buf; + void* shift_buf_dst = nullptr; +#ifndef ENABLE_ONEDNN_V3 + shift_buf_dst = static_cast(fp32_scale_or_combine_buf) + data_size; + (void)fp32_shift_buf; +#else + OP_REQUIRES(ctx, (fp32_shift_buf != nullptr), + absl::InvalidArgumentError("Invalid shift buffer")); + shift_buf_dst = fp32_shift_buf; +#endif // !ENABLE_ONEDNN_V3 if (std::is_same::value) { - memcpy(scale_buf_dst, scale_buf_src, scale_offset); - memcpy(shift_buf_dst, shift_buf_src, scale_offset); + memcpy(scale_buf_dst, scale_buf_src, data_size); + memcpy(shift_buf_dst, shift_buf_src, data_size); } else { // oneDNN requires float type for scale_shift, need to convert to float // type diff --git a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc index 22fea24bf37b00..5f0c50c8b47c3f 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc @@ -1261,7 +1261,6 @@ class BiasCacheTest : public OpsTestBase { } }; -#ifndef ENABLE_ONEDNN_V3 TEST_F(BiasCacheTest, Conv2DBiasCacheTestOldAPI) { TestConv2DBiasCacheTest(true); } @@ -1269,7 +1268,6 @@ TEST_F(BiasCacheTest, Conv2DBiasCacheTestOldAPI) { TEST_F(BiasCacheTest, Conv2DBiasCacheTestNewAPI) { TestConv2DBiasCacheTest(false); } -#endif // !ENABLE_ONEDNN_V3 // Testing fusion of pad and fusedconv2d template diff --git a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc index 297e95c1cc6f20..ae5ad08b3f4393 100644 --- a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc @@ -61,7 +61,12 @@ class MklLayerNormOp : public OpKernel { "tensors are not same.")); auto cpu_engine = engine(engine::kind::cpu, 0); - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); auto cpu_stream = std::unique_ptr(CreateStream(&eigen_tp, cpu_engine)); diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op.cc index dc37a7023b42ca..e4122c83b4042d 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op.cc @@ -162,7 +162,12 @@ class MklMatMulOp : public OpKernel { char char_transb = transb ? 'T' : 'N'; VLOG(2) << "MKL DNN SGEMM called"; #ifndef ENABLE_ONEDNN_OPENMP - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // With threadpool , the runtime overhead is comparable to the kernel // execution for small kernel sizes. For such sizes, it may be better to run // the kernel single threaded. Here we are coming up with a cost model based diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 363ca8bbff6c88..1d388705e7ddd0 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -135,7 +135,12 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { // Extend the basic parameters for data types and fusions. ExtendMklDnnMatMulFwdParams(ctx, matmul_params); auto st = ExecuteSingleThreadedGemm(batch, channel, k, sizeof(T)); - MklDnnThreadPool eigen_tp(ctx, st ? 1 : -1); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread(), + st ? 1 : -1); MklDnnMatMulFwdPrimitive* matmul_prim = MklDnnMatMulFwdPrimitiveFactory::Get(matmul_params, 0); diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 33e2e8a646d192..3e55f11cd24abe 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -1016,7 +1016,12 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, MklMatMulParams params("dnnl_gemm", a_dims, b_dims, c_dims, a_strides, b_strides, c_strides); auto st = ExecuteSingleThreadedGemm(m, n, k, sizeof(T)); - MklDnnThreadPool eigen_tp(ctx, st ? 1 : -1); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread(), + st ? 1 : -1); MklMatMulPrimitive* matmul_prim = MklMatMulPrimitiveFactory::Get(params, 0); diff --git a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc index d27b36c54f2870..2360cdefba407c 100644 --- a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc @@ -143,7 +143,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, this->native_format_); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); // Allocate output tensor. this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()), @@ -337,7 +342,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { prop_kind::forward_training, static_cast(this->data_format_mkldnn_), src_md, this->native_format_); - MklDnnThreadPool eigen_tp(context); + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklPoolingBwdPrimitive* pooling_bwd = MklPoolingBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc index aab190f5ac618d..c73233cab8dd26 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc @@ -375,19 +375,21 @@ void MklPoolParameters::Init(OpKernelContext* context, if (depth_window == 1) { // We are pooling in the D (Pool3D only), H and W. if (!is_pool2d) { - OP_REQUIRES_OK( - context, GetWindowedOutputSizeVerbose(tensor_in_planes, window_planes, - planes_stride, padding, - &out_planes, &pad_P1, &pad_P2)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + tensor_in_planes, window_planes, + /*dilation_rate=*/1, planes_stride, padding, + &out_planes, &pad_P1, &pad_P2)); } - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_rows, window_rows, row_stride, - padding, &out_height, &pad_top, &pad_bottom)); + OP_REQUIRES_OK( + context, GetWindowedOutputSizeVerbose( + tensor_in_rows, window_rows, /*dilation_rate=*/1, + row_stride, padding, &out_height, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_cols, window_cols, col_stride, - padding, &out_width, &pad_left, &pad_right)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + tensor_in_cols, window_cols, /*dilation_rate=*/1, + col_stride, padding, &out_width, &pad_left, &pad_right)); // TF can work with int64, but oneDNN only supports int32. // Fail if the depth, height or width are greater than MAX_INT. diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc index 155f9fb4207563..259dfacc0bf51b 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc @@ -104,7 +104,6 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/util/mkl_threadpool.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/work_sharder.h" @@ -246,8 +245,13 @@ class MklDnnQuantizedMatMulOp // Extend the basic parameters for data types and fusions. this->ExtendMklDnnMatMulFwdParams(context, matmul_fwd_dims); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // Get a MatMul fwd from primitive pool. - MklDnnThreadPool eigen_tp(context); matmul_fwd = MklDnnMatMulFwdPrimitiveFactory::Get(matmul_fwd_dims, 0); diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc index f2f2234f5fa5ae..b8190118a04e93 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc @@ -560,7 +560,12 @@ class MklQuantizeV2Op : public OpKernel { fwdParams.post_op_params.param.push_back(scale_factor); #endif // ENABLE_ONEDNN_V3 - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklReorderWithScalePrimitive* reorder_prim = MklReorderWithScalePrimitiveFactory::Get(src.GetUsrMem(), dst.GetUsrMem(), fwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_test.cc index 00cc02bfcad397..4dc4634775b075 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3) && defined(ENABLE_MKL) +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #define EIGEN_USE_THREADS #include @@ -1062,4 +1062,4 @@ TEST_F(QuantizedConvTest, BiasAddSumReluFusionFloatSummand) { } } // namespace tensorflow -#endif // INTEL_MKL && !ENABLE_ONEDNN_V3 && ENABLE_MKL +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op.cc b/tensorflow/core/kernels/mkl/mkl_relu_op.cc index 24a0ae60fc02b6..03f19e21da86ca 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op.cc @@ -476,7 +476,12 @@ class MklReluOpBase : public OpKernel { // Try to get an eltwise forward primitive from caching pool MklEltwiseFwdParams fwdParams(src_dims, src_md, alg_kind, alpha_, beta_); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklEltwiseFwdPrimitive* eltwise_fwd = MklEltwiseFwdPrimitiveFactory::Get(fwdParams); auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd(); @@ -683,7 +688,10 @@ class MklReluGradOpBase : public OpKernel { MklEltwiseBwdParams bwdParams(src_dims, common_md, alg_kind, alpha_, beta_, GetTypeOfInputTensorFromFwdOp()); - MklDnnThreadPool eigen_tp(context); + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklEltwiseBwdPrimitive* eltwise_bwd = MklEltwiseBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc index 6e1daf9ff5babe..31431213397fd8 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/kernels/meta_support.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/util/mkl_threadpool.h" #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc index 62ac3674f2e048..ece885bb686615 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc @@ -115,7 +115,12 @@ class MklRequantizePerChannelOp : public OpKernel { cpu_engine_, scales.data()); #endif // !ENABLE_ONEDNN_V3 - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); memory::dims dims_mkl_order = TFShapeToMklDnnDimsInNCHW(input.shape(), FORMAT_NHWC); memory::desc input_md = memory::desc(dims_mkl_order, MklDnnType(), diff --git a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc index 9291d2c099165a..60624f2b7d110f 100644 --- a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc @@ -266,7 +266,12 @@ class MklSoftmaxOp : public OpKernel { fwdParams.aarch64_counter = MklSoftmaxPrimitiveFactory::IncrementCounter(); #endif - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklSoftmaxPrimitive* softmax_fwd = MklSoftmaxPrimitiveFactory::Get(fwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl/mkl_transpose_op.cc index 7ad7e517edc813..b26879dd51556a 100644 --- a/tensorflow/core/kernels/mkl/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_transpose_op.cc @@ -83,7 +83,12 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor, out.SetUsrMem(in_dims, out_strides, out_tensor); std::vector net; - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); auto* prim = FindOrCreateReorder(in.GetUsrMem(), out.GetUsrMem()); transpose_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); in.SetUsrMemDataHandle(&in_tensor, transpose_stream); diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index decb162858650b..1b9db6406b43d9 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -640,13 +640,14 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_atan2_kernels", - op = "atan2", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "atan2", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -749,25 +750,27 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_ceil_kernels", - op = "ceil", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "ceil", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_floor_kernels", - op = "floor", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "floor", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -793,26 +796,28 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_rint_kernels", - jit_types = ["f16"], - op = "rint", - tile_size = "1024", - types = [ + jit_types = [ + "f16", "f32", "f64", ], + op = "rint", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_round_kernels", - op = "round", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "i32", "i64", ], + op = "round", + tile_size = "1024", + types = [], ) # Predicate kernels @@ -1030,12 +1035,13 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_conj_kernels", - op = "conj", - tile_size = "256", - types = [ + jit_types = [ "c64", "c128", ], + op = "conj", + tile_size = "256", + types = [], unroll_factors = "2", ) @@ -1172,10 +1178,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "maximum", - tile_size = "1024", - types = [ "f16", "f32", "f64", @@ -1183,6 +1185,9 @@ gpu_kernel_library( "i64", "ui8", ], + op = "maximum", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1193,10 +1198,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "minimum", - tile_size = "1024", - types = [ "f16", "f32", "f64", @@ -1204,6 +1205,9 @@ gpu_kernel_library( "i64", "ui8", ], + op = "minimum", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1255,9 +1259,7 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_neg_kernels", - op = "neg", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", @@ -1268,6 +1270,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "neg", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1276,22 +1281,19 @@ gpu_kernel_library( jit_types = [ "i8", "i16", - ], - op = "pow", - tile_size = "1024", - types = [ "f16", "f32", "f64", "i64", ], + op = "pow", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_reciprocal_kernels", - op = "reciprocal", - tile_size = "256", - types = [ + jit_types = [ "c64", "c128", "f16", @@ -1299,6 +1301,9 @@ gpu_kernel_library( "f64", "i64", ], + op = "reciprocal", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1307,10 +1312,6 @@ gpu_kernel_library( jit_types = [ "i8", "i16", - ], - op = "sign", - tile_size = "256", - types = [ "f16", "f32", "f64", @@ -1319,6 +1320,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "sign", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1362,80 +1366,86 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_xdivy_kernels", - op = "xdivy", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xdivy", + tile_size = "1024", + types = [], unroll_factors = "4", ) # Logarithmic and exponential kernels gpu_kernel_library( name = "gpu_exp_kernels", - op = "exp", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "exp", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_expm1_kernels", - op = "expm1", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "expm1", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_log_kernels", - op = "log", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "log", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_log1p_kernels", - op = "log1p", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "log1p", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_xlogy_kernels", - op = "xlogy", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xlogy", + tile_size = "1024", + types = [], unroll_factors = "4", # For complex XlogyOp kernels, we don't use unrolling, it would only cause # slowdowns. @@ -1447,15 +1457,16 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_xlog1py_kernels", - op = "xlog1py", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xlog1py", + tile_size = "1024", + types = [], unroll_factors = "4", # For complex Xlog1pyOp kernels, we don't use unrolling, it would only cause # slowdowns. @@ -1469,25 +1480,27 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_sqrt_kernels", - op = "sqrt", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "sqrt", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_rsqrt_kernels", - op = "rsqrt", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "rsqrt", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1500,28 +1513,28 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "square", - tile_size = "1024", - types = [ "f16", "f32", "f64", "i64", ], + op = "square", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_squared_difference_kernels", - op = "squared_difference", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "i64", ], + op = "squared_difference", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1529,74 +1542,77 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_bitwise_and_kernels", - op = "bitwise_and", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_and", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_bitwise_or_kernels", - op = "bitwise_or", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_or", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_bitwise_xor_kernels", - op = "bitwise_xor", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_xor", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_invert_kernels", - op = "invert", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "invert", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_left_shift_kernels", - op = "left_shift", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "left_shift", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_right_shift_kernels", - op = "right_shift", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", @@ -1606,6 +1622,9 @@ gpu_kernel_library( "ui32", "ui64", ], + op = "right_shift", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1613,52 +1632,57 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_logical_not_kernels", + jit_types = ["i1"], op = "logical_not", tile_size = "256", - types = ["i1"], + types = [], ) gpu_kernel_library( name = "gpu_logical_and_kernels", - op = "logical_and", - tile_size = "1024", - types = [ + jit_types = [ "i1", ], + op = "logical_and", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_logical_or_kernels", - op = "logical_or", - tile_size = "1024", - types = [ + jit_types = [ "i1", ], + op = "logical_or", + tile_size = "1024", + types = [], ) # Erf kernels gpu_kernel_library( name = "gpu_erf_kernels", - op = "erf", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "erf", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_erfc_kernels", - op = "erfc", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "erfc", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1666,45 +1690,49 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_polygamma_kernels", - op = "polygamma", - tile_size = "256", - types = [ + jit_types = [ "f32", "f64", ], + op = "polygamma", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_digamma_kernels", - op = "digamma", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "digamma", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_lgamma_kernels", - op = "lgamma", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "lgamma", + tile_size = "256", + types = [], ) gpu_kernel_library( # The zeta kernels needs many registers so tile at 256. name = "gpu_zeta_kernels", - op = "zeta", - tile_size = "256", - types = [ + jit_types = [ "f32", "f64", ], + op = "zeta", + tile_size = "256", + types = [], # TODO(b/178388085): Enable unrolling after vectorization is fixed. # unroll_factors = "4", ) @@ -1731,61 +1759,64 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "relu", - tile_size = "256", - types = [ "f16", "f32", "f64", ], + op = "relu", + tile_size = "256", + types = [], unroll_factors = "16B", ) gpu_kernel_library( name = "gpu_elu_kernels", - op = "elu", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "elu", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_selu_kernels", - op = "selu", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "selu", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_sigmoid_kernels", - op = "sigmoid", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "sigmoid", + tile_size = "256", + types = [], ) # Kernels that support all floating-point types. [ gpu_kernel_library( name = "gpu_" + op + "_kernels", - op = op, - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = op, + tile_size = "256", + types = [], unroll_factors = "4", ) for op in [ @@ -1837,11 +1868,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - max_supported_rank = 8, - op = "select_v2", - tile_size = "256", - types = [ "i1", "i32", "i64", @@ -1851,6 +1877,10 @@ gpu_kernel_library( "c64", "c128", ], + max_supported_rank = 8, + op = "select_v2", + tile_size = "256", + types = [], ) gpu_kernel_library( @@ -1862,10 +1892,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "zeros_like", - tile_size = "1024", - types = [ "i1", "i64", "f16", @@ -1874,6 +1900,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "zeros_like", + tile_size = "1024", + types = [], ) gpu_kernel_library( @@ -1885,10 +1914,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "ones_like", - tile_size = "1024", - types = [ "i1", "i64", "f16", @@ -1897,14 +1922,18 @@ gpu_kernel_library( "c64", "c128", ], + op = "ones_like", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_next_after_kernels", - op = "next_after", - tile_size = "1024", - types = [ + jit_types = [ "f32", "f64", ], + op = "next_after", + tile_size = "1024", + types = [], ) diff --git a/tensorflow/core/kernels/ops_util_test.cc b/tensorflow/core/kernels/ops_util_test.cc index e449947e16def3..a3848b4dca6db6 100644 --- a/tensorflow/core/kernels/ops_util_test.cc +++ b/tensorflow/core/kernels/ops_util_test.cc @@ -107,13 +107,13 @@ class OpsUtilTest : public ::testing::Test { int64_t new_height, new_width, pad_top, pad_bottom, pad_left, pad_right; Status status = GetWindowedOutputSizeVerbose( pad_struct.input.in_height, pad_struct.input.filter_height, - pad_struct.input.row_stride, pad_struct.input.padding, &new_height, - &pad_top, &pad_bottom); + /*dilation_rate=*/1, pad_struct.input.row_stride, + pad_struct.input.padding, &new_height, &pad_top, &pad_bottom); EXPECT_EQ(status.code(), code) << status; status = GetWindowedOutputSizeVerbose( pad_struct.input.in_width, pad_struct.input.filter_width, - pad_struct.input.col_stride, pad_struct.input.padding, &new_width, - &pad_left, &pad_right); + /*dilation_rate=*/1, pad_struct.input.col_stride, + pad_struct.input.padding, &new_width, &pad_left, &pad_right); EXPECT_EQ(status.code(), code) << status; EXPECT_EQ(pad_struct.output.new_height, new_height); EXPECT_EQ(pad_struct.output.new_width, new_width); diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index b48287ae1442a4..407d6991608c7e 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -164,12 +164,14 @@ PoolParameters::PoolParameters(OpKernelContext* context, } if (depth_window == 1) { - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_rows, window_rows, row_stride, - padding, &out_height, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_cols, window_cols, col_stride, - padding, &out_width, &pad_left, &pad_right)); + OP_REQUIRES_OK( + context, GetWindowedOutputSizeVerbose( + tensor_in_rows, window_rows, /*dilation_rate=*/1, + row_stride, padding, &out_height, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + tensor_in_cols, window_cols, /*dilation_rate=*/1, + col_stride, padding, &out_width, &pad_left, &pad_right)); pad_depth = 0; out_depth = depth; } else { @@ -195,12 +197,14 @@ PoolParameters::PoolParameters(OpKernelContext* context, errors::Unimplemented("Depthwise max pooling is currently " "only implemented for CPU devices.")); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_rows, window_rows, row_stride, - padding, &out_height, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_cols, window_cols, col_stride, - padding, &out_width, &pad_left, &pad_right)); + OP_REQUIRES_OK( + context, GetWindowedOutputSizeVerbose( + tensor_in_rows, window_rows, /*dilation_rate=*/1, + row_stride, padding, &out_height, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + tensor_in_cols, window_cols, /*dilation_rate=*/1, + col_stride, padding, &out_width, &pad_left, &pad_right)); pad_depth = 0; out_depth = depth / depth_window; } diff --git a/tensorflow/core/kernels/ragged_cross_op.cc b/tensorflow/core/kernels/ragged_cross_op.cc index 31af55a893a562..71deb58c3c12d0 100644 --- a/tensorflow/core/kernels/ragged_cross_op.cc +++ b/tensorflow/core/kernels/ragged_cross_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -386,13 +387,32 @@ class RaggedCrossOp : public OpKernel { // Validate tensor shapes. for (int i = 0; i < num_ragged; ++i) { - if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape())) { - return errors::InvalidArgument( + if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape()) || + !TensorShapeUtils::IsVector(ragged_splits_list[i].shape())) { + return absl::InvalidArgumentError( "tf.ragged.cross only supports inputs with rank=2."); } - if (!TensorShapeUtils::IsVector(ragged_splits_list[i].shape()) || - (ragged_splits_list[i].NumElements() == 0)) { - return errors::InvalidArgument("Invalid RaggedTensor"); + if (ragged_splits_list[i].NumElements() == 0) { + return absl::InvalidArgumentError( + "Invalid RaggedTensor: Ragged splits must be non-empty."); + } + auto flat_row_splits = ragged_splits_list[i].flat(); + if (flat_row_splits(0) != 0) { + return absl::InvalidArgumentError( + "Invalid RaggedTensor: Ragged splits must start from 0."); + } + int64_t num_values = ragged_values_list[i].NumElements(); + if (flat_row_splits(flat_row_splits.size() - 1) != num_values) { + return absl::InvalidArgumentError( + "Invalid RaggedTensor: " + "Ragged splits must end with the number of values."); + } + for (int i = 1; i < flat_row_splits.size(); ++i) { + if (flat_row_splits(i - 1) > flat_row_splits(i)) { + return absl::InvalidArgumentError( + "Invalid RaggedTensor: " + "Ragged splits must be sorted in ascending order."); + } } } for (int i = 0; i < num_sparse; ++i) { diff --git a/tensorflow/core/kernels/stochastic_cast_op.cc b/tensorflow/core/kernels/stochastic_cast_op.cc index 626a00894da311..ba2954760610a6 100644 --- a/tensorflow/core/kernels/stochastic_cast_op.cc +++ b/tensorflow/core/kernels/stochastic_cast_op.cc @@ -91,9 +91,9 @@ REGISTER_CAST_TO_INT_CPU_KERNEL(double, int16); REGISTER_CAST_TO_INT_CPU_KERNEL(double, int32); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -REGISTER_CAST_TO_INT_CPU_KERNEL(half, int8); -REGISTER_CAST_TO_INT_CPU_KERNEL(half, int16); -REGISTER_CAST_TO_INT_CPU_KERNEL(half, int32); +REGISTER_CAST_TO_INT_GPU_KERNEL(half, int8); +REGISTER_CAST_TO_INT_GPU_KERNEL(half, int16); +REGISTER_CAST_TO_INT_GPU_KERNEL(half, int32); REGISTER_CAST_TO_INT_GPU_KERNEL(bfloat16, int8); REGISTER_CAST_TO_INT_GPU_KERNEL(bfloat16, int16); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 42eff63cab03b1..331cfa84728eea 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2702,11 +2702,11 @@ REGISTER_OP("ExtractImagePatches") int64_t output_rows, output_cols; int64_t padding_before, padding_after; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_rows, ksize_rows_eff, stride_rows, padding, &output_rows, - &padding_before, &padding_after)); + in_rows, ksize_rows_eff, /*dilation_rate=*/1, stride_rows, padding, + &output_rows, &padding_before, &padding_after)); TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_cols, ksize_cols_eff, stride_cols, padding, &output_cols, - &padding_before, &padding_after)); + in_cols, ksize_cols_eff, /*dilation_rate=*/1, stride_cols, padding, + &output_cols, &padding_before, &padding_after)); ShapeHandle output_shape = c->MakeShape( {batch_size_dim, output_rows, output_cols, output_depth_dim}); c->set_output(0, output_shape); @@ -2808,14 +2808,14 @@ REGISTER_OP("ExtractVolumePatches") int64_t output_planes, output_rows, output_cols; int64_t padding_before, padding_after; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_planes, ksize_planes, stride_planes, padding, &output_planes, - &padding_before, &padding_after)); + in_planes, ksize_planes, /*dilation_rate=*/1, stride_planes, padding, + &output_planes, &padding_before, &padding_after)); TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_rows, ksize_rows, stride_rows, padding, &output_rows, - &padding_before, &padding_after)); + in_rows, ksize_rows, /*dilation_rate=*/1, stride_rows, padding, + &output_rows, &padding_before, &padding_after)); TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_cols, ksize_cols, stride_cols, padding, &output_cols, - &padding_before, &padding_after)); + in_cols, ksize_cols, /*dilation_rate=*/1, stride_cols, padding, + &output_cols, &padding_before, &padding_after)); ShapeHandle output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols, output_depth_dim}); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index d28e2c61c3d51a..6e34569da586aa 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1169,11 +1169,11 @@ REGISTER_OP("Dilation2D") int64_t output_rows, output_cols; int64_t padding_before, padding_after; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_rows, filter_rows_eff, stride_rows, padding, &output_rows, - &padding_before, &padding_after)); + in_rows, filter_rows_eff, /*dilation_rate=*/1, stride_rows, padding, + &output_rows, &padding_before, &padding_after)); TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_cols, filter_cols_eff, stride_cols, padding, &output_cols, - &padding_before, &padding_after)); + in_cols, filter_cols_eff, /*dilation_rate=*/1, stride_cols, padding, + &output_cols, &padding_before, &padding_after)); ShapeHandle output_shape = c->MakeShape( {batch_size_dim, output_rows, output_cols, output_depth_dim}); diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 9c417995040205..a6195e98d0d1ab 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -1262,7 +1262,7 @@ tf_cc_test( tf_cc_test( name = "fake_python_env_test", - size = "small", + size = "medium", srcs = ["fake_python_env_test.cc"], args = [ "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py", diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.cc b/tensorflow/core/profiler/utils/hlo_proto_map.cc index 1cca0c182ede14..a0f90aaecb0d78 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_map.cc +++ b/tensorflow/core/profiler/utils/hlo_proto_map.cc @@ -53,41 +53,24 @@ ParseHloProtosFromXSpace(const XSpace& space) { const XPlane* raw_plane = FindPlaneWithName(space, kMetadataPlaneName); if (raw_plane != nullptr) { XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(raw_plane); - if (raw_plane->stats_size() > 0) { - // Fallback for legacy aggregated XPlane. - // TODO(b/235990417): Remove after 06/14/2023. - plane.ForEachStat([&](const XStatVisitor& stat) { - if (stat.ValueCase() != XStat::kBytesValue) return; - auto hlo_proto = std::make_unique(); - absl::string_view byte_value = stat.BytesValue(); - if (hlo_proto->ParseFromArray(byte_value.data(), byte_value.size())) { - hlo_protos.emplace_back(stat.Id(), std::move(hlo_proto)); - } - }); - } else { - const XStatMetadata* hlo_proto_stat_metadata = - plane.GetStatMetadataByType(StatType::kHloProto); - if (hlo_proto_stat_metadata == nullptr) { - // Fallback for legacy XPlane. - // TODO(b/235990417): Remove after 06/14/2023. - hlo_proto_stat_metadata = plane.GetStatMetadata(StatType::kHloProto); - } - if (hlo_proto_stat_metadata != nullptr) { - plane.ForEachEventMetadata( - [&](const XEventMetadataVisitor& event_metadata) { - auto hlo_proto_stat = event_metadata.GetStat( - StatType::kHloProto, *hlo_proto_stat_metadata); - if (!hlo_proto_stat) return; - if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return; - auto hlo_proto = std::make_unique(); - absl::string_view byte_value = hlo_proto_stat->BytesValue(); - if (hlo_proto->ParseFromArray(byte_value.data(), - byte_value.size())) { - hlo_protos.emplace_back(event_metadata.Id(), - std::move(hlo_proto)); - } - }); - } + + const XStatMetadata* hlo_proto_stat_metadata = + plane.GetStatMetadataByType(StatType::kHloProto); + if (hlo_proto_stat_metadata != nullptr) { + plane.ForEachEventMetadata( + [&](const XEventMetadataVisitor& event_metadata) { + auto hlo_proto_stat = event_metadata.GetStat( + StatType::kHloProto, *hlo_proto_stat_metadata); + if (!hlo_proto_stat) return; + if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return; + auto hlo_proto = std::make_unique(); + absl::string_view byte_value = hlo_proto_stat->BytesValue(); + if (hlo_proto->ParseFromArray(byte_value.data(), + byte_value.size())) { + hlo_protos.emplace_back(event_metadata.Id(), + std::move(hlo_proto)); + } + }); } } return hlo_protos; diff --git a/tensorflow/core/protobuf/fingerprint.proto b/tensorflow/core/protobuf/fingerprint.proto index 837b9a04d61db0..6ac5307ebacab7 100644 --- a/tensorflow/core/protobuf/fingerprint.proto +++ b/tensorflow/core/protobuf/fingerprint.proto @@ -27,4 +27,5 @@ message FingerprintDef { uint64 checkpoint_hash = 5; // Version specification of the fingerprint. VersionDef version = 6; + // TODO(b/290068219): add USM version when GA } diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 03427a1910b0ce..ef972445c7fc58 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1553 // Updated: 2023/7/10 +#define TF_GRAPH_DEF_VERSION 1560 // Updated: 2023/7/17 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/runtime_fallback/BUILD b/tensorflow/core/runtime_fallback/BUILD index 435d5474147060..e04b69e250d8b0 100644 --- a/tensorflow/core/runtime_fallback/BUILD +++ b/tensorflow/core/runtime_fallback/BUILD @@ -31,7 +31,6 @@ tf_cc_binary( deps = [ ":bef_executor_lib", "@com_google_absl//absl/strings", - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_kernels_alwayslink", "//tensorflow/core/platform:stream_executor", "//tensorflow/core/runtime_fallback/conversion:conversion_alwayslink", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_kernels_alwayslink", diff --git a/tensorflow/core/runtime_fallback/kernel/BUILD b/tensorflow/core/runtime_fallback/kernel/BUILD index 2aea6feb9c657d..23ca51aa4db1cf 100644 --- a/tensorflow/core/runtime_fallback/kernel/BUILD +++ b/tensorflow/core/runtime_fallback/kernel/BUILD @@ -492,10 +492,8 @@ cc_library( ], deps = [ ":kernel_fallback_compat_request_state", - ":kernel_fallback_tensor", ":kernel_fallback_utils", ":tensor_util", - "//tensorflow/core/runtime_fallback/runtime:kernel_utils", "//tensorflow/core/tfrt/utils:fallback_tensor", "//tensorflow/core/tfrt/utils:gpu_variables_table", "//tensorflow/core/tfrt/utils:tensor_util", diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc index 54b017d71a61b1..16ef9cd3eefa91 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/threadpool_interface.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/tfrt/graph_executor/config.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tfrt/host_context/resource_context.h" // from @tf_runtime #include "tfrt/support/pointer_util.h" // from @tf_runtime @@ -60,13 +61,6 @@ void FallbackResourceArray::SetResource( *resource_storage_[index], resources_.back().get()); } -static CancellationManager* GetDefaultCancellationManager() { - // TODO(b/167630926): Support cancellation by hooking up with TFRT's - // mechanism. - static auto* const default_cancellation_manager = new CancellationManager; - return default_cancellation_manager; -} - KernelFallbackCompatRequestState::KernelFallbackCompatRequestState( std::function)>* runner, const tensorflow::DeviceMgr* device_manager, int64_t step_id, @@ -85,7 +79,6 @@ KernelFallbackCompatRequestState::KernelFallbackCompatRequestState( ? collective_executor_handle_->get() : nullptr), rendezvous_(std::move(rendezvous)), - default_cancellation_manager_(GetDefaultCancellationManager()), device_manager_(device_manager), runner_table_(runner_table), resource_array_(resource_array), @@ -165,7 +158,9 @@ Status SetUpKernelFallbackCompatRequestContext( const absl::optional& model_metadata, std::function)>* runner, tfrt_stub::CostRecorder* cost_recorder, - tfrt::ResourceContext* client_graph_resource_context) { + tfrt::ResourceContext* client_graph_resource_context, + tensorflow::CancellationManager* cancellation_manager, + const tensorflow::tfrt_stub::RuntimeConfig* runtime_config) { DCHECK(builder); DCHECK(device_manager); DCHECK(pflr); @@ -181,6 +176,8 @@ Status SetUpKernelFallbackCompatRequestContext( fallback_request_state.set_cost_recorder(cost_recorder); fallback_request_state.set_client_graph_resource_context( client_graph_resource_context); + fallback_request_state.set_cancellation_manager(cancellation_manager); + fallback_request_state.set_runtime_config(runtime_config); return OkStatus(); } diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h index 0645f02481e690..201eae2e1c6f5d 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h @@ -138,7 +138,10 @@ class KernelFallbackCompatRequestState { std::function)>* runner() const { return runner_; } CancellationManager* cancellation_manager() const { - return default_cancellation_manager_; + return cancellation_manager_; + } + void set_cancellation_manager(CancellationManager* cancellation_manager) { + cancellation_manager_ = cancellation_manager; } RendezvousInterface* rendezvous() const { return rendezvous_.get(); } @@ -192,7 +195,7 @@ class KernelFallbackCompatRequestState { std::unique_ptr collective_executor_handle_; CollectiveExecutor* collective_executor_ = nullptr; core::RefCountPtr rendezvous_; - CancellationManager* default_cancellation_manager_ = nullptr; + CancellationManager* cancellation_manager_ = nullptr; const tensorflow::DeviceMgr* device_manager_ = nullptr; @@ -232,11 +235,13 @@ Status SetUpKernelFallbackCompatRequestContext( const tensorflow::ProcessFunctionLibraryRuntime* pflr, tfrt_stub::OpKernelRunnerTable* runner_table, FallbackResourceArray* resource_array, - tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr, - const std::optional& model_metadata = std::nullopt, - std::function)>* runner = nullptr, - tfrt_stub::CostRecorder* cost_recorder = nullptr, - tfrt::ResourceContext* client_graph_resource_context = nullptr); + tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, + const std::optional& model_metadata, + std::function)>* runner, + tfrt_stub::CostRecorder* cost_recorder, + tfrt::ResourceContext* client_graph_resource_context, + tensorflow::CancellationManager* cancellation_manager, + const tensorflow::tfrt_stub::RuntimeConfig* runtime_config); } // namespace tfd } // namespace tensorflow diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc index 426f189a503224..2309310a4d3a26 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc @@ -267,7 +267,12 @@ Status SetUpKernelFallbackCompatRequestContextForBatch( return SetUpKernelFallbackCompatRequestContext( builder, device_manager, pflr, runner_table, resource_array, - intra_op_threadpool, session_metadata, /*runner=*/nullptr); + intra_op_threadpool, session_metadata, + src_fallback_request_state->runner(), + src_fallback_request_state->cost_recorder(), + src_fallback_request_state->client_graph_resource_context(), + src_fallback_request_state->cancellation_manager(), + src_fallback_request_state->runtime_config()); } StatusOr> SetUpRequestContext( diff --git a/tensorflow/core/runtime_fallback/util/BUILD b/tensorflow/core/runtime_fallback/util/BUILD index a820dad23a3dda..ee575490dd35a6 100644 --- a/tensorflow/core/runtime_fallback/util/BUILD +++ b/tensorflow/core/runtime_fallback/util/BUILD @@ -81,7 +81,6 @@ cc_library( hdrs = ["fallback_test_util.h"], tags = ["no_oss"], deps = [ - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_request_context", "//tensorflow/core:framework", "//tensorflow/core/platform:threadpool_interface", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat", diff --git a/tensorflow/core/runtime_fallback/util/fallback_test_util.cc b/tensorflow/core/runtime_fallback/util/fallback_test_util.cc index 3e451b3199408d..43d617cca46664 100644 --- a/tensorflow/core/runtime_fallback/util/fallback_test_util.cc +++ b/tensorflow/core/runtime_fallback/util/fallback_test_util.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/core/runtime_fallback/util/fallback_test_util.h" #include +#include #include -#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h" #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h" @@ -69,10 +69,13 @@ tfrt::ExecutionContext CreateFallbackTestExecutionContext( status = SetUpKernelFallbackCompatRequestContext( &request_context_builder, eager_context->local_device_mgr(), eager_context->pflr(), runner_table, resource_array, - user_intra_op_threadpool); + user_intra_op_threadpool, /*model_metadata=*/std::nullopt, + /*runner=*/nullptr, /*cost_recorder=*/nullptr, + /*client_graph_resource_context=*/resource_context, + /*cancellation_manager=*/nullptr, + /*runtime_config=*/nullptr); TF_DCHECK_OK(status); - status = SetUpTfJitRtRequestContext(&request_context_builder); TF_DCHECK_OK(status); auto request_context = std::move(request_context_builder).build(); diff --git a/tensorflow/core/tfrt/fallback/BUILD b/tensorflow/core/tfrt/fallback/BUILD index 765008be386279..84566c792ef35d 100644 --- a/tensorflow/core/tfrt/fallback/BUILD +++ b/tensorflow/core/tfrt/fallback/BUILD @@ -126,9 +126,8 @@ tf_cc_test( srcs = ["cost_recorder_test.cc"], deps = [ ":cost_recorder", + ":op_cost_map_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core/platform:status", - "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/core/tfrt/fallback/cost_recorder.h b/tensorflow/core/tfrt/fallback/cost_recorder.h index 9929ca76e1028c..b275e0a5d35791 100644 --- a/tensorflow/core/tfrt/fallback/cost_recorder.h +++ b/tensorflow/core/tfrt/fallback/cost_recorder.h @@ -19,14 +19,12 @@ limitations under the License. #define TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ #include -#include #include #include "absl/container/flat_hash_map.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/tfrt/fallback/op_cost_map.pb.h" namespace tensorflow { namespace tfrt_stub { @@ -45,8 +43,7 @@ class CostRecorder { // Returns the normalized average execution duration of the op keyed by // `op_key`. If there is no record for `op_key`, returns the uint32_t::max to // avoid stream merging. Note that we don't use uint64_t::max because - // otherwise adding op costs would cause overflow. (See details in - // go/tfrt-stream-analysis-doc.) + // otherwise adding op costs would cause overflow. uint64_t GetCost(int64_t op_key) const; // Writes the op cost map (in format of `OpCostMapProto`) to a file specified @@ -65,8 +62,7 @@ class CostRecorder { uint64_t normalize_ratio_; mutable tensorflow::mutex op_cost_map_mutex_; - // Map op key to {sum of op execution duration in nanoseconds, #occurences of - // the op}. + // Map op key to {sum of op execution duration, #occurences of the op}. absl::flat_hash_map> op_cost_map_ TF_GUARDED_BY(op_cost_map_mutex_); }; diff --git a/tensorflow/core/tfrt/fallback/cost_recorder_test.cc b/tensorflow/core/tfrt/fallback/cost_recorder_test.cc index 6f5d6486eb1c82..827259c0990fb6 100644 --- a/tensorflow/core/tfrt/fallback/cost_recorder_test.cc +++ b/tensorflow/core/tfrt/fallback/cost_recorder_test.cc @@ -19,9 +19,8 @@ limitations under the License. #include #include -#include "absl/strings/string_view.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/fallback/op_cost_map.pb.h" namespace tensorflow { namespace tfrt_stub { diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 295e86a6ca267c..0eb60064cbecd1 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -29,6 +29,7 @@ cc_library( ":config", "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", "//tensorflow/core:core_cpu", + "//tensorflow/core/framework:tensor", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/tfrt/runtime:work_queue_interface", "//tensorflow/core/tfrt/utils:bridge_graph_analysis", @@ -63,7 +64,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tfrt:import_model", - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_request_context", "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", "//tensorflow/compiler/mlir/tfrt:transforms/update_op_cost_in_tfrt_mlir", "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:import_model", @@ -88,6 +88,7 @@ cc_library( "//tensorflow/core/tfrt/mlrt/interpreter:execute", "//tensorflow/core/tfrt/mlrt/kernel:context", "//tensorflow/core/tfrt/runtime", + "//tensorflow/core/tfrt/runtime:stream", "//tensorflow/core/tfrt/runtime:work_queue_interface", "//tensorflow/core/tfrt/stubs:tfrt_native_lowering_stub", "//tensorflow/core/tfrt/utils", @@ -129,6 +130,7 @@ tf_cc_test( "//tensorflow/cc:array_ops", "//tensorflow/cc:cc_ops", "//tensorflow/cc:const_op", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:test", "//tensorflow/core/framework:graph_proto_cc", "//tensorflow/core/framework:types_proto_cc", diff --git a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h index bb18cc39d38ad0..09f48b592f2f56 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h +++ b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h @@ -15,11 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTION_OPTIONS_H_ #define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTION_OPTIONS_H_ +#include #include #include +#include #include "absl/types/optional.h" #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/tfrt/graph_executor/config.h" @@ -101,6 +104,9 @@ struct GraphExecutionRunOptions { // If true, just-in-time host compilation is disabled, and then if the // specified graph is not compiled, the execution will return an error. bool disable_compilation = false; + + std::function)> + streamed_output_callback; }; // Creates the default `SessionOptions` from a `GraphExecutionOptions`. diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index c59ef950f95946..03079e960dd081 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -44,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h" #include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" @@ -70,6 +69,7 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/interpreter/execute.h" #include "tensorflow/core/tfrt/mlrt/kernel/context.h" #include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/runtime/stream.h" #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" #include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" @@ -234,9 +234,9 @@ StatusOr> CreateRequestInfo( fallback_request_state.set_client_graph_resource_context( client_graph_resource_context); fallback_request_state.set_runtime_config(&options.runtime_config); + fallback_request_state.set_cancellation_manager( + &request_info->cancellation_manager); - TF_RETURN_IF_ERROR( - tensorflow::SetUpTfJitRtRequestContext(&request_context_builder)); // Set priority in the builder. tfrt::RequestOptions request_options; request_options.priority = run_options.priority; @@ -265,7 +265,8 @@ tensorflow::Status GraphExecutionRunOnFunction( tfd::FallbackResourceArray* resource_array, const Runtime& runtime, const FallbackState& fallback_state, tfrt::RequestDeadlineTracker* req_deadline_tracker, - CostRecorder* cost_recorder) { + CostRecorder* cost_recorder, + std::optional stream_callback_id) { TF_ASSIGN_OR_RETURN( auto request_info, CreateRequestInfo(options, run_options, run_options.work_queue, @@ -273,10 +274,10 @@ tensorflow::Status GraphExecutionRunOnFunction( runner_table, resource_array, fallback_state, cost_recorder)); + int64_t request_id = request_info->tfrt_request_context->id(); tensorflow::profiler::TraceMeProducer traceme( // To TraceMeConsumers in RunHandlerThreadPool::WorkerLoop. - [request_id = request_info->tfrt_request_context->id(), signature_name, - &options, symbol_uids] { + [request_id, signature_name, &options, symbol_uids] { return tensorflow::profiler::TraceMeEncode( "TfrtModelRun", {{"_r", 1}, @@ -287,8 +288,7 @@ tensorflow::Status GraphExecutionRunOnFunction( {"tf_symbol_uid", symbol_uids.tf_symbol_uid}, {"tfrt_symbol_uid", symbol_uids.tfrt_symbol_uid}}); }, - tensorflow::profiler::ContextType::kTfrtExecutor, - request_info->tfrt_request_context->id()); + tensorflow::profiler::ContextType::kTfrtExecutor, request_id); // Only configure timer when the deadline is set. if (run_options.deadline.has_value()) { @@ -304,6 +304,23 @@ tensorflow::Status GraphExecutionRunOnFunction( deadline, request_info->tfrt_request_context); } + ScopedStreamCallback scoped_stream_callback; + + if (stream_callback_id.has_value()) { + if (!run_options.streamed_output_callback) { + return absl::InvalidArgumentError( + "streamed_output_callback is not provided for a streaming model."); + } + + auto streamed_output_callback = run_options.streamed_output_callback; + + TF_ASSIGN_OR_RETURN( + scoped_stream_callback, + GetGlobalStreamCallbackRegistry().Register( + options.model_metadata.name(), *stream_callback_id, + StepId(request_id), std::move(streamed_output_callback))); + } + if (loaded_executable) { auto function = loaded_executable->GetFunction(signature_name); if (!function) { @@ -431,9 +448,6 @@ StatusOr> GraphExecutor::Create( TfrtGraphExecutionState::Options graph_execution_state_options; graph_execution_state_options.run_placer_grappler_on_functions = options.run_placer_grappler_on_functions; - graph_execution_state_options.enable_tfrt_gpu = options.enable_tfrt_gpu; - graph_execution_state_options.use_bridge_for_gpu = - options.compile_options.use_bridge_for_gpu; options.compile_options.fuse_get_resource_ops_in_hoisting = !options.enable_mlrt; @@ -556,7 +570,8 @@ tensorflow::Status GraphExecutor::Run( &executable_context->resource_context, &loaded_client_graph.runner_table(), &loaded_client_graph.resource_array(), runtime(), fallback_state_, - &req_deadline_tracker_, cost_recorder.get())); + &req_deadline_tracker_, cost_recorder.get(), + loaded_client_graph.stream_callback_id())); if (cost_recorder != nullptr) { TF_RETURN_IF_ERROR( @@ -595,6 +610,11 @@ GraphExecutor::ImportAndCompileClientGraph( registry, mlir::MLIRContext::Threading::DISABLED); ASSIGN_OR_RETURN_IN_IMPORT( auto module, ImportClientGraphToMlirModule(client_graph, context.get())); + + TF_ASSIGN_OR_RETURN( + auto stream_callback_id, + CreateStreamCallbackId(options().model_metadata.name(), module.get())); + // TODO(b/278143179): Upload module w/o control flow. SymbolUids symbol_uids; symbol_uids.tf_symbol_uid = MaybeUploadMlirToXsymbol(module.get()); @@ -658,7 +678,8 @@ GraphExecutor::ImportAndCompileClientGraph( return std::make_unique( client_graph.name, std::move(symbol_uids), this, std::move(context), std::move(module_with_op_keys), std::move(module), - std::move(executable_context), options_.enable_online_cost_analysis); + std::move(executable_context), options_.enable_online_cost_analysis, + std::move(stream_callback_id)); } StatusOr> diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.h b/tensorflow/core/tfrt/graph_executor/graph_executor.h index f345d4191160e1..bde356e9107fd7 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.h +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.h @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/runtime/stream.h" #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h" #include "tensorflow/tsl/platform/thread_annotations.h" @@ -66,6 +67,8 @@ struct RequestInfo { WorkQueueInterface* request_queue = nullptr; // The task runner used by tensorflow::OpKernel. std::function)> runner; + + tensorflow::CancellationManager cancellation_manager; }; struct SymbolUids { @@ -91,6 +94,9 @@ StatusOr> CreateRequestInfo( // Note: `resource_context` is per-graph-executor and // `client_graph_resource_context` is per-loaded-client-graph. See the comment // above `GraphExecutor::resource_context_` about the todo to merge these two. +// +// TODO(chky): Refactor this function to take `LoadedClientGraph` instead of +// having a long list of parameters. tensorflow::Status GraphExecutionRunOnFunction( const GraphExecutionOptions& options, const GraphExecutionRunOptions& run_options, @@ -104,7 +110,8 @@ tensorflow::Status GraphExecutionRunOnFunction( tfd::FallbackResourceArray* resource_array, const Runtime& runtime, const FallbackState& fallback_state, tfrt::RequestDeadlineTracker* req_deadline_tracker, - CostRecorder* cost_recorder = nullptr); + CostRecorder* cost_recorder = nullptr, + std::optional stream_callback_id = std::nullopt); // Runs a MLRT function for executing tensorflow graphs. tensorflow::Status RunMlrtFunction( @@ -131,12 +138,14 @@ class GraphExecutor { mlir::OwningOpRef tf_mlir_with_op_keys, mlir::OwningOpRef tfrt_mlir, std::shared_ptr executable_context, - bool enable_online_cost_analysis) + bool enable_online_cost_analysis, + std::optional stream_callback_id) : name_(std::move(name)), symbol_uids_(std::move(symbol_uids)), graph_executor_(graph_executor), mlir_context_(std::move(mlir_context)), - executable_context_(std::move(executable_context)) { + executable_context_(std::move(executable_context)), + stream_callback_id_(std::move(stream_callback_id)) { if (enable_online_cost_analysis) { tf_mlir_with_op_keys_ = std::move(tf_mlir_with_op_keys); tfrt_mlir_ = std::move(tfrt_mlir); @@ -165,6 +174,10 @@ class GraphExecutor { tfd::FallbackResourceArray& resource_array() { return resource_array_; } SyncResourceState& sync_resource_state() { return sync_resource_state_; } + const std::optional& stream_callback_id() const { + return stream_callback_id_; + } + private: std::string name_; SymbolUids symbol_uids_; @@ -185,6 +198,8 @@ class GraphExecutor { TF_GUARDED_BY(executable_context_mu_); mutable absl::once_flag create_cost_recorder_once_; SyncResourceState sync_resource_state_; + + std::optional stream_callback_id_; }; // A subgraph constructed by specifying input/output tensors. diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc index 1b704c4660b0c6..34505ce6bcac25 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/tfrt/graph_executor/graph_executor.h" #include +#include #include #include #include @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" @@ -44,6 +46,8 @@ namespace tensorflow { namespace tfrt_stub { namespace { +using ::testing::status::StatusIs; + class GraphExecutorTest : public ::testing::TestWithParam {}; tensorflow::Status GetSimpleGraphDef(GraphDef& graph_def) { @@ -145,6 +149,104 @@ TEST_P(GraphExecutorTest, BasicWithOnlineCostAnalysis) { ::testing::ElementsAreArray({2})); } +REGISTER_OP("TestCancel") + .Input("x: T") + .Output("z: T") + .Attr("T: {int32}") + .SetShapeFn(::tensorflow::shape_inference::UnchangedShape); + +class TestCancelKernel : public OpKernel { + public: + explicit TestCancelKernel(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + auto status = absl::CancelledError(); + ctx->cancellation_manager()->StartCancelWithStatus(status); + ctx->SetStatus(status); + } +}; + +REGISTER_KERNEL_BUILDER(Name("TestCancel").Device(DEVICE_CPU), + TestCancelKernel); + +REGISTER_OP("TestIsCancelled").Output("z: T").Attr("T: {bool}").SetIsStateful(); + +class TestIsCancelledKernel : public OpKernel { + public: + explicit TestIsCancelledKernel(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + ctx->set_output( + 0, tensorflow::Tensor(ctx->cancellation_manager()->IsCancelled())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("TestIsCancelled").Device(DEVICE_CPU), + TestIsCancelledKernel); + +TEST_P(GraphExecutorTest, Cancellation) { + GraphDef graph_def; + + tensorflow::GraphDefBuilder builder( + tensorflow::GraphDefBuilder::kFailImmediately); + + const tensorflow::TensorShape tensor_shape({10, 9}); + tensorflow::Node* input = tensorflow::ops::SourceOp( + "Placeholder", builder.opts() + .WithName("input") + .WithAttr("dtype", tensorflow::DT_INT32) + .WithAttr("shape", tensor_shape)); + tensorflow::ops::SourceOp("TestIsCancelled", + builder.opts() + .WithName("is_cancelled") + .WithAttr("T", tensorflow::DT_BOOL)); + tensorflow::ops::UnaryOp("TestCancel", input, + builder.opts() + .WithName("test_cancel") + .WithAttr("T", tensorflow::DT_INT32)); + + TF_ASSERT_OK(builder.ToGraphDef(&graph_def)); + + auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); + GraphExecutor::Options options(runtime.get()); + options.enable_mlrt = GetParam(); + + TF_ASSERT_OK_AND_ASSIGN( + auto fallback_state, + tensorflow::tfrt_stub::FallbackState::Create( + CreateDefaultSessionOptions(options), graph_def.library())) + auto resource_context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto graph_executor, + GraphExecutor::Create(std::move(options), *fallback_state, + std::move(resource_context), graph_def, + GetKernelRegistry())); + { + std::vector> inputs; + inputs.push_back({"input", CreateTfTensor( + /*shape=*/{1, 3}, /*data=*/{1, 1, 1})}); + + std::vector outputs; + EXPECT_THAT(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"test_cancel:0"}, + /*target_tensor_names=*/{}, &outputs), + StatusIs(absl::StatusCode::kCancelled)); + } + + { + std::vector outputs; + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, /*inputs=*/{}, + /*output_tensor_names=*/{"is_cancelled:0"}, + /*target_tensor_names=*/{}, &outputs)); + ASSERT_EQ(outputs.size(), 1); + + EXPECT_THAT(GetTfTensorData(outputs[0]), + ::testing::ElementsAreArray({false})); + } +} + INSTANTIATE_TEST_SUITE_P(GraphExecutorTestSuite, GraphExecutorTest, ::testing::Bool()); @@ -155,14 +257,16 @@ TEST_F(GraphExecutorTest, DoOnlineCostAnalysisExactlyOnce) { /*mlir_context=*/nullptr, /*tf_mlir_with_op_keys=*/{}, /*tfrt_mlir=*/{}, /*executable_context=*/nullptr, - /*enable_online_cost_analysis=*/true); + /*enable_online_cost_analysis=*/true, + /*stream_callback_id=*/std::nullopt); GraphExecutor::LoadedClientGraph loaded_client_graph_1( "name1", /*symbol_uids=*/{}, /*graph_executor=*/nullptr, /*mlir_context=*/nullptr, /*tf_mlir_with_op_keys=*/{}, /*tfrt_mlir=*/{}, /*executable_context=*/nullptr, - /*enable_online_cost_analysis=*/true); + /*enable_online_cost_analysis=*/true, + /*stream_callback_id=*/std::nullopt); // For each `LoadedClientGraph`, `MaybeCreateCostRecorder()` only returns a // cost recorder for once. diff --git a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc index 0d274cab8ada57..20fd6d69a3ef2f 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc @@ -322,6 +322,11 @@ void MlrtBatchResource::ProcessFuncBatchImpl( fallback_request_state.set_client_graph_resource_context( caller_fallback_request_state.client_graph_resource_context()); + fallback_request_state.set_cancellation_manager( + caller_fallback_request_state.cancellation_manager()); + fallback_request_state.set_runtime_config( + caller_fallback_request_state.runtime_config()); + tensorflow::profiler::TraceMeProducer activity( // To TraceMeConsumers in WorkQueue. [step_id] { diff --git a/tensorflow/core/tfrt/runtime/BUILD b/tensorflow/core/tfrt/runtime/BUILD index 16978e7ce9306c..c4e50acd172bb5 100644 --- a/tensorflow/core/tfrt/runtime/BUILD +++ b/tensorflow/core/tfrt/runtime/BUILD @@ -19,6 +19,7 @@ package_group( # copybara:uncomment "//learning/brain/experimental/tfrt/...", # copybara:uncomment "//learning/brain/tfrt/...", # copybara:uncomment "//learning/infra/mira/...", + # copybara:uncomment "//learning/pathways/serving/...", # copybara:uncomment "//learning/serving/...", # copybara:uncomment "//quality/webanswers/servo2/...", ], @@ -101,6 +102,54 @@ cc_library( ], ) +cc_library( + name = "stream", + srcs = ["stream.cc"], + hdrs = ["stream.h"], + deps = [ + ":channel", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:tensor_proto_cc", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:random", + "//tensorflow/tsl/profiler/lib:traceme", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/utility", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "channel", + hdrs = ["channel.h"], + deps = [ + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "stream_test", + srcs = ["stream_test.cc"], + deps = [ + ":stream", + "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/tfrt/saved_model:saved_model_testutil", + "//tensorflow/tsl/platform:env", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_test( name = "tf_threadpool_concurrent_work_queue_test", srcs = ["tf_threadpool_concurrent_work_queue_test.cc"], @@ -120,3 +169,14 @@ tf_cc_test( "@tf_runtime//:support", ], ) + +tf_cc_test( + name = "channel_test", + srcs = ["channel_test.cc"], + deps = [ + ":channel", + "//tensorflow/tsl/platform:env", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/core/tfrt/runtime/channel.h b/tensorflow/core/tfrt/runtime/channel.h new file mode 100644 index 00000000000000..5a01e78677064f --- /dev/null +++ b/tensorflow/core/tfrt/runtime/channel.h @@ -0,0 +1,79 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_CHANNEL_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_CHANNEL_H_ + +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" + +namespace tensorflow { +namespace tfrt_stub { + +// An unbounded queue for communicating between threads. This class is +// thread-safe. +template +class UnboundedChannel { + public: + absl::Status Write(T value) { + absl::MutexLock lock(&mu_); + + if (closed_) { + return absl::InternalError( + "Failed to write to the UnboundedChannel that is closed."); + } + + channel_.push(std::move(value)); + + return absl::OkStatus(); + } + + bool Read(T& value) { + absl::MutexLock lock(&mu_); + + mu_.Await(absl::Condition( + +[](UnboundedChannel* channel) ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return !channel->channel_.empty() || channel->closed_; + }, + this)); + + if (!channel_.empty()) { + value = std::move(channel_.front()); + channel_.pop(); + return true; + } + + // If channel_ is empty, then it must be closed at this point. + DCHECK(closed_); + return false; + } + + void Close() { + absl::MutexLock lock(&mu_); + closed_ = true; + } + + private: + absl::Mutex mu_; + std::queue channel_ ABSL_GUARDED_BY(mu_); + bool closed_ ABSL_GUARDED_BY(mu_) = false; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_CHANNEL_H_ diff --git a/tensorflow/core/tfrt/runtime/channel_test.cc b/tensorflow/core/tfrt/runtime/channel_test.cc new file mode 100644 index 00000000000000..ed30b88720d479 --- /dev/null +++ b/tensorflow/core/tfrt/runtime/channel_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/runtime/channel.h" + +#include +#include + +#include +#include +#include "absl/synchronization/blocking_counter.h" +#include "tensorflow/tsl/platform/env.h" + +namespace tensorflow { +namespace tfrt_stub { +namespace { + +using ::testing::ElementsAreArray; +using ::testing::UnorderedElementsAreArray; +using ::testing::status::StatusIs; + +TEST(ChannelTest, Basic) { + UnboundedChannel channel; + + std::vector expected(100); + std::iota(expected.begin(), expected.end(), 0); + + tsl::Env::Default()->SchedClosure([&]() { + for (int v : expected) { + CHECK_OK(channel.Write(v)); + } + channel.Close(); + }); + + std::vector outputs; + int v = -1; + while (channel.Read(v)) { + outputs.push_back(v); + } + + EXPECT_THAT(outputs, ElementsAreArray(expected)); + + EXPECT_THAT(channel.Write(100), StatusIs(absl::StatusCode::kInternal)); +} + +TEST(ChannelTest, MultipleWriters) { + UnboundedChannel channel; + + std::vector expected(100); + std::iota(expected.begin(), expected.end(), 0); + + tsl::Env::Default()->SchedClosure([&]() { + absl::BlockingCounter bcount(expected.size()); + for (int v : expected) { + tsl::Env::Default()->SchedClosure([&, v]() { + CHECK_OK(channel.Write(v)); + bcount.DecrementCount(); + }); + } + bcount.Wait(); + channel.Close(); + }); + + std::vector outputs; + int v = 0; + while (channel.Read(v)) { + outputs.push_back(v); + } + + EXPECT_THAT(outputs, UnorderedElementsAreArray(expected)); +} + +TEST(ChannelTest, MultipleReaders) { + UnboundedChannel channel; + + std::vector expected(100); + std::iota(expected.begin(), expected.end(), 0); + + absl::Mutex mu; + std::vector outputs; + + int num_readers = 200; + absl::BlockingCounter bcount(num_readers); + for (int i = 0; i < num_readers; ++i) { + tsl::Env::Default()->SchedClosure([&]() { + int v = 0; + while (channel.Read(v)) { + absl::MutexLock lock(&mu); + outputs.push_back(v); + } + bcount.DecrementCount(); + }); + } + + for (int v : expected) { + CHECK_OK(channel.Write(v)); + } + channel.Close(); + + bcount.Wait(); + EXPECT_THAT(outputs, UnorderedElementsAreArray(expected)); +} + +TEST(ChannelTest, FullyBuffered) { + UnboundedChannel channel; + + std::vector expected(100); + std::iota(expected.begin(), expected.end(), 0); + + for (int v : expected) { + CHECK_OK(channel.Write(v)); + } + channel.Close(); + + std::vector outputs; + int v = -1; + while (channel.Read(v)) { + outputs.push_back(v); + } + + EXPECT_THAT(outputs, ElementsAreArray(expected)); +} + +} // namespace +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/runtime/stream.cc b/tensorflow/core/tfrt/runtime/stream.cc new file mode 100644 index 00000000000000..01ed8ad9e6824f --- /dev/null +++ b/tensorflow/core/tfrt/runtime/stream.cc @@ -0,0 +1,211 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/runtime/stream.h" + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/utility/utility.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/tsl/platform/random.h" +#include "tensorflow/tsl/profiler/lib/traceme.h" + +namespace tensorflow { +namespace tfrt_stub { + +absl::StatusOr> CreateStreamCallbackId( + absl::string_view model_name, mlir::ModuleOp module) { + mlir::Builder builder(module.getContext()); + + // Inject information about the callback to `tf.PwStreamResults` ops. The + // attribute names must match `PwStreamResult` op's implementation. + + std::vector ops; + module->walk([&](mlir::TF::PwStreamResultsOp op) { ops.push_back(op); }); + + if (ops.empty()) { + return std::nullopt; + } + + auto& stream_interface = GetGlobalStreamCallbackRegistry().stream_interface(); + + auto controller_address = stream_interface.controller_address(); + auto controller_address_attr = builder.getStringAttr(controller_address); + + auto model_name_attr = builder.getStringAttr(model_name); + + // We use int64_t instead of uint64_t returned by `New64()` because + // TensorFlow doesn't support uint64 attributes. + const StreamCallbackId callback_id( + static_cast(tsl::random::New64())); + auto callback_id_attr = builder.getI64IntegerAttr(callback_id.id); + + for (auto op : ops) { + op->setAttr("_controller_address", controller_address_attr); + op->setAttr("_model_name", model_name_attr); + op->setAttr("_callback_id", callback_id_attr); + } + + return callback_id; +} + +absl::StatusOr StreamCallbackRegistry::Register( + absl::string_view model_name, StreamCallbackId callback_id, StepId step_id, + absl::AnyInvocable< + void(absl::flat_hash_map)> + callback) { + absl::MutexLock l(&mu_); + + const auto [it, inserted] = + stream_callbacks_.insert({std::make_pair(callback_id, step_id), nullptr}); + if (!inserted) { + return absl::AlreadyExistsError(absl::StrCat( + "Stream callback ", callback_id, " @ ", step_id, " already exists")); + } + + it->second = std::make_unique(); + it->second->thread = absl::WrapUnique(tsl::Env::Default()->StartThread( + tensorflow::ThreadOptions(), + /*name=*/absl::StrCat("stream_handler_", callback_id, "_", step_id), + [model_name = std::string(model_name), callback_id, step_id, + callback = std::move(callback), state = it->second.get(), + this]() mutable { + StreamedResult result; + while (state->channel.Read(result)) { + absl::Duration dequeue_latency = absl::Now() - result.enqueued_time; + interface_->RecordDequeueLatency(model_name, dequeue_latency); + + tsl::profiler::TraceMe trace_me("StreamCallbackInvocation"); + trace_me.AppendMetadata([&]() { + return tsl::profiler::TraceMeEncode({ + {"callback_id", callback_id.id}, + {"step_id", step_id.id}, + }); + }); + + absl::Time start_time = absl::Now(); + callback(std::move(result.tensors)); + interface_->RecordCallbackLatency(model_name, + absl::Now() - start_time); + } + })); + + return ScopedStreamCallback(this, callback_id, step_id); +} + +absl::Status StreamCallbackRegistry::Write(StreamCallbackId callback_id, + StepId step_id, + StreamedResult result) { + absl::MutexLock lock(&mu_); + auto iter = stream_callbacks_.find({callback_id, step_id}); + if (iter == stream_callbacks_.end()) { + return absl::NotFoundError(absl::StrCat( + "Stream callback ", callback_id, " @ ", step_id, + " does not exist; this usually indicates that a streaming signature " + "was called by a non-streaming request")); + } + + auto* state = iter->second.get(); + DCHECK(state); + return state->channel.Write(std::move(result)); +} + +std::unique_ptr +StreamCallbackRegistry::Unregister(StreamCallbackId callback_id, + StepId step_id) { + absl::MutexLock l(&mu_); + const auto it = stream_callbacks_.find({callback_id, step_id}); + if (it == stream_callbacks_.end()) { + return nullptr; + } + auto state = std::move(it->second); + stream_callbacks_.erase(it); + return state; +} + +ScopedStreamCallback::ScopedStreamCallback(ScopedStreamCallback&& other) + : registry_(other.registry_), + callback_id_(other.callback_id_), + step_id_(other.step_id_) { + other.callback_id_ = std::nullopt; + other.step_id_ = StepId::GetInvalidStepId(); +} + +ScopedStreamCallback& ScopedStreamCallback::operator=( + ScopedStreamCallback&& other) { + Unregister(); + + registry_ = other.registry_; + callback_id_ = other.callback_id_; + step_id_ = other.step_id_; + other.callback_id_ = std::nullopt; + other.step_id_ = StepId::GetInvalidStepId(); + + return *this; +} + +void ScopedStreamCallback::Unregister() { + if (!callback_id_.has_value()) { + return; + } + + tsl::profiler::TraceMe trace_me("ScopedStreamCallback::Unregister"); + trace_me.AppendMetadata([&]() { + return tsl::profiler::TraceMeEncode({ + {"callback_id", callback_id_->id}, + {"step_id", step_id_.id}, + }); + }); + + DCHECK(registry_); + auto state = registry_->Unregister(*callback_id_, step_id_); + DCHECK(state); + + // At this point, it is safe to close the channel. + state->channel.Close(); + + // Wait until the stream handler finishes. + state->thread.reset(); + + callback_id_.reset(); +} + +StreamInterfaceFactory& GetGlobalStreamInterfaceFactory() { + static auto* stream_interface_factory = new StreamInterfaceFactory; + return *stream_interface_factory; +} + +StreamCallbackRegistry& GetGlobalStreamCallbackRegistry() { + static auto* stream_callback_registry = new StreamCallbackRegistry( + GetGlobalStreamInterfaceFactory().CreateStreamInterface().value()); + return *stream_callback_registry; +} + +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/runtime/stream.h b/tensorflow/core/tfrt/runtime/stream.h new file mode 100644 index 00000000000000..7fea8ffe88c01b --- /dev/null +++ b/tensorflow/core/tfrt/runtime/stream.h @@ -0,0 +1,221 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and limitations +under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/runtime/channel.h" +#include "tensorflow/tsl/platform/env.h" + +namespace tensorflow { +namespace tfrt_stub { + +template +struct SafeId { + SafeId() : id(0) {} + explicit constexpr SafeId(int64_t id) : id(id) {} + + using Base = SafeId; + + int64_t id; + + friend bool operator==(const Derived& x, const Derived& y) { + return x.id == y.id; + } + + template + friend void AbslStringify(Sink& sink, const Derived& x) { + absl::Format(&sink, "%d", x.id); + } + + template + friend H AbslHashValue(H h, const Derived& x) { + return H::combine(std::move(h), x.id); + } +}; + +struct StreamedResult { + absl::flat_hash_map tensors; + absl::Time enqueued_time; +}; + +struct StreamCallbackId : SafeId { + using Base::Base; +}; + +struct StepId : SafeId { + using Base::Base; + + bool valid() const { return id != 0; } + static constexpr StepId GetInvalidStepId() { return StepId(0); } +}; + +class StreamInterface { + public: + explicit StreamInterface(std::string controller_address) + : controller_address_(std::move(controller_address)) {} + virtual ~StreamInterface() = default; + + absl::string_view controller_address() const { return controller_address_; } + + virtual void RecordDequeueLatency(absl::string_view model_name, + absl::Duration latency) {} + + virtual void RecordCallbackLatency(absl::string_view model_name, + absl::Duration latency) {} + + private: + std::string controller_address_; +}; + +class ScopedStreamCallback; + +class StreamInterfaceFactory { + public: + void Register(absl::AnyInvocable< + absl::StatusOr>() const> + interface_factory) { + absl::MutexLock lock(&mu_); + interface_factory_ = std::move(interface_factory); + } + + absl::StatusOr> CreateStreamInterface() + const { + absl::MutexLock lock(&mu_); + return interface_factory_(); + } + + private: + mutable absl::Mutex mu_; + absl::AnyInvocable>() const> + interface_factory_ ABSL_GUARDED_BY(mu_) = []() { + return absl::InternalError( + "The factory for StreamInterface is not registered."); + }; +}; + +// Returns the global factory for the stream interface. The factory for the +// stream interface must be registered first before calling +// GetGlobalStreamCallbackRegistry(). +StreamInterfaceFactory& GetGlobalStreamInterfaceFactory(); + +// Mapping from tuples of (callback_id, step_id) to callback states. The mapping +// is stored in a global variable so that it can be shared between +// `ScopedStreamCallback` and `InvokeStreamCallbackOp`. +// +// This class is thread-safe. +class StreamCallbackRegistry { + public: + explicit StreamCallbackRegistry(std::unique_ptr interface) + : interface_(std::move(interface)) { + DCHECK(interface_); + } + + // Registers a callback under the given id. A stream callback is uniquely + // identified by a tuple of a callback id (unique to each executable) and a + // step id (unique to each invocation of a given executable). Returns an RAII + // object that removes the callback from the registry on its deallocation, or + // an error if the id already exists in the registry. + // + // If a program runs `tf.PwStreamResults` with a matching callback/step id, + // `callback` will be called with the arguments of `tf.PwStreamResults`. + // + // All invocations to `callback` are handled serially by a single thread, so + // `callback` doesn't need to be thread-safe even if multiple + // `tf.PwStreamResults` ops may run concurrently. + absl::StatusOr Register( + absl::string_view model_name, StreamCallbackId callback_id, + StepId step_id, + absl::AnyInvocable< + void(absl::flat_hash_map)> + callback); + + absl::Status Write(StreamCallbackId callback_id, StepId step_id, + StreamedResult result); + + StreamInterface& stream_interface() const { return *interface_; } + + private: + friend class ScopedStreamCallback; + + struct CallbackState { + std::unique_ptr thread; + UnboundedChannel channel; + }; + + std::unique_ptr Unregister(StreamCallbackId callback_id, + StepId step_id); + + std::unique_ptr interface_; + + mutable absl::Mutex mu_; + absl::flat_hash_map, + std::unique_ptr> + stream_callbacks_ ABSL_GUARDED_BY(mu_); +}; + +// Returns the global registry for the stream callbacks. The stream interface +// must have been registered through GetGlobalStreamInterfaceFactory() before +// calling this function. +StreamCallbackRegistry& GetGlobalStreamCallbackRegistry(); + +// Creates a new stream callback id and rewrites the given module with +// information required to trigger this callback remotely. Returns the callback +// id, or `std::nullopt` if the module has no stream outputs. +absl::StatusOr> CreateStreamCallbackId( + absl::string_view model_name, mlir::ModuleOp module); + +// Implements an RAII object that registers a callback to be called on receiving +// streamed tensors. +class ScopedStreamCallback { + public: + ScopedStreamCallback() = default; + + // Moveable but not copyable. + ScopedStreamCallback(ScopedStreamCallback&& other); + ScopedStreamCallback& operator=(ScopedStreamCallback&& other); + + ~ScopedStreamCallback() { Unregister(); } + + private: + friend class StreamCallbackRegistry; + + explicit ScopedStreamCallback(StreamCallbackRegistry* registry, + StreamCallbackId callback_id, StepId step_id) + : registry_(registry), callback_id_(callback_id), step_id_(step_id) {} + + void Unregister(); + + StreamCallbackRegistry* registry_ = nullptr; + std::optional callback_id_; + StepId step_id_ = StepId::GetInvalidStepId(); +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ diff --git a/tensorflow/core/tfrt/runtime/stream_test.cc b/tensorflow/core/tfrt/runtime/stream_test.cc new file mode 100644 index 00000000000000..df377486627941 --- /dev/null +++ b/tensorflow/core/tfrt/runtime/stream_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/runtime/stream.h" + +#include +#include +#include + +#include +#include +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" +#include "tensorflow/tsl/platform/env.h" + +namespace tensorflow { +namespace tfrt_stub { +namespace { + +using ::tensorflow::test::AsTensor; +using ::testing::AnyOf; +using ::testing::ElementsAreArray; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class TestStreamInterface : public StreamInterface { + public: + TestStreamInterface() : StreamInterface("test_address") {} +}; + +const bool kUnused = []() { + GetGlobalStreamInterfaceFactory().Register( + []() { return std::make_unique(); }); + return true; +}(); + +TEST(StreamTest, Simple) { + StreamCallbackId callback_id(1234); + StepId step_id(5678); + + std::vector> outputs; + + { + ASSERT_OK_AND_ASSIGN( + auto scoped_stream_callback, + GetGlobalStreamCallbackRegistry().Register( + "test_model", callback_id, step_id, + [&](absl::flat_hash_map arg) { + outputs.push_back(std::move(arg)); + })); + + std::vector> expected = + {{{"a", AsTensor({100})}, {"b", AsTensor({200})}}, + {{"c", AsTensor({300})}}}; + auto thread = absl::WrapUnique(tsl::Env::Default()->StartThread( + tsl::ThreadOptions(), "fake_stream_client", [&]() { + for (const auto& map : expected) { + CHECK_OK(GetGlobalStreamCallbackRegistry().Write( + callback_id, step_id, {map, absl::Now()})); + } + })); + } + + EXPECT_EQ(outputs.size(), 2); + EXPECT_THAT(GetTfTensorData(outputs[0]["a"]), + ElementsAreArray({100})); + EXPECT_THAT(GetTfTensorData(outputs[0]["b"]), + ElementsAreArray({200})); + EXPECT_THAT(GetTfTensorData(outputs[1]["c"]), + ElementsAreArray({300})); +} + +TEST(StreamTest, MultipleWriters) { + StreamCallbackId callback_id(1234); + StepId step_id(5678); + + std::vector>> outputs; + + { + ASSERT_OK_AND_ASSIGN( + auto scoped_stream_callback, + GetGlobalStreamCallbackRegistry().Register( + "test_model", callback_id, step_id, + [&](absl::flat_hash_map arg) { + absl::flat_hash_map> out; + for (const auto& p : arg) { + out[p.first] = GetTfTensorData(p.second); + } + outputs.push_back(std::move(out)); + })); + + std::vector> expected = + {{{"a", AsTensor({100})}, {"b", AsTensor({200})}}, + {{"c", AsTensor({300})}}}; + + for (const auto& p : expected) { + tsl::Env::Default()->SchedClosure([callback_id, step_id, p]() { + // The stream callback may be dropped early, and in that case we ignore + // the error. + GetGlobalStreamCallbackRegistry() + .Write(callback_id, step_id, {p, absl::Now()}) + .IgnoreError(); + }); + } + + absl::SleepFor(absl::Microseconds(100)); + } + + LOG(INFO) << "StreamCallback receives " << outputs.size() << " outputs."; + + for (const auto& output : outputs) { + EXPECT_THAT( + output, + AnyOf(UnorderedElementsAre(Pair("a", ElementsAreArray({100})), + Pair("b", ElementsAreArray({200}))), + UnorderedElementsAre(Pair("c", ElementsAreArray({300}))))); + } +} + +} // namespace +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index 98eb3e205c53d5..3f08eabe5e5626 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -16,10 +16,7 @@ package_group( # copybara:uncomment "//learning/serving/...", "//tensorflow/core/runtime_fallback/...", "//tensorflow/core/tfrt/mlrt/application/tensorflow/tests/...", - "//tensorflow/core/tfrt/saved_model/tests/...", - "//tensorflow/core/tfrt/graph_executor/...", - "//tensorflow/core/tfrt/tfrt_session/...", - "//tensorflow/core/tfrt/utils/debug/...", + "//tensorflow/core/tfrt/...", "//tensorflow_serving/...", "//tensorflow/core/tfrt/saved_model/python/...", # copybara:uncomment "//platforms/xla/tests/saved_models/...", @@ -150,8 +147,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_kernels_alwayslink", - "//tensorflow/compiler/mlir/tfrt:tfrt_jitrt_passes", "//tensorflow/core/framework:graph_proto_cc", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/protobuf:for_core_protos_cc", diff --git a/tensorflow/core/tfrt/saved_model/python/BUILD b/tensorflow/core/tfrt/saved_model/python/BUILD index 92d1944a8cf1a7..2b4fb5668cc835 100644 --- a/tensorflow/core/tfrt/saved_model/python/BUILD +++ b/tensorflow/core/tfrt/saved_model/python/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.default.bzl", "tf_python_pybind_extension") -load("//tensorflow:pytype.default.bzl", "pytype_strict_binary") +load("//tensorflow:pytype.default.bzl", "pytype_strict_binary", "pytype_strict_contrib_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -29,19 +29,6 @@ pytype_strict_binary( ], ) -py_binary( - name = "saved_model_aot_compile_py", - srcs = ["saved_model_aot_compile.py"], - main = "saved_model_aot_compile.py", - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":_pywrap_saved_model_aot_compile", - "//tensorflow/core/tfrt/graph_executor/python:_pywrap_graph_execution_options", - "@absl_py//absl:app", - ], -) - tf_python_pybind_extension( name = "_pywrap_saved_model_aot_compile", srcs = ["saved_model_aot_compile_wrapper.cc"], @@ -89,3 +76,22 @@ tf_python_pybind_extension( "@pybind11_abseil//pybind11_abseil:status_casters", ], ) + +pytype_strict_contrib_test( + name = "saved_model_aot_compile_test", + size = "small", + srcs = [ + "saved_model_aot_compile_test.py", + ], + data = [ + "//learning/brain/tfrt/cpp_tests/gpu_inference:testdata", + ], + python_version = "PY3", + deps = [ + ":_pywrap_saved_model_aot_compile", + "//base/python:pywrapbase", + "//tensorflow/python/platform:client_testlib", + "//testing/pybase", + "//third_party/py/lingvo:compat", + ], +) diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile.py b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile.py deleted file mode 100644 index da11aa9b22aa68..00000000000000 --- a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Test .py file for pybind11 files for AotOptions and AotCompileSavedModel, currently unable to test due to nullptr in AotOptions.""" - - -from absl import app -from tensorflow.core.tfrt.graph_executor.python import _pywrap_graph_execution_options -from tensorflow.core.tfrt.saved_model.python import _pywrap_saved_model_aot_compile - - -def main(unused_argv): - if not _pywrap_saved_model_aot_compile: - return - try: - # Test for creating an instance of GraphExecutionOptions - test = _pywrap_graph_execution_options.GraphExecutionOptions() - print(test) - - # Executes AoTOptions and AotCompileSavedModel for Wrapping Tests - _pywrap_saved_model_aot_compile.AotOptions() - - # TODO(cesarmagana): Once AotCompileSavedModel is complete - # update this test script to read from CNS - _pywrap_saved_model_aot_compile.AotCompileSavedModel("random") - - # Could also do except status.StatusNotOk if testing for AotCompileSavedModel - except Exception as exception: # pylint: disable=broad-exception-caught - print(exception) - - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py new file mode 100644 index 00000000000000..2e849fb743855c --- /dev/null +++ b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py @@ -0,0 +1,37 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os + +import lingvo.compat as tf + +from tensorflow.core.tfrt.saved_model.python import _pywrap_saved_model_aot_compile +from tensorflow.python.platform import test + + +class SavedModelAotCompileTest(test.TestCase): + + def testVerify_saved_model(self): + outputpath = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR") + filepath = "learning/brain/tfrt/cpp_tests/gpu_inference/test_data/translate_converted_placed/" + _pywrap_saved_model_aot_compile.AotCompileSavedModel( + filepath, _pywrap_saved_model_aot_compile.AotOptions(), outputpath + ) + + # Verifies that .pbtxt is created correctly in the output directory + self.assertTrue(tf.io.gfile.exists(outputpath + "/aot_saved_model.pbtxt")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc index b8b0f6985007d4..7c1e31fab55e7e 100644 --- a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc +++ b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc @@ -25,12 +25,8 @@ namespace py = pybind11; PYBIND11_MODULE(_pywrap_saved_model_aot_compile, m) { py::google::ImportStatusModule(); - py::class_(m, "AotOptions", - py::dynamic_attr()) - .def(py::init<>()) - .def_readwrite( - "graph_execution_options", - &tensorflow::tfrt_stub::AotOptions::graph_execution_options); + py::class_(m, "AotOptions") + .def(py::init<>()); m.doc() = "pybind11 AotOptions Python - C++ Wrapper"; m.def("AotCompileSavedModel", &tensorflow::tfrt_stub::AotCompileSavedModel, diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index 348d1b3dd903aa..488c27d8c1ae67 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -319,8 +319,7 @@ std::vector FindNamesForValidSignatures( StatusOr> ImportSavedModel( mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def, const FallbackState& fallback_state, std::string saved_model_dir, - bool import_user_signatures, bool run_placer_grappler_on_functions, - bool enable_tfrt_gpu, bool use_bridge_for_gpu) { + bool import_user_signatures, bool run_placer_grappler_on_functions) { std::vector signature_names; if (import_user_signatures) { signature_names = FindNamesForValidSignatures(meta_graph_def); @@ -338,8 +337,7 @@ StatusOr> ImportSavedModel( TF_ASSIGN_OR_RETURN(auto import_input, TfrtSavedModelMLIRImportInput::Create( fallback_state, &meta_graph_def, /*debug_info=*/{}, - run_placer_grappler_on_functions, enable_tfrt_gpu, - use_bridge_for_gpu)); + run_placer_grappler_on_functions)); TF_ASSIGN_OR_RETURN( auto module, @@ -530,9 +528,6 @@ void UpdateCompileOptions(SavedModel::Options& options) { if (options.graph_execution_options.enable_tfrt_gpu) { options.graph_execution_options.compile_options.decompose_resource_ops = false; - // TODO(b/260915352): Remove this flag and use GPU bridge by default, and - // remove the obsolete TFRT GPU runtime as well. - options.graph_execution_options.compile_options.use_bridge_for_gpu = true; } options.graph_execution_options.compile_options @@ -625,9 +620,7 @@ SavedModelImpl::LoadSavedModel(Options options, &context, meta_graph_def, *fallback_state, std::string(saved_model_dir), /*import_user_signatures=*/!options.enable_lazy_loading, - options.graph_execution_options.run_placer_grappler_on_functions, - options.graph_execution_options.enable_tfrt_gpu, - options.graph_execution_options.compile_options.use_bridge_for_gpu)); + options.graph_execution_options.run_placer_grappler_on_functions)); // TODO(b/278143179): Upload module w/o control flow. SymbolUids symbol_uids; symbol_uids.tf_symbol_uid = MaybeUploadMlirToXsymbol(mlir_module.get()); diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc index 2711a631822539..aa5fcd51e98195 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h" +#include #include #include "absl/status/status.h" @@ -33,7 +34,7 @@ limitations under the License. namespace tensorflow::tfrt_stub { -AotOptions::AotOptions() : graph_execution_options(GetGlobalRuntime()) {} +AotOptions::AotOptions() : graph_execution_options(nullptr) {} Status AotCompileSavedModel(absl::string_view input_model_dir, const AotOptions& aot_options, diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h index a3cffc385b8d7e..5547d3506e0ace 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_AOT_COMPILE_H_ #define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_AOT_COMPILE_H_ +#include #include #include "tensorflow/compiler/xla/service/compiler.h" @@ -23,8 +24,9 @@ limitations under the License. namespace tensorflow::tfrt_stub { struct AotOptions { - GraphExecutionOptions graph_execution_options; AotOptions(); + + std::unique_ptr graph_execution_options; }; // AOT Compiles saved_model in input_model_dir, writing output diff --git a/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc b/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc index bfd25ccb7699af..379ed634f98606 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc @@ -28,15 +28,12 @@ namespace tfrt_stub { StatusOr TfrtSavedModelMLIRImportInput::Create( const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info, - bool run_placer_grappler_on_nested_functions, bool enable_tfrt_gpu, - bool use_bridge_for_gpu) { + bool run_placer_grappler_on_nested_functions) { DCHECK(meta_graph_def); TfrtGraphExecutionState::Options options; options.run_placer_grappler_on_functions = run_placer_grappler_on_nested_functions; - options.enable_tfrt_gpu = enable_tfrt_gpu; - options.use_bridge_for_gpu = use_bridge_for_gpu; TF_ASSIGN_OR_RETURN( auto graph_execution_state, TfrtGraphExecutionState::Create(options, meta_graph_def->graph_def(), diff --git a/tensorflow/core/tfrt/saved_model/saved_model_import_input.h b/tensorflow/core/tfrt/saved_model/saved_model_import_input.h index 3c1b9fca053ffb..f1913359935801 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_import_input.h +++ b/tensorflow/core/tfrt/saved_model/saved_model_import_input.h @@ -33,8 +33,7 @@ class TfrtSavedModelMLIRImportInput : public SavedModelMLIRImportInput { static StatusOr Create( const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info, - bool run_placer_grappler_on_nested_functions = false, - bool enable_tfrt_gpu = false, bool use_bridge_for_gpu = false); + bool run_placer_grappler_on_nested_functions = false); TfrtSavedModelMLIRImportInput( const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info, diff --git a/tensorflow/core/tfrt/utils/BUILD b/tensorflow/core/tfrt/utils/BUILD index c1c83267233098..44e29ea288284f 100644 --- a/tensorflow/core/tfrt/utils/BUILD +++ b/tensorflow/core/tfrt/utils/BUILD @@ -198,8 +198,6 @@ cc_library( srcs = ["tfrt_graph_execution_state.cc"], hdrs = ["tfrt_graph_execution_state.h"], deps = [ - "//tensorflow/compiler/jit:common", - "//tensorflow/compiler/jit:compilation_passes", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:upgrade_graph", "//tensorflow/core:core_cpu_base", @@ -236,25 +234,14 @@ tf_cc_test( "//tensorflow/cc:array_ops", "//tensorflow/cc:cc_ops", "//tensorflow/cc:const_op", - "//tensorflow/cc:function_ops", "//tensorflow/cc:functional_ops", - "//tensorflow/cc:math_ops", - "//tensorflow/cc:resource_variable_ops", - "//tensorflow/cc:scope", - "//tensorflow/cc:sendrecv_ops", "//tensorflow/cc:while_loop", - "//tensorflow/compiler/jit:common", - "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", "//tensorflow/core:test", - "//tensorflow/core/framework:attr_value_proto_cc", "//tensorflow/core/framework:graph_proto_cc", - "//tensorflow/core/framework:node_def_proto_cc", "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/grappler/utils:grappler_test", - "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc index 22ffa9a6b32ee5..8a80444993f4ff 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc @@ -26,16 +26,10 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/types/span.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" #include "tensorflow/core/common_runtime/function_body.h" #include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/common_runtime/graph_constructor.h" -#include "tensorflow/core/common_runtime/lower_functional_ops.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/common_runtime/partitioning_utils.h" -#include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -46,7 +40,6 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" @@ -231,171 +224,6 @@ NodeDef CreateNewIdentityNode(const NodeDef& node, return identity; } -// Inlines functions into the top level graph. -Status InlineFunctions(std::unique_ptr* graph, - const DeviceSet* device_set) { - GraphOptimizationPassOptions optimization_options; - SessionOptions session_options; - // We don't lower v2 control flow to v1 for now. - session_options.config.mutable_experimental()->set_use_tfrt(true); - session_options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_do_function_inlining(true); - optimization_options.session_options = &session_options; - optimization_options.graph = graph; - optimization_options.flib_def = (*graph)->mutable_flib_def(); - optimization_options.device_set = device_set; - optimization_options.is_function_graph = false; - - LowerFunctionalOpsPass pass; - return pass.Run(optimization_options); -} - -// Assigns input/output nodes to the host. -Status PlaceInputOutputNodesOnHost(const std::vector& inputs, - const std::vector& outputs, - const Device* cpu_device, Graph* graph) { - std::unordered_map name_to_node_map = - graph->BuildNodeNameIndex(); - for (const auto& input : inputs) { - name_to_node_map.at(grappler::NodeName(input)) - ->set_assigned_device_name(cpu_device->name()); - } - - // Collect all output nodes. - absl::flat_hash_set output_nodes; - for (const auto& output : outputs) { - output_nodes.insert(name_to_node_map.at(grappler::NodeName(output))); - } - for (const auto& output_node : output_nodes) { - // Append an IdentityN node to the original output node if it is not - // assigned to the host. - if (!output_node->IsIdentity() && - output_node->type_string() != "IdentityN" && - output_node->assigned_device_name() != cpu_device->name()) { - // Rename the original output node. - std::string output_node_name = output_node->name(); - output_node->set_name(output_node_name + "/tfrt_renamed"); - - // Append an IdentityN node with the original output node name. - std::vector output_tensors; - output_tensors.reserve(output_node->num_outputs()); - for (int i = 0; i < output_node->num_outputs(); i++) { - output_tensors.push_back(NodeBuilder::NodeOut(output_node, i)); - } - TF_RETURN_IF_ERROR(NodeBuilder(output_node_name, "IdentityN") - .AssignedDevice(cpu_device->name()) - .Input(output_tensors) - .Finalize(graph, /*created_node=*/nullptr)); - } else { - output_node->set_assigned_device_name(cpu_device->name()); - } - } - return OkStatus(); -} - -Status AdjustDeviceAssignment(const std::vector& inputs, - const std::vector& outputs, - const std::vector& control_outputs, - const Device* cpu_device, Graph* graph) { - // TODO(b/232299232): We don't inline and partition v2 control flow currently. - // All ops within control flow are placed on CPU for now. Figure out a better - // way to handle v2 control flow. - for (Node* node : graph->op_nodes()) { - if (node->IsWhileNode() || node->IsIfNode()) { - LOG(WARNING) << "The control flow node " << node->name() - << " is placed on CPU."; - node->set_assigned_device_name(cpu_device->name()); - } - } - - TF_RETURN_IF_ERROR( - PlaceInputOutputNodesOnHost(inputs, outputs, cpu_device, graph)); - return OkStatus(); -} - -bool IsTpuGraph(const Graph* graph) { - static const auto* const kTpuOps = new absl::flat_hash_set{ - "TPUPartitionedCall", "TPUCompile", "TPUReplicateMetadata"}; - for (const Node* node : graph->nodes()) { - if (kTpuOps->contains(node->type_string())) { - return true; - } - } - for (const std::string& func_name : graph->flib_def().ListFunctionNames()) { - const FunctionDef* func_def = graph->flib_def().Find(func_name); - for (const NodeDef& node_def : func_def->node_def()) { - if (kTpuOps->contains(node_def.op())) return true; - } - } - return false; -} - -// Adds Send/Recv ops to `graph` for data transfer, if ops are run on different -// devices. Returns a new graph with the added Send/Recv ops. -// This is done by partitioning `graph` and add Send/Recv ops on the edges -// across devices. -StatusOr> BuildXlaOpsAndMaybeInsertTransferOps( - const std::string& graph_func_name, const FallbackState& fallback_state, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& control_outputs, - std::unique_ptr graph) { - // Skip inserting transfer ops if this is a TPU graph. - // Our stack currently cannot run the old bridge on TPU graphs, as it will - // generate ops that are not supported by the subsequent MLIR passes. - // In the case where TPU related ops are not wrapped in TPUPartitionedCall, - // running placer and partitioning on such graphs will fail. So we skip TPU - // graphs for now. - // TODO(b/228510957): In the long term, we will want a unified way for data - // transfer, i.e., using Send/Recv ops for data transfer for TPU as well. - if (IsTpuGraph(graph.get())) { - return graph; - } - - // Inline functions to facilitate partitioning nodes in the functions. - TF_RETURN_IF_ERROR(InlineFunctions(&graph, &fallback_state.device_set())); - if (VLOG_IS_ON(1)) { - DumpGraphToFile("after_inlining", *graph); - } - - // Replace the StatefulPartitionedCall op that should be compiled to an - // XlaLaunch op. - // TODO(b/239089915): Clean this up after the logic is implemented in TFXLA - // bridge. - TF_RETURN_IF_ERROR(BuildXlaLaunchOps(graph.get())); - if (VLOG_IS_ON(1)) { - DumpGraphToFile("after_build_xla_launch", *graph); - } - - // Run placer. - const Device* cpu_device = fallback_state.device_manager().HostCPU(); - if (cpu_device == nullptr) { - return errors::Internal("No CPU device found."); - } - Placer placer(graph.get(), /*function_name=*/"", &graph->flib_def(), - &fallback_state.device_set(), cpu_device, - /*allow_soft_placement=*/true, - /*log_device_placement=*/false); - TF_RETURN_IF_ERROR(placer.Run()); - if (VLOG_IS_ON(1)) { - DumpGraphToFile("after_placer", *graph); - } - - TF_RETURN_IF_ERROR(AdjustDeviceAssignment(inputs, outputs, control_outputs, - cpu_device, graph.get())); - - // Insert send/recv ops to the graph. - TF_ASSIGN_OR_RETURN( - std::unique_ptr new_graph, - InsertTransferOps(fallback_state.device_set(), std::move(graph))); - if (VLOG_IS_ON(1)) { - DumpGraphToFile("after_transfer_ops_insertion", *new_graph); - } - - return new_graph; -} - } // namespace StatusOr @@ -463,22 +291,6 @@ TfrtGraphExecutionState::CreateOptimizedGraph( result.grappler_duration = absl::Now() - grappler_start_time; - if (options_.enable_tfrt_gpu && !options_.use_bridge_for_gpu) { - TF_ASSIGN_OR_RETURN( - result.graph, - BuildXlaOpsAndMaybeInsertTransferOps( - graph_import_config.graph_func_name, fallback_state_, inputs, - graph_import_config.outputs, graph_import_config.control_outputs, - std::move(result.graph))); - - // Update `control_outputs` as there might be newly added Send ops. - for (const Node* node : result.graph->nodes()) { - if (node->IsSend()) { - graph_import_config.control_outputs.push_back(node->name()); - } - } - } - return result; } @@ -865,41 +677,5 @@ TfrtGraphExecutionState::OptimizeGraph( return optimized_graph; } -// TODO(b/239089915): Clean this up after the logic is implemented in TFXLA -// bridge. -Status BuildXlaLaunchOps(Graph* graph) { - const auto is_xla_launch_node = [](const Node& n) -> StatusOr { - if (!n.IsPartitionedCall()) { - return false; - } - bool xla_must_compile = false; - const bool has_attribute = - TryGetNodeAttr(n.attrs(), kXlaMustCompileAttr, &xla_must_compile); - return has_attribute && xla_must_compile; - }; - - const auto get_xla_function_info = [](const Node& launch) - -> StatusOr { - EncapsulateXlaComputationsPass::XlaFunctionInfo result; - std::vector tin_dtypes; - TF_RETURN_IF_ERROR(GetNodeAttr(launch.def(), "Tin", &tin_dtypes)); - int variable_start_index = 0; - for (; variable_start_index < tin_dtypes.size(); ++variable_start_index) { - if (tin_dtypes.at(variable_start_index) == DT_RESOURCE) break; - } - result.variable_start_index = variable_start_index; - - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(launch.attrs(), "f", &func)); - result.function_name = func.name(); - - return result; - }; - - return EncapsulateXlaComputationsPass::BuildXlaLaunchOps( - graph, is_xla_launch_node, get_xla_function_info, - /*add_edges_to_output_of_downstream_nodes=*/false); -} - } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h index d592d857fd6769..e347412ec532f6 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h @@ -52,10 +52,6 @@ class TfrtGraphExecutionState { struct Options { bool run_placer_grappler_on_functions = false; - // TODO(b/262826012): Remove the flag after we switch to using bridge. - bool enable_tfrt_gpu = false; - // TODO(b/260915352): Remove the flag and default to using bridge. - bool use_bridge_for_gpu = false; bool run_placer_on_graph = true; }; @@ -138,12 +134,6 @@ Status EliminateRefVariablesFromV1ControlFlow(GraphDef& graph_def); // Removes the "_input_shapes" attribute of functions in the graph. void RemoveInputShapesInFunctions(tensorflow::GraphDef& graph_def); -// Replaces partitioned calls in the graph that have _XlaMustCompile attribute -// set to true with XlaLaunch op. -// TODO(b/239089915): Clean this up after the logic is implemented in TFXLA -// bridge. -Status BuildXlaLaunchOps(Graph* graph); - } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc index aa99c168ebd1c1..e16b941cc46c4a 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc @@ -21,32 +21,18 @@ limitations under the License. #include #include -#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/functional_ops.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/cc/ops/resource_variable_ops.h" -#include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/while_loop.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" -#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_factory.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/graph_to_functiondef.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/grappler/utils/grappler_test.h" -#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" -#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace tfrt_stub { @@ -761,481 +747,6 @@ TEST_F(ExtendGraphTest, ExtendGraph) { CompareGraphs(expected, *graph_execution_state->original_graph_def()); } -// An auxiliary struct to verify the graph after partitioning and inserting -// transfer ops. -struct GraphInfo { - NodeDef* input_node = nullptr; - NodeDef* output_node = nullptr; - NodeDef* stateful_partitioned_call_node = nullptr; - std::vector partitioned_call_nodes; - std::vector fdefs; -}; - -class InsertTransferOpsTest : public grappler::GrapplerTest { - protected: - void SetUp() override { - SessionOptions options; - auto* device_count = options.config.mutable_device_count(); - device_count->insert({"CPU", 2}); - std::vector> devices; - TF_ASSERT_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0", - &devices)); - device0_ = devices[0].get(); - device1_ = devices[1].get(); - - fallback_state_ = - std::make_unique(options, std::move(devices), fdef_lib_); - } - - GraphInfo GetGraphInfo(const std::string& input, const std::string& output, - GraphDef& graphdef) { - GraphInfo graph_info; - for (NodeDef& node : *graphdef.mutable_node()) { - if (node.op() == "PartitionedCall") { - graph_info.partitioned_call_nodes.push_back(&node); - } else if (node.op() == "StatefulPartitionedCall") { - graph_info.stateful_partitioned_call_node = &node; - } else if (node.name() == input) { - graph_info.input_node = &node; - } else if (node.name() == output) { - graph_info.output_node = &node; - } - } - - // Find the corresponding function called by the PartitionedCall nodes. - absl::flat_hash_map func_name_to_func; - for (const FunctionDef& fdef : graphdef.library().function()) { - func_name_to_func[fdef.signature().name()] = fdef; - } - for (NodeDef* node : graph_info.partitioned_call_nodes) { - CHECK(node->attr().contains("f")); - CHECK(func_name_to_func.contains(node->attr().at("f").func().name())); - const FunctionDef& fdef = - func_name_to_func.at(node->attr().at("f").func().name()); - graph_info.fdefs.push_back(fdef); - } - return graph_info; - } - - std::unique_ptr fallback_state_; - Device* device0_ = nullptr; // Not owned. - Device* device1_ = nullptr; // Not owned. - tensorflow::FunctionDefLibrary fdef_lib_; -}; - -TEST_F(InsertTransferOpsTest, InsertTransferOps) { - GraphDef graphdef; - { - Scope scope = Scope::NewRootScope(); - Scope scope1 = scope.WithDevice(device0_->name()); - Scope scope2 = scope.WithDevice(device1_->name()); - - // A graph whose nodes are on different devices. - // a(Const, on device0) -> b(Abs, on device1) -> c(Identity, on device0) - Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1}); - Output b = ops::Abs(scope2.WithOpName("b"), a); - Output c = ops::Identity(scope1.WithOpName("c"), b); - - // Before partitioning, there is no send/recv nodes. - int send_count = 0, recv_count = 0; - for (const auto* op : scope.graph()->op_nodes()) { - if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - } - ASSERT_EQ(scope.graph()->num_op_nodes(), 3); - ASSERT_EQ(send_count, 0); - ASSERT_EQ(recv_count, 0); - - TF_ASSERT_OK(scope.ToGraphDef(&graphdef)); - } - - TfrtGraphExecutionState::Options options; - options.run_placer_grappler_on_functions = false; - options.enable_tfrt_gpu = true; - TF_ASSERT_OK_AND_ASSIGN( - auto graph_execution_state, - TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_)); - - tensorflow::GraphImportConfig graph_import_config; - graph_import_config.prune_unused_nodes = true; - graph_import_config.enable_shape_inference = false; - tensorflow::ArrayInfo array_info; - array_info.imported_dtype = DT_FLOAT; - array_info.shape.set_unknown_rank(true); - graph_import_config.inputs["a"] = array_info; - graph_import_config.outputs = {"c"}; - - TF_ASSERT_OK_AND_ASSIGN( - auto optimized_graph, - graph_execution_state->CreateOptimizedGraph(graph_import_config)); - - // Verify that two paris of Send/Recv nodes are added. - int send_count = 0, recv_count = 0; - for (const auto* op : optimized_graph.graph->op_nodes()) { - if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - } - EXPECT_EQ(optimized_graph.graph->num_op_nodes(), 7); - EXPECT_EQ(send_count, 2); - EXPECT_EQ(recv_count, 2); -} - -TEST_F(InsertTransferOpsTest, InsertTransferOpsWithFunctionInlining) { - GraphDef graphdef; - { - Scope scope = Scope::NewRootScope(); - Scope scope1 = scope.WithDevice(device0_->name()); - Scope scope2 = scope.WithDevice(device1_->name()); - - // A graph whose nodes are on different devices. - // a(Const, on device0) -> b(PartitionedCall) -> c(Identity, on device0) - // where PartitionedCall invokes a function with two nodes assigned to - // different devices. - const Tensor kThree = test::AsScalar(3.0); - auto fdef = tensorflow::FunctionDefHelper::Create( - "_Pow3", {"x: float"}, {"y: float"}, {}, - {// The two nodes in the function are assigned to different devices. - {{"three"}, - "Const", - {}, - {{"dtype", DT_FLOAT}, {"value", kThree}}, - /*dep=*/{}, - device0_->name()}, - {{"pow3"}, - "Pow", - {"x", "three:output:0"}, - {{"T", DT_FLOAT}}, - /*dep=*/{}, - device1_->name()}}, - {{"y", "pow3:z:0"}}); - - tensorflow::FunctionDefLibrary fdef_lib; - *fdef_lib.add_function() = fdef; - TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); - - Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1}); - - std::vector inputs = {a}; - std::vector output_dtypes = { - fdef.signature().output_arg(0).type()}; - tensorflow::NameAttrList func_attr; - func_attr.set_name(fdef.signature().name()); - auto pcall = ops::PartitionedCall(scope2, inputs, output_dtypes, func_attr); - Output b = pcall.output.front(); - - Output c = ops::Identity(scope1.WithOpName("c"), b); - - TF_ASSERT_OK(scope.ToGraphDef(&graphdef)); - - // Before partitioning, there is no send/recv nodes. - int partitioned_call_count = 0, mul_count = 0, send_count = 0, - recv_count = 0; - for (const auto* op : scope.graph()->op_nodes()) { - if (op->IsPartitionedCall()) - ++partitioned_call_count; - else if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - else if (op->type_string() == "Mul") - ++mul_count; - } - ASSERT_EQ(partitioned_call_count, 1); - ASSERT_EQ(send_count, 0); - ASSERT_EQ(recv_count, 0); - ASSERT_EQ(mul_count, 0); - } - - TfrtGraphExecutionState::Options options; - options.run_placer_grappler_on_functions = false; - options.enable_tfrt_gpu = true; - TF_ASSERT_OK_AND_ASSIGN( - auto graph_execution_state, - TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_)); - - tensorflow::GraphImportConfig graph_import_config; - graph_import_config.prune_unused_nodes = true; - graph_import_config.enable_shape_inference = false; - tensorflow::ArrayInfo array_info; - array_info.imported_dtype = DT_FLOAT; - array_info.shape.set_unknown_rank(true); - graph_import_config.inputs["a"] = array_info; - graph_import_config.outputs = {"c"}; - - TF_ASSERT_OK_AND_ASSIGN( - auto optimized_graph, - graph_execution_state->CreateOptimizedGraph(graph_import_config)); - - // Verify that the resultant graph has no PartitionedCall ops, function body - // is inlined into the main graph, and send/recv ops are added. - int partitioned_call_count = 0, mul_count = 0, send_count = 0, recv_count = 0; - for (const auto* op : optimized_graph.graph->op_nodes()) { - if (op->IsPartitionedCall()) - ++partitioned_call_count; - else if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - else if (op->type_string() == "Mul") - ++mul_count; - } - - EXPECT_EQ(partitioned_call_count, 0); - EXPECT_EQ(send_count, 2); - EXPECT_EQ(recv_count, 2); - EXPECT_EQ(mul_count, 1); -} - -TEST_F(InsertTransferOpsTest, AppendIdentityN) { - GraphDef graphdef; - { - Scope scope = Scope::NewRootScope(); - Scope scope1 = scope.WithDevice(device0_->name()); - Scope scope2 = scope.WithDevice(device1_->name()); - - // A graph with two nodes assigned on different devices. - // a(Const, on device0) -> b(Abs, on device1) - Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1}); - Output b = ops::Abs(scope2.WithOpName("b"), a); - - TF_ASSERT_OK(scope.ToGraphDef(&graphdef)); - - // There is no IdentityN/Send/Recv nodes originally. - int identity_count = 0, abs_count = 0, const_count = 0, send_count = 0, - recv_count = 0; - for (const auto* op : scope.graph()->op_nodes()) { - if (op->type_string() == "IdentityN") - ++identity_count; - else if (op->IsConstant()) - ++const_count; - else if (op->type_string() == "Abs") - ++abs_count; - else if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - } - ASSERT_EQ(scope.graph()->num_op_nodes(), 2); - ASSERT_EQ(identity_count, 0); - ASSERT_EQ(const_count, 1); - ASSERT_EQ(abs_count, 1); - ASSERT_EQ(send_count, 0); - ASSERT_EQ(recv_count, 0); - } - TfrtGraphExecutionState::Options options; - options.run_placer_grappler_on_functions = false; - options.enable_tfrt_gpu = true; - TF_ASSERT_OK_AND_ASSIGN( - auto graph_execution_state, - TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_)); - - tensorflow::GraphImportConfig graph_import_config; - graph_import_config.prune_unused_nodes = true; - graph_import_config.enable_shape_inference = false; - tensorflow::ArrayInfo array_info; - array_info.imported_dtype = DT_FLOAT; - array_info.shape.set_unknown_rank(true); - graph_import_config.inputs["a"] = array_info; - graph_import_config.outputs = {"b"}; - - TF_ASSERT_OK_AND_ASSIGN( - auto optimized_graph, - graph_execution_state->CreateOptimizedGraph(graph_import_config)); - GraphDef optimized_graphdef; - optimized_graph.graph->ToGraphDef(&optimized_graphdef); - - // Verify that IdentityN/Send/Recv nodes are added. - int identity_count = 0, abs_count = 0, const_count = 0, send_count = 0, - recv_count = 0; - for (const auto* op : optimized_graph.graph->op_nodes()) { - if (op->type_string() == "IdentityN") - ++identity_count; - else if (op->IsConstant()) - ++const_count; - else if (op->type_string() == "Abs") - ++abs_count; - else if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - } - EXPECT_EQ(optimized_graph.graph->num_op_nodes(), 7); - EXPECT_EQ(identity_count, 1); - EXPECT_EQ(const_count, 1); - EXPECT_EQ(abs_count, 1); - EXPECT_EQ(send_count, 2); - EXPECT_EQ(recv_count, 2); -} - -std::unique_ptr MakeOuterGraph(const FunctionLibraryDefinition& flib_def, - const std::string& function_name) { - Scope scope = Scope::NewRootScope().ExitOnError(); - TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); - - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - std::vector func_inputs; - func_inputs.push_back( - tensorflow::NodeDefBuilder::NodeOut(a.node()->name(), 0, DT_INT32)); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(b.node()->name(), 0, - b.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(c.node()->name(), 0, - c.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(d.node()->name(), 0, - d.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(u.node()->name(), 0, - u.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(v.node()->name(), 0, - v.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(w.node()->name(), 0, - w.output.type())); - - std::vector input_dtypes; - for (const NodeDefBuilder::NodeOut& func_input : func_inputs) { - input_dtypes.push_back(func_input.data_type); - } - - std::vector output_dtypes = {DT_FLOAT, DT_INT32, DT_FLOAT, - DT_FLOAT}; - - NameAttrList f; - f.set_name(function_name); - - NodeDef def; - TF_CHECK_OK(NodeDefBuilder("xla_call_0", "StatefulPartitionedCall", &flib_def) - .Input(func_inputs) - .Attr("Tin", input_dtypes) - .Attr("Tout", output_dtypes) - .Attr("f", f) - .Device("/gpu:0") - .Attr(kXlaMustCompileAttr, true) - .Finalize(&def)); - - Status status; - Node* launch = scope.graph()->AddNode(def, &status); - TF_CHECK_OK(status); - TF_CHECK_OK(scope.DoShapeInference(launch)); - scope.graph()->AddEdge(a.node(), 0, launch, 0); - scope.graph()->AddEdge(b.node(), 0, launch, 1); - scope.graph()->AddEdge(c.node(), 0, launch, 2); - scope.graph()->AddEdge(d.node(), 0, launch, 3); - scope.graph()->AddEdge(u.node(), 0, launch, 4); - scope.graph()->AddEdge(v.node(), 0, launch, 5); - scope.graph()->AddEdge(w.node(), 0, launch, 6); - - auto consumer0_a = - ops::Identity(scope.WithOpName("consumer0_a"), Output(launch, 0)); - auto consumer0_b = - ops::Identity(scope.WithOpName("consumer0_b"), Output(launch, 0)); - auto consumer0_c = - ops::Identity(scope.WithOpName("consumer0_c"), Output(launch, 0)); - auto consumer1 = - ops::Identity(scope.WithOpName("consumer1"), Output(launch, 1)); - auto consumer2 = - ops::Identity(scope.WithOpName("consumer2"), Output(launch, 2)); - auto consumer3 = - ops::Identity(scope.WithOpName("consumer3"), Output(launch, 3)); - - std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -// Makes an encapsulate body graph for use in tests. -std::unique_ptr MakeBodyGraph() { - Scope scope = Scope::NewRootScope().ExitOnError(); - - auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); - auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); - - auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); - auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); - auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); - - auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); - auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); - auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); - auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); - - auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); - auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); - auto g = ops::Add(scope.WithOpName("G"), f, arg3); - - auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), - b_identity, 0); - auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); - auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); - auto out3 = - ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); - - std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -TEST(BuildXlaOpsTest, BuildXlaLaunchOp) { - std::unique_ptr body_graph = MakeBodyGraph(); - FunctionDefLibrary flib; - TF_ASSERT_OK( - GraphToFunctionDef(*body_graph, "xla_func_0", flib.add_function())); - - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - - std::unique_ptr graph = MakeOuterGraph(flib_def, "xla_func_0"); - TF_ASSERT_OK(BuildXlaLaunchOps(graph.get())); - - Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); - TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); - - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - NameAttrList function; - function.set_name("xla_func_0"); - auto launch = ops::XlaLaunch( - scope.WithOpName("xla_call_0").WithDevice("/gpu:0"), - std::initializer_list{}, std::initializer_list{a, b, c, d}, - std::initializer_list{u, v, w}, - DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); - - auto consumer0_a = - ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); - auto consumer0_b = - ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); - auto consumer0_c = - ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); - auto consumer1 = - ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); - auto consumer2 = - ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); - auto consumer3 = - ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); - - GraphDef expected_def; - TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); - - GraphDef actual_def; - graph->ToGraphDef(&actual_def); - TF_EXPECT_GRAPH_EQ(expected_def, actual_def); -} - } // namespace } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 096caff26ef884..1005b9d061b5fa 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -12,6 +12,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/tf2xla:__subpackages__", + "//tensorflow/compiler/xla:__subpackages__", "//tensorflow/compiler/xrt:__subpackages__", "//tensorflow/core/tpu:__subpackages__", "//tensorflow/dtensor:__subpackages__", diff --git a/tensorflow/core/transforms/BUILD b/tensorflow/core/transforms/BUILD index cd20ec56853754..bc4d89f8aafbeb 100644 --- a/tensorflow/core/transforms/BUILD +++ b/tensorflow/core/transforms/BUILD @@ -1,10 +1,10 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", "//tensorflow/tools/tfg_graph_transforms:__subpackages__", diff --git a/tensorflow/core/transforms/cf_sink/BUILD b/tensorflow/core/transforms/cf_sink/BUILD index e5a77916b5a1cb..6cf78e4eeee0c9 100644 --- a/tensorflow/core/transforms/cf_sink/BUILD +++ b/tensorflow/core/transforms/cf_sink/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/consolidate_attrs/BUILD b/tensorflow/core/transforms/consolidate_attrs/BUILD index 50525a5058200e..5558172c0669e2 100644 --- a/tensorflow/core/transforms/consolidate_attrs/BUILD +++ b/tensorflow/core/transforms/consolidate_attrs/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/const_dedupe_hoist/BUILD b/tensorflow/core/transforms/const_dedupe_hoist/BUILD index b6a81a5a93f848..381b666a80a711 100644 --- a/tensorflow/core/transforms/const_dedupe_hoist/BUILD +++ b/tensorflow/core/transforms/const_dedupe_hoist/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/constant_folding/BUILD b/tensorflow/core/transforms/constant_folding/BUILD index 1b5b0fb43f4c34..e64e9d868f2677 100644 --- a/tensorflow/core/transforms/constant_folding/BUILD +++ b/tensorflow/core/transforms/constant_folding/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/cse/BUILD b/tensorflow/core/transforms/cse/BUILD index a6c6914204cd8f..6a4dd774bbcbbc 100644 --- a/tensorflow/core/transforms/cse/BUILD +++ b/tensorflow/core/transforms/cse/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/drop_unregistered_attribute/BUILD b/tensorflow/core/transforms/drop_unregistered_attribute/BUILD index 98a5fe7d236f19..73cc8918341602 100644 --- a/tensorflow/core/transforms/drop_unregistered_attribute/BUILD +++ b/tensorflow/core/transforms/drop_unregistered_attribute/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD b/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD index fe69bb24386eb2..b19c211a8abdcb 100644 --- a/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD +++ b/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/func_to_graph/BUILD b/tensorflow/core/transforms/func_to_graph/BUILD index 4cd2e365f3d384..0c62a5a2f90894 100644 --- a/tensorflow/core/transforms/func_to_graph/BUILD +++ b/tensorflow/core/transforms/func_to_graph/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/functional_to_region/BUILD b/tensorflow/core/transforms/functional_to_region/BUILD index 14addc62ce7b47..428c83441aad56 100644 --- a/tensorflow/core/transforms/functional_to_region/BUILD +++ b/tensorflow/core/transforms/functional_to_region/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/graph_compactor/BUILD b/tensorflow/core/transforms/graph_compactor/BUILD index 360246876f52df..35635b3e0169e5 100644 --- a/tensorflow/core/transforms/graph_compactor/BUILD +++ b/tensorflow/core/transforms/graph_compactor/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/graph_to_func/BUILD b/tensorflow/core/transforms/graph_to_func/BUILD index c4bcc7fb83bac2..69023cc514fb83 100644 --- a/tensorflow/core/transforms/graph_to_func/BUILD +++ b/tensorflow/core/transforms/graph_to_func/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/legacy_call/BUILD b/tensorflow/core/transforms/legacy_call/BUILD index c010fdb7333637..1784c67edc712d 100644 --- a/tensorflow/core/transforms/legacy_call/BUILD +++ b/tensorflow/core/transforms/legacy_call/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/region_to_functional/BUILD b/tensorflow/core/transforms/region_to_functional/BUILD index b49cb34ce65075..becc78b878bd8d 100644 --- a/tensorflow/core/transforms/region_to_functional/BUILD +++ b/tensorflow/core/transforms/region_to_functional/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/remapper/BUILD b/tensorflow/core/transforms/remapper/BUILD index 4a3ab37da4294b..a75461d412b23b 100644 --- a/tensorflow/core/transforms/remapper/BUILD +++ b/tensorflow/core/transforms/remapper/BUILD @@ -1,10 +1,10 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/shape_inference/BUILD b/tensorflow/core/transforms/shape_inference/BUILD index c1fd69fbe2b619..d9eb50e9762b2a 100644 --- a/tensorflow/core/transforms/shape_inference/BUILD +++ b/tensorflow/core/transforms/shape_inference/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core/transforms:__subpackages__", ], diff --git a/tensorflow/core/transforms/toposort/BUILD b/tensorflow/core/transforms/toposort/BUILD index 7b39cc616a414d..be03bed1002e0b 100644 --- a/tensorflow/core/transforms/toposort/BUILD +++ b/tensorflow/core/transforms/toposort/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/compiler:__subpackages__", "//tensorflow/core:__subpackages__", diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 7076340d5a5df8..f7436f0a35c7b1 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -163,7 +163,6 @@ filegroup( "matmul_autotune.h", "matmul_bcast.h", "mirror_pad_mode.h", - "mkl_threadpool.h", "mkl_util.h", "onednn_env_vars.h", "overflow.h", @@ -296,9 +295,9 @@ filegroup( filegroup( name = "mkl_util_hdrs", srcs = [ - "mkl_threadpool.h", "mkl_util.h", "onednn_env_vars.h", + "//tensorflow/tsl/util:onednn_util_hdrs", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 322991376f9924..4eaee7d13884ee 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -36,13 +36,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/env_var.h" -#include "tensorflow/core/util/mkl_threadpool.h" #include "tensorflow/core/util/onednn_env_vars.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #ifdef DNNL_AARCH64_USE_ACL #include "tensorflow/core/platform/mutex.h" #endif +#include "tensorflow/tsl/util/onednn_threadpool.h" using dnnl::engine; using dnnl::memory; @@ -274,7 +274,7 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) { return true; } -inline dnnl::stream* CreateStream(MklDnnThreadPool* eigen_tp, +inline dnnl::stream* CreateStream(tsl::OneDnnThreadPool* eigen_tp, const engine& engine) { #ifndef ENABLE_ONEDNN_OPENMP if (eigen_tp != nullptr) { @@ -649,6 +649,13 @@ class MklDnnShape { } }; +inline Eigen::ThreadPoolInterface* EigenThreadPoolFromTfContext( + OpKernelContext* context) { + return context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); +} + // List of MklShape objects. Used in Concat/Split layers. typedef std::vector MklDnnShapeList; @@ -663,9 +670,14 @@ inline void ExecutePrimitive(const std::vector& net, DCHECK(net_args); DCHECK_EQ(net.size(), net_args->size()); std::unique_ptr cpu_stream; - MklDnnThreadPool eigen_tp; + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { - eigen_tp = MklDnnThreadPool(context); + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + eigen_tp = + tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); } else { cpu_stream.reset(CreateStream(nullptr, cpu_engine)); @@ -706,8 +718,8 @@ inline Status ConvertMklToTF(OpKernelContext* context, bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor, net, net_args, cpu_engine); if (!status) { - return Status(absl::StatusCode::kInternal, - "ConvertMklToTF(): Failed to create reorder for input"); + return absl::InternalError( + "ConvertMklToTF(): Failed to create reorder for input"); } ExecutePrimitive(net, &net_args, cpu_engine, context); } else { @@ -715,8 +727,7 @@ inline Status ConvertMklToTF(OpKernelContext* context, bool status = output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape); if (!status) { - return Status( - absl::StatusCode::kInternal, + return absl::InternalError( "ConvertMklToTF(): Failed to forward input tensor to output"); } } @@ -1114,8 +1125,7 @@ inline memory::format_tag MklTensorFormatToMklDnnDataFormat( inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) { if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC; if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW; - TF_CHECK_OK( - Status(absl::StatusCode::kInvalidArgument, "Unsupported data format")); + TF_CHECK_OK(absl::InvalidArgumentError("Unsupported data format")); return MklTensorFormat::FORMAT_INVALID; } @@ -1127,8 +1137,7 @@ inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) { inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) { if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC; if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW; - TF_CHECK_OK( - Status(absl::StatusCode::kInvalidArgument, "Unsupported data format")); + TF_CHECK_OK(absl::InvalidArgumentError("Unsupported data format")); return MklTensorFormat::FORMAT_INVALID; } @@ -1144,8 +1153,7 @@ inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) { if (format == MklTensorFormat::FORMAT_NCHW || format == MklTensorFormat::FORMAT_NCDHW) return FORMAT_NCHW; - TF_CHECK_OK( - Status(absl::StatusCode::kInvalidArgument, "Unsupported data format")); + TF_CHECK_OK(absl::InvalidArgumentError("Unsupported data format")); // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure // that we don't come here. @@ -1311,10 +1319,9 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim, } catch (dnnl::error& e) { delete[] input_dims; delete[] input_strides; - return Status(absl::StatusCode::kInternal, - tensorflow::strings::StrCat( - "Failed to create blocked memory descriptor.", - "Status: ", e.status, ", message: ", e.message)); + return absl::InternalError( + absl::StrCat("Failed to create blocked memory descriptor.", + "Status: ", e.status, ", message: ", e.message)); } return OkStatus(); } @@ -1322,11 +1329,22 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim, inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, const memory& src_mem, const memory& dst_mem, const engine& engine, - OpKernelContext* ctx = nullptr) { + OpKernelContext* ctx = nullptr, + memory* scale_mem = nullptr) { std::vector net; net.push_back(dnnl::reorder(reorder_desc)); std::vector net_args; +#ifndef ENABLE_ONEDNN_V3 net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); +#else + if (scale_mem != nullptr) { + net_args.push_back({{DNNL_ARG_FROM, src_mem}, + {DNNL_ARG_TO, dst_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *scale_mem}}); + } else { + net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); + } +#endif // !ENABLE_ONEDNN_V3 ExecutePrimitive(net, &net_args, engine, ctx); } @@ -1596,9 +1614,12 @@ class MklDnnData { reorder_memory_ = new memory(op_md, engine); auto* prim = FindOrCreateReorder(user_memory_, reorder_memory_); std::shared_ptr cpu_stream; - MklDnnThreadPool eigen_tp; + tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { - eigen_tp = MklDnnThreadPool(context); + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + eigen_tp = + tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); } else { cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); @@ -1663,9 +1684,12 @@ class MklDnnData { reorder_memory_ = new memory(op_md, engine, reorder_data_handle); auto* prim = FindOrCreateReorder(user_memory_, reorder_memory_); std::shared_ptr cpu_stream; - MklDnnThreadPool eigen_tp; + tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { - eigen_tp = MklDnnThreadPool(context); + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(context); + eigen_tp = + tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); } else { cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); @@ -1774,9 +1798,12 @@ class MklDnnData { net_args.push_back( {{DNNL_ARG_FROM, *reorder_memory_}, {DNNL_ARG_TO, *user_memory_}}); std::shared_ptr cpu_stream; - MklDnnThreadPool eigen_tp; + tsl::OneDnnThreadPool eigen_tp; if (ctx != nullptr) { - eigen_tp = MklDnnThreadPool(ctx); + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + eigen_tp = + tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); } else { cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); diff --git a/tensorflow/dtensor/cc/constants.h b/tensorflow/dtensor/cc/constants.h index 9dea0928fe4e1d..b14ea2438af7f2 100644 --- a/tensorflow/dtensor/cc/constants.h +++ b/tensorflow/dtensor/cc/constants.h @@ -58,6 +58,8 @@ static constexpr char kNewResourceLayoutIndices[] = // Attribute carries layout for newly inferred layout of resource handle. static constexpr char kNewResourceArgLayouts[] = "_inferred_resource_layouts"; +static constexpr char kNumLocalOutputsAttr[] = "_num_local_outputs"; + // Attribute carries input layout information for shape op. static constexpr char kShapeOpInputLayout[] = "_shape_input_layout"; @@ -136,6 +138,8 @@ static constexpr int kSparseTensorNum = 3; // Attribute which stores the environment variable value for all_reduce // optimization group size: DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_GROUP_SIZE. +// This represents the maximum number of AllReduce ops to merge into one op. It +// is a determining factor used during dtensor_allreduce_combine_optimization. static constexpr char kAllReduceNumOpsInGroup[] = "dtensor.all_reduce_combiner.num_ops_in_group"; @@ -144,6 +148,14 @@ static constexpr char kAllReduceNumOpsInGroup[] = static constexpr char kEnableMultiDeviceMode[] = "dtensor.enable_multi_device_mode"; +// Attribute which stores the environment variable value for all_reduce +// optimization group size: DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_GROUP_SIZE. +// This represents the maximum distance between two AllReduce on the compute +// graph in terms of topological level. It is a determining factor used during +// dtensor_allreduce_combine_optimization. +static constexpr char kAllReduceTopologicalDistance[] = + "dtensor.all_reduce_combiner.topological_distance"; + } // namespace dtensor } // namespace tensorflow diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index d413b37141c3dc..deb26a332d7581 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -1839,8 +1839,8 @@ void DTensorDevice::ExecuteMultiDeviceOperation( int output_offset = 0; for (int i = 0; i < num_output_layouts; i++) { const Layout& output_layout = function.output_layouts[i]; + const int num_devices = function.num_local_outputs[i]; std::vector layout_outputs; - const int num_devices = output_layout.num_devices(); for (int j = 0; j < num_devices; j++) { const int output_idx = output_offset + j; layout_outputs.emplace_back(std::move(eager_outputs[output_idx])); diff --git a/tensorflow/dtensor/cc/dtensor_device_util.cc b/tensorflow/dtensor/cc/dtensor_device_util.cc index 319fc9bc521f6c..8cc8fa32c6277c 100644 --- a/tensorflow/dtensor/cc/dtensor_device_util.cc +++ b/tensorflow/dtensor/cc/dtensor_device_util.cc @@ -922,6 +922,21 @@ Status PrepareGraphForMlir( return OkStatus(); } +StatusOr> GetNumLocalOutputs(Node* node) { + const AttrValue* num_local_outputs = + (node->attrs()).Find(kNumLocalOutputsAttr); + if (num_local_outputs == nullptr) { + return absl::InvalidArgumentError("missing num_local_outputs attribute"); + } else { + const AttrValue_ListValue& list = num_local_outputs->list(); + std::vector res; + res.reserve(list.i_size()); + std::copy(list.i().begin(), list.i().end(), std::back_inserter(res)); + return res; + } +} + +namespace { Status SetMultiDeviceFunctionOutputs( TranslatedFunction& function, Node* node, const std::vector& global_output_shapes) { @@ -929,6 +944,8 @@ Status SetMultiDeviceFunctionOutputs( if (serialized_layouts == nullptr) { return absl::InvalidArgumentError("missing layout attribute"); } + TF_ASSIGN_OR_RETURN(std::vector num_local_outputs, + GetNumLocalOutputs(node)); const auto& serialized_layout_list = serialized_layouts->list(); for (int i = 0; i < serialized_layout_list.s_size(); i++) { const auto& serialized_layout = serialized_layout_list.s(i); @@ -936,17 +953,26 @@ Status SetMultiDeviceFunctionOutputs( Layout::FromString(serialized_layout)); function.output_layouts.emplace_back(std::move(layout)); } - for (int i = 0; i < function.output_layouts.size(); i++) { - const Layout& output_layout = function.output_layouts[i]; + int num_output_layouts = function.output_layouts.size(); + for (int i = 0; i < num_output_layouts; i++) { + const Layout* output_layout = &(function.output_layouts[i]); + if (output_layout->IsEmpty()) { + const auto search = function.resource_input_layouts.find(i); + if (search != function.resource_input_layouts.end()) { + output_layout = &(search->second); + } + } PartialTensorShape local_shape = - output_layout.LocalShapeFromGlobalShape(global_output_shapes[i]); - const int num_devices = output_layout.num_devices(); + output_layout->LocalShapeFromGlobalShape(global_output_shapes[i]); + const int64_t num_devices = num_local_outputs[i]; for (int j = 0; j < num_devices; j++) { function.local_output_shapes.emplace_back(local_shape); } } + function.num_local_outputs = std::move(num_local_outputs); return OkStatus(); } +} // namespace // Returns set of functions to run to execute DTensor computation. StatusOr IdentifyAllFunctionsToExecute( @@ -1033,6 +1059,7 @@ StatusOr IdentifyAllFunctionsToExecute( function.local_output_shapes.emplace_back( output_layout.LocalShapeFromGlobalShape( global_output_shapes[global_index])); + function.num_local_outputs.emplace_back(1); } } diff --git a/tensorflow/dtensor/cc/dtensor_device_util.h b/tensorflow/dtensor/cc/dtensor_device_util.h index 00ef825c92ca26..6bc5b2c4295aa9 100644 --- a/tensorflow/dtensor/cc/dtensor_device_util.h +++ b/tensorflow/dtensor/cc/dtensor_device_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_ #include +#include #include #include #include @@ -114,6 +115,8 @@ struct TranslatedFunction { std::vector local_output_shapes; // Output data types. std::vector output_dtypes; + // Number of local outputs for each layout. + std::vector num_local_outputs; }; struct ExecutionFunctions { diff --git a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc index 1e207cbca6fe37..28691771e2376e 100644 --- a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc +++ b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc @@ -128,6 +128,12 @@ DTensorMlirPassRunner::ImportGraphToMlir( mlir::IntegerAttr::get(mlir::IntegerType::get(&context_, /*width=*/64), group_size)); + int topo_dist = dtensor::AllReduceCombineOptimizationTopologicalDistance(); + module->setAttr( + dtensor::kAllReduceTopologicalDistance, + mlir::IntegerAttr::get(mlir::IntegerType::get(&context_, /*width=*/64), + topo_dist)); + if (dtensor::EnableMultiDeviceMode()) { module->setAttr(dtensor::kEnableMultiDeviceMode, mlir::BoolAttr::get(&context_, true)); diff --git a/tensorflow/dtensor/cc/dtensor_utils.cc b/tensorflow/dtensor/cc/dtensor_utils.cc index 3c915c4b25c845..b22c9ad5b0023d 100644 --- a/tensorflow/dtensor/cc/dtensor_utils.cc +++ b/tensorflow/dtensor/cc/dtensor_utils.cc @@ -158,6 +158,24 @@ int AllReduceCombineOptimizationGroupSize() { return 0; } +int AllReduceCombineOptimizationTopologicalDistance() { + int64_t topo_dist; + absl::Status status = tsl::ReadInt64FromEnvVar( + "DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_TOPOLOGICAL_DISTANCE", + /*default_val=*/0, &topo_dist); + if (!status.ok()) { + LOG(WARNING) << "Invalid DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_TOPOLOGICAL" + "_DISTANCE, using the default value 0."; + return 0; + } else if (topo_dist < 0) { + LOG(WARNING) << "Invalid DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_TOPOLOGICAL" + "_DISTANCE, value must be a positive integer, using the " + "default value 0."; + return 0; + } + return topo_dist; +} + bool EnableMultiDeviceMode() { bool multi_device_mode; absl::Status status = tsl::ReadBoolFromEnvVar( diff --git a/tensorflow/dtensor/cc/dtensor_utils.h b/tensorflow/dtensor/cc/dtensor_utils.h index 06626ddcda6bd1..b89a85cd5ec5b6 100644 --- a/tensorflow/dtensor/cc/dtensor_utils.h +++ b/tensorflow/dtensor/cc/dtensor_utils.h @@ -66,9 +66,24 @@ bool EnableReplicatedSpmdAsDefault(const std::string& op_name); // Returns whether to use all-to-all collective for relayout when possible. bool EnableAllToAllForRelayout(); -// Returns the maximum number of AllReduce ops to merge into a group. +// Returns the maximum number of AllReduce ops to merge into a group. This value +// determines the AllReduce grouping in dtensor_allreduce_combine_optimization. +// The input value should be in range of [0, INT_MAX]. It is advised to pick +// a value based on knowledge of the total number of AllReduces. When the value +// is too big, the behaviour will act as aggressive grouping. When the value is +// too small, the behaviour will act as having no extended grouping. int AllReduceCombineOptimizationGroupSize(); +// Returns the maximum topological distance between two AllReduce ops to merge +// into a single AllReduce. This value is used to determine AllReduce grouping +// in dtensor_allreduce_combine_optimization. The input value should be in range +// of [0, INT_MAX]. However, it is advised to select a value based on knowledge +// of the compute graph, such as the minimum distance between two model layers. +// When the input value is too big, the behaviour will act as aggressive group- +// ing. When the input value is too small, the behaviour will act as having no +// extended grouping. +int AllReduceCombineOptimizationTopologicalDistance(); + // Returns whether to perform multi-device expansion. bool EnableMultiDeviceMode(); } // namespace dtensor diff --git a/tensorflow/dtensor/cc/tensor_layout.cc b/tensorflow/dtensor/cc/tensor_layout.cc index 4e3a69d7d83ff3..f38a70137eca30 100644 --- a/tensorflow/dtensor/cc/tensor_layout.cc +++ b/tensorflow/dtensor/cc/tensor_layout.cc @@ -800,27 +800,13 @@ Mesh Mesh::CreateMesh(const std::string& mesh_name, } StatusOr Layout::GetLayout( - const std::vector& sharding_spec_strs, const Mesh& mesh) { - // Re-format sharding specs. - std::vector sharding_specs; - sharding_specs.reserve(sharding_spec_strs.size()); - for (const std::string& spec_str : sharding_spec_strs) { - ShardingSpec spec; - spec.set_sharding_spec(spec_str); - sharding_specs.push_back(spec); - } - return GetLayout(sharding_specs, mesh); -} - -StatusOr Layout::GetLayout( - const std::vector& sharding_specs, const Mesh& mesh) { + const std::vector& sharding_specs, const Mesh& mesh) { Layout layout; // Append mesh, then check sharding_specs are legal. layout.mesh_ = mesh; // Check sharding_specs are either mesh dimension or special value. - for (const auto& dim : sharding_specs) { - const std::string& sharding_spec = dim.sharding_spec(); + for (const auto& sharding_spec : sharding_specs) { if (!(sharding_spec == kUnshardedDim || sharding_spec == kAny || sharding_spec == kMatch || mesh.IsMeshDim(sharding_spec) || sharding_spec == "scalar")) @@ -831,8 +817,7 @@ StatusOr Layout::GetLayout( } // Check same tensor dimensions not sharded over same mesh dimension twice. std::set dims_set; - for (const auto& dim : sharding_specs) { - const std::string& sharding_spec = dim.sharding_spec(); + for (const auto& sharding_spec : sharding_specs) { if (sharding_spec == kUnshardedDim || sharding_spec == kAny) continue; // If scalar, delete all sharding specs. if (sharding_spec == "scalar") { @@ -876,8 +861,7 @@ bool Layout::IsEmpty() const { return mesh_.IsEmpty(); } namespace { Mesh ReducedAbstractMesh(const Layout* layout) { - const std::vector& shard_spec_strs = - layout->sharding_spec_strs(); + const std::vector shard_spec_strs = layout->sharding_spec_strs(); std::vector reduced_mesh_dims; reduced_mesh_dims.reserve(layout->mesh().dims().size()); for (const MeshDimension& mesh_dim : layout->mesh().dims()) { @@ -934,12 +918,9 @@ Mesh Layout::ReducedMesh() const { namespace { Layout ReducedLayout(const Layout* layout) { - // Change format sharding specs. - std::vector shard_specs(layout->sharding_specs().size()); - for (size_t i = 0; i < shard_specs.size(); ++i) - shard_specs[i] = layout->dim(i); // Retrieve layout. - return Layout::GetLayout(shard_specs, layout->ReducedMesh()).value(); + return Layout::GetLayout(layout->sharding_spec_strs(), layout->ReducedMesh()) + .value(); } // Returns index of the given mesh dimension or mesh dim size if not found. @@ -952,16 +933,13 @@ StatusOr IndexOfMeshDimension(const Mesh& mesh, } // namespace ShardVector Layout::GetShardVector() const { - // Change format sharding specs. - std::vector shard_specs(sharding_specs().size()); - for (size_t i = 0; i < shard_specs.size(); ++i) shard_specs[i] = dim(i); // Obtain a shard position (i.e. sharded section of a tensor) from a mesh // location, using the sharding specs. auto GetShardFromDeviceLocation = [&](const DeviceLocation& loc) -> Shard { Shard shard; - for (size_t i = 0; i < shard_specs.size(); ++i) { + for (size_t i = 0; i < sharding_specs_.size(); ++i) { // If unsharded, there is only one shard, that is 1. - std::string spec = shard_specs[i].sharding_spec(); + std::string spec = sharding_specs_[i]; if (spec == Layout::kUnshardedDim) { shard.push_back(1); } else { @@ -974,11 +952,11 @@ ShardVector Layout::GetShardVector() const { }; // Obtain dims of shard vector. auto ShardVectorDims = [&]() -> std::vector { - std::vector num_shards_per_dim(shard_specs.size()); - for (size_t i = 0; i < sharding_specs().size(); ++i) { - ShardingSpec spec = sharding_specs()[i]; - if (Layout::IsShardedSpec(spec)) { - StatusOr dim_size = mesh().dim_size(spec.sharding_spec()); + std::vector num_shards_per_dim(sharding_specs_.size()); + for (size_t i = 0; i < sharding_specs_.size(); ++i) { + std::string spec = sharding_specs_[i]; + if (Layout::IsShardedDimension(spec)) { + StatusOr dim_size = mesh().dim_size(spec); num_shards_per_dim[i] = dim_size.value(); } else { num_shards_per_dim[i] = 1; @@ -1033,28 +1011,25 @@ std::map Layout::HostShardMap() const { } const std::string& Layout::sharding_spec(int idx) const { - return sharding_specs_[idx].sharding_spec(); + return sharding_specs_[idx]; } std::vector Layout::num_shards() const { std::vector num_shards; num_shards.reserve(sharding_specs_.size()); - for (const auto& sharding_spec : sharding_specs_) { - num_shards.push_back(num_shards_for_dim(sharding_spec)); + for (int64_t index = 0; index < sharding_specs_.size(); ++index) { + num_shards.push_back(num_shards_for_dim(index)); } return num_shards; } -size_t Layout::num_shards_for_dim(const ShardingSpec& dim) const { - absl::string_view name = dim.sharding_spec(); - if (name == Layout::kUnshardedDim) return 1; - if (name == Layout::kMatch) return -1; - - return mesh().dim_size(name).value(); -} size_t Layout::num_shards_for_dim(int dim) const { - return num_shards_for_dim(sharding_specs_[dim]); + const std::string spec = sharding_specs_[dim]; + if (spec == Layout::kUnshardedDim) return 1; + if (spec == Layout::kMatch) return -1; + + return mesh().dim_size(spec).value(); } bool Layout::IsFullyReplicated() const { @@ -1062,7 +1037,7 @@ bool Layout::IsFullyReplicated() const { return false; } for (const auto& sharding_spec : sharding_specs_) { - if (sharding_spec.sharding_spec() != Layout::kUnshardedDim) return false; + if (sharding_spec != Layout::kUnshardedDim) return false; } return true; } @@ -1070,7 +1045,7 @@ bool Layout::IsFullyReplicated() const { bool Layout::IsLastDimReplicated() const { return (mesh_.IsTile() && ((sharding_specs_.empty()) || - (sharding_specs_.back().sharding_spec() == Layout::kUnshardedDim))); + (sharding_specs_.back() == Layout::kUnshardedDim))); } bool Layout::IsBatchParallel() const { @@ -1082,12 +1057,12 @@ bool Layout::IsBatchParallel() const { } for (int i = 1; i < sharding_specs_.size(); ++i) { - const auto& dim = sharding_specs_[i]; - if (dim.sharding_spec() != Layout::kUnshardedDim) { + const auto& spec = sharding_specs_[i]; + if (spec != Layout::kUnshardedDim) { return false; } } - return sharding_specs_[0].sharding_spec() != Layout::kUnshardedDim; + return sharding_specs_[0] != Layout::kUnshardedDim; } // TODO(samuelslee) Replace this with the IsBatchParallel() everywhere @@ -1097,7 +1072,7 @@ bool Layout::IsBatchParallel(int non_batch_rank) const { } if (sharding_specs_.empty()) return true; for (int i = rank() - non_batch_rank; i < rank(); ++i) { - if (num_shards_for_dim(sharding_specs_[i]) != 1) return false; + if (num_shards_for_dim(i) != 1) return false; } return true; } @@ -1105,8 +1080,8 @@ bool Layout::IsBatchParallel(int non_batch_rank) const { StatusOr Layout::ToProto() const { LayoutProto proto; TF_ASSIGN_OR_RETURN(*proto.mutable_mesh_config(), mesh_.ToProto()); - for (const auto& dim : sharding_specs_) { - *proto.add_sharding_specs() = dim; + for (const auto& spec : sharding_specs_) { + proto.add_sharding_specs()->set_sharding_spec(spec); } return proto; } @@ -1115,10 +1090,8 @@ bool Layout::IsEquivalent(const Layout& b) const { if (this->rank() != b.rank()) return false; if (this->mesh() != b.mesh()) return false; for (int i = 0; i < this->rank(); ++i) { - if (this->sharding_specs_[i].sharding_spec() != - b.sharding_specs_[i].sharding_spec()) { - if ((this->num_shards_for_dim(this->sharding_specs_[i]) != 1) || - (b.num_shards_for_dim(b.sharding_specs_[i]) != 1)) + if (this->sharding_specs_[i] != b.sharding_specs_[i]) { + if ((this->num_shards_for_dim(i) != 1) || (b.num_shards_for_dim(i) != 1)) return false; } } @@ -1142,7 +1115,7 @@ std::vector Layout::GlobalShapeFromLocalShape( } std::vector stride_for_dim; - stride_for_dim.resize(sharding_specs().size()); + stride_for_dim.resize(sharding_specs_.size()); size_t stride = mesh().num_local_devices(); for (int i = 0; i < stride_for_dim.size(); i++) { stride = stride / num_shards_for_dim(i); @@ -1168,8 +1141,8 @@ std::vector Layout::GlobalShapeFromLocalShape( }; std::vector global_shape; - global_shape.reserve(sharding_specs().size()); - for (int i = 0; i < sharding_specs().size(); ++i) { + global_shape.reserve(sharding_specs_.size()); + for (int i = 0; i < sharding_specs_.size(); ++i) { global_shape.push_back(dimension_size(i)); } return global_shape; @@ -1182,7 +1155,7 @@ std::vector Layout::LocalShapeFromGlobalShape( } std::vector shards = num_shards(); std::vector local_shape; - for (int i = 0; i < sharding_specs().size(); ++i) { + for (int i = 0; i < sharding_specs_.size(); ++i) { int64_t dim_shards = shards[i]; // TODO(hthu): Shape might not be always divisible. int64_t local_size = IsDynamicSize(global_shape[i]) @@ -1200,7 +1173,7 @@ PartialTensorShape Layout::LocalShapeFromGlobalShape( } std::vector shards = num_shards(); PartialTensorShape local_shape({}); - for (int spec_index = 0; spec_index < sharding_specs().size(); ++spec_index) { + for (int spec_index = 0; spec_index < sharding_specs_.size(); ++spec_index) { int64_t dim_size = global_shape.dim_size(spec_index); int64_t local_size = IsDynamicSize(dim_size) ? dim_size : dim_size / shards[spec_index]; @@ -1213,7 +1186,7 @@ StatusOr Layout::FromProto(const LayoutProto& proto) { Layout layout; if (proto.mesh_config().single_device().empty()) { for (const auto& spec : proto.sharding_specs()) - layout.sharding_specs_.push_back(spec); + layout.sharding_specs_.push_back(spec.sharding_spec()); TF_ASSIGN_OR_RETURN(auto mesh, Mesh::ParseFromProto(proto.mesh_config())); layout.mesh_ = std::move(mesh); @@ -1299,10 +1272,7 @@ StatusOr Layout::FromString(absl::string_view layout_str) { } std::vector Layout::sharding_spec_strs() const { - std::vector sharding_spec_strs(sharding_specs().size()); - for (size_t i = 0; i < sharding_specs().size(); ++i) - sharding_spec_strs[i] = sharding_spec(i); - return sharding_spec_strs; + return sharding_specs_; } std::string Layout::ToString() const { @@ -1315,8 +1285,7 @@ std::string Layout::ToString() const { std::string layout_str = "sharding_specs:"; // Print sharding specs. - for (const ShardingSpec& dim : sharding_specs_) { - std::string dim_name = dim.sharding_spec(); + for (const auto& dim_name : sharding_specs_) { absl::StrAppend(&layout_str, dim_name + ","); } // Append mesh. @@ -1326,19 +1295,16 @@ std::string Layout::ToString() const { StatusOr Layout::GetLayoutWithReducedDims( const absl::flat_hash_set& reduced_dims, bool keep_dims) const { - dtensor::LayoutProto output_layout; - TF_ASSIGN_OR_RETURN(*output_layout.mutable_mesh_config(), mesh().ToProto()); - + std::vector sharding_specs; for (int i = 0; i < rank(); ++i) { // reduced_dims may contain negative values. if (!reduced_dims.contains(i) && !reduced_dims.contains(i - rank())) { - *output_layout.add_sharding_specs() = dim(i); + sharding_specs.push_back(sharding_spec(i)); } else if (keep_dims) { - auto* replicated_dim = output_layout.add_sharding_specs(); - replicated_dim->set_sharding_spec(kUnshardedDim); + sharding_specs.push_back(kUnshardedDim); } } - return Layout::FromProto(output_layout).value(); + return Layout::GetLayout(sharding_specs, mesh()); } Layout Layout::Truncate(int64 split_point, bool end) const { @@ -1362,9 +1328,7 @@ Layout Layout::LeftPad(int64_t rank) const { Layout output_layout(*this); auto& specs = output_layout.sharding_specs_; - ShardingSpec spec; - spec.set_sharding_spec(Layout::kUnshardedDim); - specs.insert(specs.begin(), rank - this->rank(), spec); + specs.insert(specs.begin(), rank - this->rank(), Layout::kUnshardedDim); return output_layout; } @@ -1408,7 +1372,7 @@ StatusOr GetMostShardedLayout(const std::vector& layouts) { absl::flat_hash_map> layout_map; for (const Layout& layout : layouts) { for (int i = 0; i < layout.rank(); ++i) { - const std::string& mesh_dim = layout.dim(i).sharding_spec(); + const std::string& mesh_dim = layout.sharding_spec(i); if (mesh_dim == Layout::kUnshardedDim) continue; layout_map[mesh_dim].insert(i); @@ -1461,7 +1425,7 @@ StatusOr GetLeastShardedLayout(const std::vector& layouts) { } specs.resize(rank, Layout::kAny); for (const auto& layout : layouts) { - auto current_specs = layout.sharding_spec_strs(); + const auto current_specs = layout.sharding_spec_strs(); for (int i = 0; i < rank; i++) { auto current_spec = current_specs[i]; if (specs[i] == Layout::kAny) { diff --git a/tensorflow/dtensor/cc/tensor_layout.h b/tensorflow/dtensor/cc/tensor_layout.h index 2cd59d123b9cab..4bee6860646539 100644 --- a/tensorflow/dtensor/cc/tensor_layout.h +++ b/tensorflow/dtensor/cc/tensor_layout.h @@ -340,16 +340,8 @@ class Layout { static bool IsShardedDimension(const absl::string_view name) { return !IsUnshardedDimension(name); } - static bool IsUnshardedSpec(const ShardingSpec& spec) { - return IsUnshardedDimension(spec.sharding_spec()); - } - static bool IsShardedSpec(const ShardingSpec& spec) { - return !IsUnshardedDimension(spec.sharding_spec()); - } static StatusOr GetLayout( const std::vector& sharding_spec_strs, const Mesh& mesh); - static StatusOr GetLayout( - const std::vector& sharding_specs, const Mesh& mesh); static StatusOr GetSingleDeviceLayout(const Mesh& mesh); // Makes a new layout from this one dropping the given dimensions. @@ -391,15 +383,9 @@ class Layout { const PartialTensorShape& global_shape) const; int64 rank() const { return sharding_specs_.size(); } - size_t num_shards_for_dim(const ShardingSpec& dim) const; size_t num_shards_for_dim(int) const; std::vector num_shards() const; - const ShardingSpec& dim(int64 idx) const { return sharding_specs_[idx]; } - absl::Span sharding_specs() const { - return sharding_specs_; - } - // Computes the corresponding shard vector to this layout. ShardVector GetShardVector() const; @@ -426,7 +412,7 @@ class Layout { } private: - std::vector sharding_specs_; + std::vector sharding_specs_; Mesh mesh_; }; diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index 95c79716168a9b..a6ef38cf6bb63d 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -3,7 +3,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +18,7 @@ package( gentbl_cc_library( name = "tensorflow_dtensor_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -46,7 +46,7 @@ gentbl_cc_library( gentbl_cc_library( name = "dtensor_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [( [ "-gen-pass-decls", diff --git a/tensorflow/dtensor/mlir/collectives.cc b/tensorflow/dtensor/mlir/collectives.cc index c54eb2b725366d..95c91dc086b0a3 100644 --- a/tensorflow/dtensor/mlir/collectives.cc +++ b/tensorflow/dtensor/mlir/collectives.cc @@ -165,22 +165,22 @@ bool CanUseAllToAll(const dtensor::Layout& src_layout, // all-to-all in addition to these which can be supported later. int num_split_dims = 0; int num_concat_dims = 0; - ShardingSpec split_spec; - ShardingSpec concat_spec; + std::string split_spec; + std::string concat_spec; for (int i = 0; i < src_layout.rank(); ++i) { if (src_layout.sharding_spec(i) == tgt_layout.sharding_spec(i)) continue; if (Layout::IsUnshardedDimension(src_layout.sharding_spec(i)) && Layout::IsShardedDimension(tgt_layout.sharding_spec(i))) { num_split_dims++; - split_spec = tgt_layout.dim(i); + split_spec = tgt_layout.sharding_spec(i); } else if (Layout::IsShardedDimension(src_layout.sharding_spec(i)) && Layout::IsUnshardedDimension(tgt_layout.sharding_spec(i))) { num_concat_dims++; - concat_spec = src_layout.dim(i); + concat_spec = src_layout.sharding_spec(i); } } return num_split_dims == 1 && num_concat_dims == 1 && - split_spec.sharding_spec() == concat_spec.sharding_spec(); + split_spec == concat_spec; } StatusOr EmitAllToAll( @@ -339,14 +339,14 @@ StatusOr EmitRelayout( for (int i = 0; i < src_layout.rank(); ++i) src_sharding_dims.emplace(src_layout.sharding_spec(i)); - std::vector intermediate_specs_1(src_layout.rank()); + std::vector intermediate_specs_1(src_layout.rank()); for (int i = 0; i < src_layout.rank(); ++i) { - if (Layout::IsShardedSpec(tgt_layout.dim(i)) && - !Layout::IsShardedSpec(src_layout.dim(i)) && + if (Layout::IsShardedDimension(tgt_layout.sharding_spec(i)) && + !Layout::IsShardedDimension(src_layout.sharding_spec(i)) && !src_sharding_dims.contains(tgt_layout.sharding_spec(i))) - intermediate_specs_1[i] = tgt_layout.dim(i); + intermediate_specs_1[i] = tgt_layout.sharding_spec(i); else - intermediate_specs_1[i] = src_layout.dim(i); + intermediate_specs_1[i] = src_layout.sharding_spec(i); } TF_ASSIGN_OR_RETURN( Layout intermediate_layout_1, @@ -357,11 +357,11 @@ StatusOr EmitRelayout( EmitAllScatter(builder, input, src_layout, intermediate_layout_1, newly_created_ops)); - std::vector intermediate_specs_2(src_layout.rank()); + std::vector intermediate_specs_2(src_layout.rank()); for (int i = 0; i < src_layout.rank(); ++i) { - if (Layout::IsShardedSpec(intermediate_specs_1[i]) && - intermediate_specs_1[i].sharding_spec() != tgt_layout.sharding_spec(i)) - intermediate_specs_2[i].set_sharding_spec(Layout::kUnshardedDim); + if (Layout::IsShardedDimension(intermediate_specs_1[i]) && + intermediate_specs_1[i] != tgt_layout.sharding_spec(i)) + intermediate_specs_2[i] = Layout::kUnshardedDim; else intermediate_specs_2[i] = intermediate_specs_1[i]; } diff --git a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc index f4e9ee9d3beed7..11099801b45ae7 100644 --- a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc +++ b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc @@ -546,6 +546,10 @@ createSubgroupsByGroupAssignment( // Experimental extended grouping logics to avoid aggressive grouping. // This function performs the same grouping method as tf.distribute, which group // all reduce ops by user defined group size (number of ops) in the input order. +// Note that group_size will be in range of [0, INT_MAX]. It is advised to pick +// a value based on knowledge of the total number of AllReduces. When group_size +// is too big, the function will act as aggressive grouping. When group_size is +// too small, the function will act as having no extended grouping. std::vector> createSubgroupsByExtendedNumOps( std::vector> all_reduce_groups, @@ -580,6 +584,103 @@ createSubgroupsByExtendedNumOps( return all_reduce_new_groups; } +// Experimental grouping logics to optimize from aggressive grouping. +// This function first sort by topological level, then create AllReduce sub- +// groups by accessing each topological distance from its previous AllReduce. +// Note that topo_dist will be in range of [0, INT_MAX]. It is advised to select +// a value based on knowledge of the compute graph, such as the minimum distance +// between two model layers. When topo_dist is too big, the function will act +// as aggressive grouping. When topo_dist is too small, the function will act as +// having no extended grouping. +StatusOr>> +createSubgroupsByTopoDist( + std::vector> all_reduce_groups, + llvm::DenseMap all_reduce_topo, + int topo_dist) { + // Disable extended grouping if topological distance is set to zero or less + if (topo_dist <= 0) return all_reduce_groups; + VLOG(4) << "current number of groups: " << all_reduce_groups.size(); + std::vector> all_reduce_new_groups; + + // Further break down the current all_reduced_groups by topological distance + // between two ops + for (auto& all_reduce_group : all_reduce_groups) { + std::vector new_group; + Status status = absl::OkStatus(); + + // Sort AllReduces by topological level as the input order may not reflect + // their dependencies on the operands in the compute graph. + std::sort(all_reduce_group.begin(), all_reduce_group.end(), + [&all_reduce_topo, &status](mlir::TF::DTensorAllReduceOp& lhs, + mlir::TF::DTensorAllReduceOp& rhs) { + if ((all_reduce_topo.find(lhs) == all_reduce_topo.end()) || + (all_reduce_topo.find(rhs) == all_reduce_topo.end())) { + status = absl::InternalError( + "Error: encounter AllReduce op with no topological level" + " assignment."); + return false; + } + return all_reduce_topo[lhs] < all_reduce_topo[rhs]; + }); + // Unable to sort AllReduces based on topological level due to error. Return + // directly as we are not able to group based on incorrect/partial topology. + if (!status.ok()) return status; + + // Form AllReduce groups based on the topological distance between ops + DCHECK(!all_reduce_group.empty()); + int prev_topo_level = all_reduce_topo[all_reduce_group[0]]; + for (const auto& all_reduce : all_reduce_group) { + DCHECK(all_reduce_topo.find(all_reduce) != all_reduce_topo.end()); + int cur_topo_level = all_reduce_topo[all_reduce]; + if (abs(cur_topo_level - prev_topo_level) <= topo_dist) { + new_group.push_back(all_reduce); + } else { + // Start a new group + all_reduce_new_groups.push_back( + std::vector(new_group.begin(), + new_group.end())); + new_group.clear(); + new_group.push_back(all_reduce); + } + prev_topo_level = cur_topo_level; + } + all_reduce_new_groups.push_back(new_group); + } + VLOG(4) << "new number of groups: " << all_reduce_new_groups.size(); + return all_reduce_new_groups; +} + +// Compute the topological level for each AllReduce op in a cluster. The level +// is defined as 1 + max operands' depth in the compute graph. If an op do not +// depend on any input/operand, then it is level 0. +llvm::DenseMap computeAllReduceTopoLevel( + mlir::tf_device::ClusterOp cluster) { + llvm::DenseMap op_topo_level; + llvm::DenseMap all_reduce_topo; + + // Compute topological level for each op. + cluster.getBody().walk([&](mlir::Operation* op) { + int max_depth = 0; + for (mlir::Value operand : op->getOperands()) { + if (mlir::Operation* operand_op = operand.getDefiningOp()) { + if (op_topo_level.find(operand_op) != op_topo_level.end()) { + max_depth = fmax(max_depth, op_topo_level[operand_op]); + } + } + } + op_topo_level[op] = max_depth + 1; + + // Save the AllReduce topological level + mlir::TF::DTensorAllReduceOp all_reduce = + llvm::dyn_cast(op); + if (all_reduce && !all_reduce.getDeviceType().contains("TPU")) { + all_reduce_topo[all_reduce] = op_topo_level[op]; + } + }); + + return all_reduce_topo; +} + struct DTensorAllReduceCombineOptimization : public impl::DTensorAllReduceCombineOptimizationBase< DTensorAllReduceCombineOptimization> { @@ -590,11 +691,10 @@ struct DTensorAllReduceCombineOptimization std::vector ordered_all_reduces; std::vector ordered_blocks; llvm::DenseSet blocks; - cluster.GetBody().walk([&](mlir::TF::DTensorAllReduceOp all_reduce) { if (!all_reduce.getDeviceType().contains("TPU")) { // Only combine all reduces for GPU and CPU - auto all_reduce_ranked_type = + mlir::RankedTensorType all_reduce_ranked_type = all_reduce.getType().dyn_cast(); if (all_reduce_ranked_type && @@ -621,15 +721,37 @@ struct DTensorAllReduceCombineOptimization all_reduce_groups = createSubgroupsByReductionAttr(all_reduce_groups); all_reduce_groups = createSubgroupsByGroupAssignment(all_reduce_groups); - // Experimental extended grouping - int group_size = 0; + // Experimental extended grouping: topological distance + if (module->hasAttrOfType( + kAllReduceTopologicalDistance)) { + llvm::DenseMap all_reduce_topo = + computeAllReduceTopoLevel(cluster); + + StatusOr>> + group = createSubgroupsByTopoDist( + all_reduce_groups, all_reduce_topo, + module + ->getAttrOfType( + kAllReduceTopologicalDistance) + .getInt()); + if (!group.ok()) { + // This is a non-fatal error since topological level distance is one + // of the optimizations in this combiner pass. Output an error and + // continue with the rest of the grouping optimization. + LOG(WARNING) << "Failed to create subgroups using topological " + << "level distance: " << group.status(); + } else { + all_reduce_groups = group.value(); + } + } + + // Experimental extended grouping: fixed number of AllReduce ops if (module->hasAttrOfType(kAllReduceNumOpsInGroup)) { - group_size = + all_reduce_groups = createSubgroupsByExtendedNumOps( + all_reduce_groups, module->getAttrOfType(kAllReduceNumOpsInGroup) - .getInt(); + .getInt()); } - all_reduce_groups = - createSubgroupsByExtendedNumOps(all_reduce_groups, group_size); // Maintain relative order of ALLReduces within the block. std::sort(all_reduce_groups.begin(), all_reduce_groups.end(), diff --git a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc index ad19eb415f3478..b3aa8a70d9fe4d 100644 --- a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc +++ b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc @@ -121,21 +121,21 @@ mlir::LogicalResult ConvertShortIntReduce(ReduceOpType reduce_op) { << "Received '" << reduce_op.getReduceOpAttr().getValue().str() << "'"; } - if (mlir::isa(tensor_input_type.getElementType())) { + if (auto integer_type = mlir::dyn_cast( + tensor_input_type.getElementType())) { int32_t min_width = 64; if (output_layout->mesh().is_tpu_mesh()) { min_width = 32; } - if (tensor_input_type.getElementType().getIntOrFloatBitWidth() >= - min_width) { + if (integer_type.getWidth() >= min_width) { return mlir::success(); } auto input_type = mlir::RankedTensorType::get( tensor_input_type.getShape(), builder.getIntegerType(min_width)); auto output_type = mlir::RankedTensorType::get( - tensor_output_type.getShape(), tensor_input_type.getElementType()); + tensor_output_type.getShape(), integer_type); return WrapOpWithCasts(input_type, output_type, reduce_op); } if (mlir::isa(tensor_input_type.getElementType())) { diff --git a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD index cee04f594da170..feb41e226c21c0 100644 --- a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD +++ b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD @@ -1,7 +1,7 @@ # DTensor MLIR dialect. load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -21,7 +21,7 @@ td_library( "ir/dtensor_dialect.td", "ir/dtensor_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:FuncTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", @@ -31,7 +31,7 @@ td_library( gentbl_cc_library( name = "DialectIncGen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], diff --git a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc index eb8d83034038dd..1080ca93f6c434 100644 --- a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc +++ b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -72,7 +75,20 @@ using ExpandedArgumentMap = absl::flat_hash_map>>; -using ExpandedResultsMap = absl::flat_hash_map>; +struct ExpandedResults { + std::optional layout; + std::vector results; + + template + void insert(Value&& value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + results.emplace_back(std::forward(value)); + } else { + results.insert(results.end(), value.begin(), value.end()); + } + } +}; mlir::BlockArgument InsertArgumentForDevice(mlir::OpBuilder& builder, mlir::func::FuncOp func, @@ -92,29 +108,28 @@ mlir::BlockArgument InsertArgumentForDevice(mlir::OpBuilder& builder, // Returns the user of all the ops in the span iff it is a single return op. // Otherwise, returns nullptr; for example, if there are multiple return ops. -template -mlir::func::ReturnOp GetReturnOpFromUsers(absl::Span ops) { - mlir::func::ReturnOp return_op; - - for (Operation op : ops) { +template +mlir::LogicalResult GetReturnOpFromUsers(Operations&& ops, + mlir::func::ReturnOp* return_op) { + for (mlir::Operation* op : ops) { for (mlir::Operation* user : op->getUsers()) { // TODO(twelve): Determine whether we should follow identity ops. if (mlir::func::ReturnOp op = llvm::dyn_cast_or_null(user)) { - if (return_op) { - if (return_op != op) { - return nullptr; + if (*return_op) { + if (*return_op != op) { + return mlir::failure(); } } else { - return_op = op; + *return_op = op; } } else { - return nullptr; + return mlir::failure(); } } } - return return_op; + return mlir::success(); } // Returns the devices for a given mesh. @@ -132,6 +147,62 @@ StatusOr> GetExpandedArguments( ExpandedArgumentMap& expanded_arguments, mlir::BlockArgument argument, const Mesh* target_mesh = nullptr); +StatusOr>> GetResourceLayouts( + mlir::Operation* op) { + if (op->hasAttr(kNewResourceArgLayouts)) { + auto attrs = op->getAttrOfType(kNewResourceArgLayouts); + std::vector layouts; + layouts.reserve(attrs.size()); + for (mlir::Attribute attr : attrs) { + auto string_attr = attr.cast(); + auto layout = Layout::FromString(string_attr.str()); + if (layout.ok()) { + layouts.emplace_back(std::move(layout.value())); + } else { + return layout.status(); + } + } + return layouts; + } else { + return std::nullopt; + } +} + +bool IsResource(mlir::Value value) { + return getElementTypeOrSelf(value.getType()).isa(); +} + +StatusOr> FindResourceLayout(mlir::BlockArgument arg) { + uint32_t arg_num = arg.getArgNumber(); + for (mlir::Operation* user : arg.getUsers()) { + auto resource_layouts = GetResourceLayouts(user); + if (resource_layouts.ok()) { + const auto& opt = resource_layouts.value(); + if (!opt || opt->empty()) { + continue; + } + } else { + return resource_layouts.status(); + } + + auto resource_indices = user->getAttrOfType( + kNewResourceLayoutIndices); + if (!resource_indices) { + return absl::InvalidArgumentError( + absl::StrCat("missing ", kNewResourceLayoutIndices)); + } + + for (auto [i, index] : llvm::enumerate(resource_indices)) { + int64_t index_value = index.getSExtValue(); + if (index_value == arg_num) { + return (resource_layouts.value())->at(i); + } + } + } + + return std::nullopt; +} + mlir::tf_device::ClusterFuncOp ExtractDeviceClusterFromFunctionCall( mlir::TF::StatefulPartitionedCallOp op) { mlir::tf_device::ClusterFuncOp result; @@ -162,11 +233,11 @@ void AddMetadataToTPUCluster(const Mesh& mesh_config, int64_t num_devices, // into a cluster func that has partitioned inputs and outputs ops; // it will be rewritten by TPURewritePass into per-device TPUExecute ops. template -mlir::LogicalResult ExpandTPUOperation(mlir::func::FuncOp target_func, - mlir::func::ReturnOp return_op, - ExpandedArgumentMap& expanded_arguments, - ExpandedResultsMap& expanded_results, - const Mesh& target_mesh, Operation op) { +mlir::LogicalResult ExpandTPUOperation( + mlir::func::FuncOp target_func, mlir::func::ReturnOp return_op, + ExpandedArgumentMap& expanded_arguments, + std::vector& expanded_results, const Mesh& target_mesh, + Operation op) { const absl::Span devices = GetDevices(target_mesh); const std::size_t num_devices = devices.size(); @@ -232,9 +303,7 @@ mlir::LogicalResult ExpandTPUOperation(mlir::func::FuncOp target_func, const std::size_t result_number = search - results.begin(); const mlir::Operation::result_range replicated_results = replications.at(result_number); - expanded_results[i].insert(expanded_results[i].end(), - replicated_results.begin(), - replicated_results.end()); + expanded_results[i].insert(replicated_results); } } } @@ -246,11 +315,11 @@ mlir::LogicalResult ExpandTPUOperation(mlir::func::FuncOp target_func, // de/multiplexes the per-device inputs/outputs for each "expanded" op. // Only usable on CPU/GPU devices, which do not require additional rewriting. template -mlir::LogicalResult ExpandOperation(mlir::func::FuncOp target_func, - mlir::func::ReturnOp return_op, - ExpandedArgumentMap& expanded_arguments, - ExpandedResultsMap& expanded_results, - const Mesh& target_mesh, Operation op) { +mlir::LogicalResult ExpandOperation( + mlir::func::FuncOp target_func, mlir::func::ReturnOp return_op, + ExpandedArgumentMap& expanded_arguments, + std::vector& expanded_results, const Mesh& target_mesh, + Operation op) { mlir::OpBuilder builder(target_func.getBody()); const absl::Span devices = GetDevices(target_mesh); const std::size_t num_devices = devices.size(); @@ -299,8 +368,8 @@ mlir::LogicalResult ExpandOperation(mlir::func::FuncOp target_func, llvm::find(results, operand); const std::size_t result_number = search - results.begin(); for (const Operation& replication : replications) { - expanded_results[i].emplace_back( - replication->getResult(result_number)); + expanded_results[i].insert( + (mlir::Value)replication->getResult(result_number)); } } } @@ -350,10 +419,25 @@ StatusOr> GetExpandedArguments( mesh = *target_mesh; } } else { - TF_ASSIGN_OR_RETURN(const std::optional layout, + TF_ASSIGN_OR_RETURN(std::optional layout, ExtractLayoutFromOperand(arg)); if (layout) { mesh = layout->mesh(); + + if (mesh->IsEmpty()) { + if (target_mesh) { + mesh = *target_mesh; + } else if (IsResource(arg)) { + TF_ASSIGN_OR_RETURN(layout, FindResourceLayout(arg)); + if (layout) { + mesh = layout->mesh(); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Could not find resource layout for %arg", + arg.getArgNumber(), "!")); + } + } + } } } if (mesh.has_value()) { @@ -401,25 +485,50 @@ mlir::FunctionType GetFunctionType(mlir::OpBuilder& builder, return builder.getFunctionType(input_types, result_types); } +struct InferredResourceAttributes { + mlir::Attribute layouts; + mlir::Attribute indices; + + InferredResourceAttributes(mlir::Attribute layouts_, mlir::Attribute indices_) + : layouts(layouts_), indices(indices_) {} +}; + +template +mlir::LogicalResult GetInferredResourceAttributes( + mlir::OpBuilder& builder, const Operations& call_ops, + std::optional* resource_attrs) { + llvm::SmallVector resource_layouts; + llvm::SmallVector resource_indices; + for (mlir::Operation* call_op : call_ops) { + const auto resource_layouts_attr = + call_op->getAttrOfType(kNewResourceArgLayouts); + const auto resource_indices_attr = + call_op->getAttrOfType( + kNewResourceLayoutIndices); + if (resource_indices_attr && resource_layouts_attr) { + for (auto [index, layout] : + llvm::zip(resource_indices_attr, resource_layouts_attr)) { + // Build up the lists of resource indices and layouts. + resource_indices.emplace_back(index.getSExtValue()); + resource_layouts.emplace_back(layout); + } + } + } + if (!resource_layouts.empty()) { + resource_attrs->emplace(builder.getArrayAttr(resource_layouts), + builder.getI32VectorAttr(resource_indices)); + } + return mlir::success(); +} + // Build a new main function that calls the multi-device/translated function. +template mlir::LogicalResult BuildOuterMainFunc( mlir::ModuleOp module, mlir::func::FuncOp old_main_func, mlir::func::FuncOp translated_func, mlir::func::ReturnOp return_op, - absl::Span call_ops) { - llvm::SmallVector output_layouts; - for (mlir::TF::StatefulPartitionedCallOp call_op : call_ops) { - // Then extract all their output layouts. - mlir::ArrayAttr layouts = - call_op->getAttr(kLayoutAttr).dyn_cast_or_null(); - if (!layouts) { - call_op.emitOpError() << "Could not find op's layouts."; - return mlir::failure(); - } - // Here, we assume that the output layouts and the results are in the same - // ordering--this property should be guaranteed as long as all the results - // have been expanded (produced by ExpandOperation). - output_layouts.insert(output_layouts.end(), layouts.begin(), layouts.end()); - } + const std::vector& expanded_results, + mlir::ArrayAttr num_local_outputs_attr, Operations&& call_ops) { + using CallOp = typename std::decay_t::value_type; mlir::SymbolTable symbol_table(module); mlir::Block* module_body = module.getBody(); @@ -442,25 +551,41 @@ mlir::LogicalResult BuildOuterMainFunc( // Get the type of the translated function. mlir::FunctionType func_type = translated_func.getFunctionType(); - // Then build a call op targeting it (reflecting its result types). - auto expanded_call_op = builder.create( - call_ops[0].getLoc(), func_type.getResults(), inputs, - translated_func.getSymName(), - /*config=*/builder.getStringAttr(""), - /*config_proto=*/builder.getStringAttr(""), - /*executor_type=*/builder.getStringAttr("")); + // Then build a call op targeting it (reflecting its result types) + auto expanded_call_op = + builder.create(call_ops[0].getLoc(), func_type.getResults(), + inputs, translated_func.getSymName(), + /*config=*/builder.getStringAttr(""), + /*config_proto=*/builder.getStringAttr(""), + /*executor_type=*/builder.getStringAttr("")); // Set the output layout attribute on the new call op. - llvm::ArrayRef output_layouts_ref(output_layouts); - mlir::ArrayAttr output_layouts_attr = - builder.getArrayAttr(output_layouts_ref); - expanded_call_op->setAttr(kLayoutAttr, output_layouts_attr); + std::vector> output_layouts; + std::transform(expanded_results.begin(), expanded_results.end(), + std::back_inserter(output_layouts), + [](const ExpandedResults& result) { return result.layout; }); + SetLayoutOnOp(expanded_call_op, builder, output_layouts); + + expanded_call_op->setAttr(kNumLocalOutputsAttr, num_local_outputs_attr); + + std::optional resource_attrs; + if (failed( + GetInferredResourceAttributes(builder, call_ops, &resource_attrs))) { + return mlir::failure(); + } + + if (resource_attrs) { + expanded_call_op->setAttr(kNewResourceArgLayouts, resource_attrs->layouts); + expanded_call_op->setAttr(kNewResourceLayoutIndices, + resource_attrs->indices); + } // Return all the values from the new call op. mlir::Operation::result_range outputs = expanded_call_op.getResults(); - if (return_op) { - builder.create(return_op.getLoc(), outputs); - } else if (!outputs.empty()) { + if (return_op || outputs.empty()) { + mlir::Location loc = return_op ? return_op.getLoc() : main_func.getLoc(); + builder.create(loc, outputs); + } else { call_ops[0]->emitOpError("Call had results, but they were not used."); return mlir::failure(); } @@ -479,6 +604,25 @@ mlir::LogicalResult BuildOuterMainFunc( return mlir::success(); } +Status ExtractResultLayouts(mlir::Operation* op, mlir::func::ReturnOp return_op, + std::vector& expanded_results) { + if (!return_op || (return_op.getNumOperands() == 0)) { + return OkStatus(); + } + TF_ASSIGN_OR_RETURN(std::vector> layouts, + ExtractLayoutFromOp(op)); + mlir::Operation::operand_range operands = return_op.getOperands(); + for (auto [layout_index, result] : llvm::enumerate(op->getResults())) { + auto search = std::find(operands.begin(), operands.end(), result); + if (search == operands.end()) { + continue; + } + std::size_t result_index = std::distance(operands.begin(), search); + expanded_results[result_index].layout = layouts[layout_index]; + } + return OkStatus(); +} + struct DTensorMultiDeviceExpansion : public impl::DTensorMultiDeviceExpansionBase< DTensorMultiDeviceExpansion> { @@ -542,19 +686,21 @@ struct DTensorMultiDeviceExpansion }); // Ensure that all the call ops return results via the same op. - mlir::func::ReturnOp return_op = GetReturnOpFromUsers( - absl::Span(stateful_call_ops)); - if (!return_op && !stateful_call_ops.empty()) { + mlir::func::ReturnOp return_op; + if (GetReturnOpFromUsers(stateful_call_ops, &return_op).failed()) { stateful_call_ops[0]->emitOpError( "Calls must be used by exactly one return op."); return; } - ExpandedResultsMap expanded_results; + std::vector expanded_results( + return_op ? return_op->getNumOperands() : 0); for (const mlir::TF::StatefulPartitionedCallOp& stateful_call_op : stateful_call_ops) { + const Status status = + ExtractResultLayouts(stateful_call_op, return_op, expanded_results); const StatusOr> mesh = - ExtractDeviceMeshFromOp(stateful_call_op); + status.ok() ? ExtractDeviceMeshFromOp(stateful_call_op) : status; if (!(mesh.ok() && *mesh)) { stateful_call_op->emitOpError("Failed to retrieve op mesh or layout."); return; @@ -581,16 +727,25 @@ struct DTensorMultiDeviceExpansion } std::vector results; - for (unsigned i = 0; i < return_op->getNumOperands(); ++i) { - ExpandedResultsMap::iterator search = expanded_results.find(i); - if (search == expanded_results.end()) { - results.emplace_back(return_op->getOperand(i)); - } else { - std::vector& values = search->second; - results.insert(results.end(), values.begin(), values.end()); + llvm::SmallVector num_local_outputs; + if (return_op) { + for (unsigned i = 0; i < return_op->getNumOperands(); ++i) { + std::vector& values = expanded_results[i].results; + int num_outputs; + if (values.empty()) { + results.emplace_back(return_op->getOperand(i)); + num_outputs = 1; + } else { + results.insert(results.end(), values.begin(), values.end()); + num_outputs = values.size(); + } + num_local_outputs.emplace_back(builder.getI64IntegerAttr(num_outputs)); } } + mlir::ArrayAttr num_local_outputs_attr = + builder.getArrayAttr(num_local_outputs); + // update the operands of the translated return op translated_terminator_op->setOperands(results); // and, update the function's type accordingly @@ -599,8 +754,8 @@ struct DTensorMultiDeviceExpansion UpdateEntryFuncAttr(builder, translated_func); mlir::LogicalResult status = BuildOuterMainFunc( - module, main_func, translated_func, return_op, - absl::Span(stateful_call_ops)); + module, main_func, translated_func, return_op, expanded_results, + num_local_outputs_attr, stateful_call_ops); if (mlir::failed(status)) { return; } diff --git a/tensorflow/dtensor/mlir/dtensor_send_recv.cc b/tensorflow/dtensor/mlir/dtensor_send_recv.cc index 14e865fc789173..4631f2cfd7e905 100644 --- a/tensorflow/dtensor/mlir/dtensor_send_recv.cc +++ b/tensorflow/dtensor/mlir/dtensor_send_recv.cc @@ -538,99 +538,105 @@ StatusOr LowerDTensorSend(mlir::Operation* send_op, // Is tensor transfer is from TPU mesh to host mesh and send layout and recv // layout is identical, then tensor from each source device is sent to // target device asynchronously. + mlir::Operation* lowered_send; if (one_to_one && IsTpuToHostMeshTransfer(input_mesh, target_mesh)) { - return LowerDTensorSendToXlaOp(input_layout, dtensor_send.getInput(), - dtensor_send, - /*send_from_device_zero=*/false); + TF_ASSIGN_OR_RETURN(lowered_send, + LowerDTensorSendToXlaOp( + input_layout, dtensor_send.getInput(), dtensor_send, + /*send_from_device_zero=*/false)); } else if (one_to_one && IsGpuToHostMeshTransfer(input_mesh, target_mesh) && !recv_layout.IsFullyReplicated()) { - return LowerOneToOneDTensorSendToTFHostSend(input_layout, target_mesh, - dtensor_send); - } - - // Calculate input tensor layout of data to send and target fully replicated - // layout. For now, we ensure that all data transfer happen with fully - // replicated tensors. - const int rank = ValueRank(dtensor_send.getInput()); - const Layout target_layout = Layout::ReplicatedOnMesh(input_mesh, rank); - - // Convert tensor to send to replicated layout. - mlir::OpBuilder builder(dtensor_send); - TF_ASSIGN_OR_RETURN(mlir::Value send_input, - EmitAllGather(builder, dtensor_send.getInput(), - input_layout, target_layout)); - - // Insert control flow such that only device with device ordinal == 0 sends - // the tensor data across mesh. - auto send_cluster = - dtensor_send->getParentOfType(); - TF_ASSIGN_OR_RETURN(std::optional mesh, - ExtractDeviceMeshFromOp(send_cluster)); - if (!mesh.has_value()) - return errors::InvalidArgument( - "failed to lower DTensor CopyToMesh op as sending side mesh is not " - "specified."); - - mlir::Location loc = dtensor_send.getLoc(); - TF_ASSIGN_OR_RETURN( - mlir::Value device_ordinal, - GetDeviceOrdinal(*mesh, loc, - send_cluster->getParentOfType(), - &builder)); - mlir::Value predicate = builder.create( - loc, device_ordinal, CreateIntScalarConst(0, builder, loc), - /*incompatible_shape_error=*/builder.getBoolAttr(true)); - - auto send_if = builder.create( - loc, llvm::SmallVector{}, predicate, - /*is_stateless=*/builder.getBoolAttr(true), - GetUniqueControlflowFnName("copy_to_mesh_send_if_then", builder), - GetUniqueControlflowFnName("copy_to_mesh_send_if_else", builder)); - - // Create empty else branch region. - auto& else_branch = send_if.getElseBranch(); - else_branch.push_back(new mlir::Block); - builder.setInsertionPointToEnd(&else_branch.front()); - builder.create(loc, - /*operands=*/llvm::ArrayRef{}); - - // Create then branch region with DTensorSend op. - auto& then_branch = send_if.getThenBranch(); - then_branch.push_back(new mlir::Block); - builder.setInsertionPointToEnd(&then_branch.front()); - auto yield = builder.create( - loc, /*operands=*/llvm::ArrayRef{}); - dtensor_send->moveBefore(yield); - - // Lower DTensorSend op to actual TF op. - TF_ASSIGN_OR_RETURN(const Mesh recv_mesh, - ExtractDeviceMeshEnclosingCluster(recv_op)); - mlir::Operation* lowered_send; - if (SendRecvOpUsesXla(input_layout.mesh(), recv_mesh)) { - // Lower DTensorSend op to Xla Send ops. - TF_ASSIGN_OR_RETURN( - lowered_send, - LowerDTensorSendToXlaOp(input_layout, send_input, dtensor_send, - /*send_from_device_zero=*/true)); - } else if (input_layout.mesh().is_cpu_mesh() && recv_mesh.is_cpu_mesh()) { - // Lower DTensorSend op to TF Host Send op. - TF_ASSIGN_OR_RETURN( - lowered_send, - LowerDTensorSendFromCPUToTFOp(input_layout, send_input, dtensor_send)); + TF_ASSIGN_OR_RETURN(lowered_send, + LowerOneToOneDTensorSendToTFHostSend( + input_layout, target_mesh, dtensor_send)); } else { - mlir::TensorType send_type = send_input.getType().cast(); - if (!recv_mesh.is_cpu_mesh() && send_type.getElementType().isInteger(32)) { - builder.setInsertionPointAfter(send_input.getDefiningOp()); - auto cast_to_int64 = builder.create( - send_input.getLoc(), - mlir::RankedTensorType::get(send_type.getShape(), - builder.getIntegerType(64)), - send_input); - send_input = cast_to_int64->getResult(0); + // Calculate input tensor layout of data to send and target fully replicated + // layout. For now, we ensure that all data transfer happen with fully + // replicated tensors. + const int rank = ValueRank(dtensor_send.getInput()); + const Layout target_layout = Layout::ReplicatedOnMesh(input_mesh, rank); + + // Convert tensor to send to replicated layout. + mlir::OpBuilder builder(dtensor_send); + TF_ASSIGN_OR_RETURN(mlir::Value send_input, + EmitAllGather(builder, dtensor_send.getInput(), + input_layout, target_layout)); + + // Insert control flow such that only device with device ordinal == 0 sends + // the tensor data across mesh. + auto send_cluster = + dtensor_send->getParentOfType(); + TF_ASSIGN_OR_RETURN(std::optional mesh, + ExtractDeviceMeshFromOp(send_cluster)); + if (!mesh.has_value()) { + return absl::InvalidArgumentError( + "failed to lower DTensor CopyToMesh op as sending side mesh is not " + "specified."); } + + mlir::Location loc = dtensor_send.getLoc(); TF_ASSIGN_OR_RETURN( - lowered_send, - LowerDTensorSendToTFOp(input_layout, send_input, dtensor_send)); + mlir::Value device_ordinal, + GetDeviceOrdinal(*mesh, loc, + send_cluster->getParentOfType(), + &builder)); + mlir::Value predicate = builder.create( + loc, device_ordinal, CreateIntScalarConst(0, builder, loc), + /*incompatible_shape_error=*/builder.getBoolAttr(true)); + + auto send_if = builder.create( + loc, llvm::SmallVector{}, predicate, + /*is_stateless=*/builder.getBoolAttr(true), + GetUniqueControlflowFnName("copy_to_mesh_send_if_then", builder), + GetUniqueControlflowFnName("copy_to_mesh_send_if_else", builder)); + + // Create empty else branch region. + auto& else_branch = send_if.getElseBranch(); + else_branch.push_back(new mlir::Block); + builder.setInsertionPointToEnd(&else_branch.front()); + builder.create( + loc, + /*operands=*/llvm::ArrayRef{}); + + // Create then branch region with DTensorSend op. + auto& then_branch = send_if.getThenBranch(); + then_branch.push_back(new mlir::Block); + builder.setInsertionPointToEnd(&then_branch.front()); + auto yield = builder.create( + loc, /*operands=*/llvm::ArrayRef{}); + dtensor_send->moveBefore(yield); + + // Lower DTensorSend op to actual TF op. + TF_ASSIGN_OR_RETURN(const Mesh recv_mesh, + ExtractDeviceMeshEnclosingCluster(recv_op)); + if (SendRecvOpUsesXla(input_layout.mesh(), recv_mesh)) { + // Lower DTensorSend op to Xla Send ops. + TF_ASSIGN_OR_RETURN( + lowered_send, + LowerDTensorSendToXlaOp(input_layout, send_input, dtensor_send, + /*send_from_device_zero=*/true)); + } else if (input_layout.mesh().is_cpu_mesh() && recv_mesh.is_cpu_mesh()) { + // Lower DTensorSend op to TF Host Send op. + TF_ASSIGN_OR_RETURN( + lowered_send, LowerDTensorSendFromCPUToTFOp(input_layout, send_input, + dtensor_send)); + } else { + mlir::TensorType send_type = + send_input.getType().cast(); + if (!recv_mesh.is_cpu_mesh() && + send_type.getElementType().isInteger(32)) { + builder.setInsertionPointAfter(send_input.getDefiningOp()); + auto cast_to_int64 = builder.create( + send_input.getLoc(), + mlir::RankedTensorType::get(send_type.getShape(), + builder.getIntegerType(64)), + send_input); + send_input = cast_to_int64->getResult(0); + } + TF_ASSIGN_OR_RETURN( + lowered_send, + LowerDTensorSendToTFOp(input_layout, send_input, dtensor_send)); + } } return lowered_send; @@ -655,8 +661,7 @@ StatusOr LowerDTensorRecv(mlir::Operation* send_op, const Mesh& recv_mesh = recv_layout.mesh(); mlir::OpBuilder builder(dtensor_recv); - bool cpu_to_cpu = - dtensor_recv.getLayout().mesh().is_cpu_mesh() && send_mesh.is_cpu_mesh(); + bool cpu_to_cpu = recv_mesh.is_cpu_mesh() && send_mesh.is_cpu_mesh(); bool one_to_one = IsOneToOneMeshTransfer(send_layout, recv_layout); bool send_recv_xla = SendRecvOpUsesXla(send_mesh, recv_mesh); @@ -672,137 +677,134 @@ StatusOr LowerDTensorRecv(mlir::Operation* send_op, } return lowered_recv; - } else if (send_recv_xla || !cpu_to_cpu) { - if (send_recv_xla && - ((one_to_one && IsTpuToHostMeshTransfer(send_mesh, recv_mesh)) || - recv_mesh.is_cpu_mesh())) { - // Recv can be lowered directly for a 1-to-1 transfer between host and - // device (*for XLA/TPUs). - TF_ASSIGN_OR_RETURN(mlir::TensorType local_output_type, - LocalTypeFromGlobalType( - dtensor_recv.getLayout(), - dtensor_recv.getType().cast())); - TF_ASSIGN_OR_RETURN(lowered_recv, LowerDTensorRecvToXlaOp( - dtensor_recv, local_output_type)); - dtensor_recv->replaceAllUsesWith(lowered_recv); - dtensor_recv.erase(); - } else { - // Choose which receive lowering function to use. - auto lower_fn = - send_recv_xla - ? (decltype(&LowerDTensorRecvToTFOp))LowerDTensorRecvToXlaOp - : LowerDTensorRecvToTFOp; - - // For other send/recv layouts, the tensor needs to be replicated. - if (!dtensor_recv.getLayout().IsFullyReplicated()) { - return errors::InvalidArgument( - "CopyToMesh where target mesh is GPU/TPU requires a replicated " - "target layout."); - } + } else if (cpu_to_cpu) { + // Lower DTensorRecv op to TF Host Recv op. + TF_ASSIGN_OR_RETURN(lowered_recv, + LowerDTensorRecvFromCPUToTFOp(send_mesh, dtensor_recv)); + } else if ((one_to_one && IsTpuToHostMeshTransfer(send_mesh, recv_mesh)) || + (send_recv_xla && recv_mesh.is_cpu_mesh())) { + // Recv can be lowered directly for a 1-to-1 transfer between host and + // device (*for XLA/TPUs). + TF_ASSIGN_OR_RETURN(mlir::TensorType local_output_type, + LocalTypeFromGlobalType( + dtensor_recv.getLayout(), + dtensor_recv.getType().cast())); + TF_ASSIGN_OR_RETURN( + lowered_recv, LowerDTensorRecvToXlaOp(dtensor_recv, local_output_type)); + dtensor_recv->replaceAllUsesWith(lowered_recv); + dtensor_recv.erase(); + } else { + // Choose which receive lowering function to use. + auto lower_fn = + send_recv_xla + ? (decltype(&LowerDTensorRecvToTFOp))LowerDTensorRecvToXlaOp + : LowerDTensorRecvToTFOp; + + // For other send/recv layouts, the tensor needs to be replicated. + if (!dtensor_recv.getLayout().IsFullyReplicated()) { + return absl::InvalidArgumentError( + "CopyToMesh where target mesh is GPU/TPU requires a replicated " + "target layout."); + } - // For Receiving at GPU/TPU, only device 0 (ordinal) receives from the - // host, then it shares the tensor with its peers. - auto recv_cluster = - dtensor_recv->getParentOfType(); - mlir::Location loc = dtensor_recv.getLoc(); - TF_ASSIGN_OR_RETURN( - mlir::Value device_ordinal, - GetDeviceOrdinal(recv_mesh, loc, - recv_cluster->getParentOfType(), - &builder)); - mlir::Value predicate = builder.create( - loc, device_ordinal, CreateIntScalarConst(0, builder, loc), - /*incompatible_shape_error=*/builder.getBoolAttr(true)); - - mlir::TensorType recv_type = dtensor_recv.getType(); - bool i32_copy = recv_type.getElementType().isInteger(32); - bool need_i32_to_i64_upcast = - i32_copy && !(recv_mesh.is_cpu_mesh() || send_recv_xla); - mlir::TensorType output_type = - need_i32_to_i64_upcast - ? mlir::RankedTensorType::get(recv_type.getShape(), - builder.getIntegerType(64)) - : recv_type; - - auto recv_if = builder.create( - loc, llvm::SmallVector{output_type}, predicate, - /*is_stateless=*/builder.getBoolAttr(true), - GetUniqueControlflowFnName("copy_to_mesh_recv_if_then", builder), - GetUniqueControlflowFnName("copy_to_mesh_recv_if_else", builder)); - - // Create empty else branch region that outputs zeros. - auto& else_branch = recv_if.getElseBranch(); - else_branch.push_back(new mlir::Block); - builder.setInsertionPointToEnd(&else_branch.front()); - - // Create a zero constant. - mlir::Attribute const_attr; - auto output_element_type = output_type.getElementType(); - if (output_element_type.isIntOrIndex()) { - if (output_element_type.isInteger(64)) { - const_attr = mlir::DenseIntElementsAttr::get( - output_type, llvm::SmallVector{0}); - } else { - const_attr = mlir::DenseIntElementsAttr::get( - output_type, llvm::SmallVector{0}); - } - } else if (output_element_type.isBF16()) { - mlir::FloatAttr zero = mlir::FloatAttr::get(output_element_type, 0.); - const_attr = mlir::DenseElementsAttr::get( - output_type, llvm::SmallVector{zero}); - } else if (output_element_type.isF16() || output_element_type.isF32()) { - const_attr = mlir::DenseFPElementsAttr::get( - output_type, llvm::SmallVector{0.0}); - } else if (output_element_type.isF64()) { - const_attr = mlir::DenseFPElementsAttr::get( - output_type, llvm::SmallVector{0.0}); + // For Receiving at GPU/TPU, only device 0 (ordinal) receives from the + // host, then it shares the tensor with its peers. + auto recv_cluster = + dtensor_recv->getParentOfType(); + mlir::Location loc = dtensor_recv.getLoc(); + TF_ASSIGN_OR_RETURN( + mlir::Value device_ordinal, + GetDeviceOrdinal(recv_mesh, loc, + recv_cluster->getParentOfType(), + &builder)); + mlir::Value predicate = builder.create( + loc, device_ordinal, CreateIntScalarConst(0, builder, loc), + /*incompatible_shape_error=*/builder.getBoolAttr(true)); + + mlir::TensorType recv_type = dtensor_recv.getType(); + bool i32_copy = recv_type.getElementType().isInteger(32); + bool need_i32_to_i64_upcast = + i32_copy && !(recv_mesh.is_cpu_mesh() || send_recv_xla); + mlir::TensorType output_type = + need_i32_to_i64_upcast + ? mlir::RankedTensorType::get(recv_type.getShape(), + builder.getIntegerType(64)) + : recv_type; + + auto recv_if = builder.create( + loc, llvm::SmallVector{output_type}, predicate, + /*is_stateless=*/builder.getBoolAttr(true), + GetUniqueControlflowFnName("copy_to_mesh_recv_if_then", builder), + GetUniqueControlflowFnName("copy_to_mesh_recv_if_else", builder)); + + // Create empty else branch region that outputs zeros. + auto& else_branch = recv_if.getElseBranch(); + else_branch.push_back(new mlir::Block); + builder.setInsertionPointToEnd(&else_branch.front()); + + // Create a zero constant. + mlir::Attribute const_attr; + auto output_element_type = output_type.getElementType(); + if (output_element_type.isIntOrIndex()) { + if (output_element_type.isInteger(64)) { + const_attr = mlir::DenseIntElementsAttr::get( + output_type, llvm::SmallVector{0}); } else { - return errors::InvalidArgument("unsupported output type"); + const_attr = mlir::DenseIntElementsAttr::get( + output_type, llvm::SmallVector{0}); } + } else if (output_element_type.isBF16()) { + mlir::FloatAttr zero = mlir::FloatAttr::get(output_element_type, 0.); + const_attr = mlir::DenseElementsAttr::get( + output_type, llvm::SmallVector{zero}); + } else if (output_element_type.isF16() || output_element_type.isF32()) { + const_attr = mlir::DenseFPElementsAttr::get( + output_type, llvm::SmallVector{0.0}); + } else if (output_element_type.isF64()) { + const_attr = mlir::DenseFPElementsAttr::get( + output_type, llvm::SmallVector{0.0}); + } else { + return absl::InvalidArgumentError("unsupported output type"); + } - mlir::Value zeros = builder.create(loc, const_attr); - builder.create( - loc, /*operands=*/llvm::ArrayRef{zeros}); - - // Create then branch region with DTensorRecv op. - auto& then_branch = recv_if.getThenBranch(); - then_branch.push_back(new mlir::Block); - builder.setInsertionPointToEnd(&then_branch.front()); - dtensor_recv->moveBefore(&then_branch.front(), then_branch.front().end()); - - TF_ASSIGN_OR_RETURN(mlir::Operation * xla_recv, - lower_fn(send_mesh, dtensor_recv, output_type)); - builder.create( - loc, - /*operands=*/llvm::ArrayRef{xla_recv->getResult(0)}); - - // Broadcast the received output to all GPU/TPU devices. - mlir::Value if_output = recv_if->getResult(0); - builder.setInsertionPointAfterValue(if_output); - absl::flat_hash_set reduced_dims; - for (const auto& mesh_dim : recv_mesh.dims()) - reduced_dims.insert(mesh_dim.name); - - TF_ASSIGN_OR_RETURN(lowered_recv, - EmitAllReduce(builder, recv_layout, reduced_dims, - recv_if, kReduceOpAdd)); - - if (need_i32_to_i64_upcast) { - lowered_recv = builder.create( - loc, recv_type, lowered_recv->getResult(0)); - } + mlir::Value zeros = builder.create(loc, const_attr); + builder.create( + loc, /*operands=*/llvm::ArrayRef{zeros}); + + // Create then branch region with DTensorRecv op. + auto& then_branch = recv_if.getThenBranch(); + then_branch.push_back(new mlir::Block); + builder.setInsertionPointToEnd(&then_branch.front()); + dtensor_recv->moveBefore(&then_branch.front(), then_branch.front().end()); + + TF_ASSIGN_OR_RETURN(mlir::Operation * xla_recv, + lower_fn(send_mesh, dtensor_recv, output_type)); + builder.create( + loc, + /*operands=*/llvm::ArrayRef{xla_recv->getResult(0)}); + + // Broadcast the received output to all GPU/TPU devices. + mlir::Value if_output = recv_if->getResult(0); + builder.setInsertionPointAfterValue(if_output); + absl::flat_hash_set reduced_dims; + for (const auto& mesh_dim : recv_mesh.dims()) + reduced_dims.insert(mesh_dim.name); - // Replaces usages of DTensorRecv op with the broadcasted value. - dtensor_recv.getOutput().replaceUsesWithIf( - lowered_recv->getResult(0), [&](mlir::OpOperand& operand) { - return !recv_if->isProperAncestor(operand.getOwner()); - }); - dtensor_recv.erase(); + TF_ASSIGN_OR_RETURN( + lowered_recv, EmitAllReduce(builder, recv_layout, reduced_dims, recv_if, + kReduceOpAdd)); + + if (need_i32_to_i64_upcast) { + lowered_recv = builder.create( + loc, recv_type, lowered_recv->getResult(0)); } - } else { - // Lower DTensorRecv op to TF Host Recv op. - TF_ASSIGN_OR_RETURN(lowered_recv, - LowerDTensorRecvFromCPUToTFOp(send_mesh, dtensor_recv)); + + // Replaces usages of DTensorRecv op with the broadcasted value. + dtensor_recv.getOutput().replaceUsesWithIf( + lowered_recv->getResult(0), [&](mlir::OpOperand& operand) { + return !recv_if->isProperAncestor(operand.getOwner()); + }); + dtensor_recv.erase(); } llvm::SmallPtrSet newly_created_ops; diff --git a/tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.cc index fa9ce8c6435c69..4b97bd0786641a 100644 --- a/tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.cc @@ -69,10 +69,8 @@ StatusOr BiasAddExpander::ExpandOp(mlir::Operation* op) { // Check if output is sharded more, change input layout to match output // layout. - int64_t num_input_shards = - input_layout.num_shards_for_dim(input_layout.dim(c_dim_idx)); - int64_t num_output_shards = - output_layout.num_shards_for_dim(output_layout.dim(c_dim_idx)); + int64_t num_input_shards = input_layout.num_shards_for_dim(c_dim_idx); + int64_t num_output_shards = output_layout.num_shards_for_dim(c_dim_idx); if (num_input_shards < num_output_shards) { mlir::Value output; diff --git a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc index 6b439730931bfe..eff66fa0354784 100644 --- a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc @@ -30,13 +30,6 @@ limitations under the License. namespace tensorflow { namespace dtensor { -namespace { - -bool Equal(const ShardingSpec& a, const ShardingSpec& b) { - return a.sharding_spec() == b.sharding_spec(); -} - -} // namespace // Einsum, like reductions, is implemented as a local operation followed by // an all-reduce over dimensions that have been reduced. @@ -150,10 +143,10 @@ Status ExtractEquationRelations( // sharding specs we raise an error if replicate_incompatible_dimensions is // false. Otherwise we treat the dimension as if it were unsharded. // Labels with unsharded dimensions are not recorded in the output. -StatusOr> GetLabelToShardingSpec( +StatusOr> GetLabelToShardingSpec( bool replicate_incompatible_dimensions, const std::vector& layouts, const std::vector>>& mappings) { - absl::flat_hash_map label_to_sharding_spec; + absl::flat_hash_map label_to_sharding_spec; absl::flat_hash_set incompatible_labels; // For each mapping, identify the mesh dimension and whether it has been @@ -171,23 +164,23 @@ StatusOr> GetLabelToShardingSpec( layouts[index].rank()) .str()); - const ShardingSpec& sharding_spec = layouts[index].dim(offset); + const std::string& sharding_spec = layouts[index].sharding_spec(offset); if (label_to_sharding_spec.contains(mapping.first)) { - if (Layout::IsShardedSpec(sharding_spec) && - !Equal(label_to_sharding_spec[mapping.first], sharding_spec)) { + if (Layout::IsShardedDimension(sharding_spec) && + label_to_sharding_spec[mapping.first] != sharding_spec) { if (!replicate_incompatible_dimensions) return errors::InvalidArgument( llvm::formatv( "incompatible mesh dimensions in equation, label '{0}' " "is mapped to mesh dimension '{1}' and '{2}'", - mapping.first, sharding_spec.sharding_spec(), - label_to_sharding_spec[mapping.first].sharding_spec()) + mapping.first, sharding_spec, + label_to_sharding_spec[mapping.first]) .str()); else incompatible_labels.insert(mapping.first); } - } else if (Layout::IsShardedSpec(sharding_spec)) { + } else if (Layout::IsShardedDimension(sharding_spec)) { label_to_sharding_spec[mapping.first] = sharding_spec; } } @@ -205,42 +198,41 @@ StatusOr> GetLabelToShardingSpec( // multiple times. E.g. ab,bc->ac (i.e. matmul) with a and c sharded over the // same dim. In this case we mark all such dimensions as replicated. StatusOr VerifyOrFixLayout( - std::pair, absl::flat_hash_map> + std::pair, absl::flat_hash_map> pair, const Mesh& mesh) { - std::vector sharding_specs = pair.first; + std::vector sharding_specs = pair.first; absl::flat_hash_map dimension_use_count = pair.second; for (int i = 0; i < sharding_specs.size(); ++i) - if (Layout::IsShardedSpec(sharding_specs[i]) && - dimension_use_count[sharding_specs[i].sharding_spec()] > 1) - sharding_specs[i].set_sharding_spec(Layout::kUnshardedDim); + if (Layout::IsShardedDimension(sharding_specs[i]) && + dimension_use_count[sharding_specs[i]] > 1) + sharding_specs[i] = Layout::kUnshardedDim; return Layout::GetLayout(sharding_specs, mesh); } // Construct a layout on a given mesh from the label to tensor dimension map // and the label to mesh_dimension map. -std::pair, absl::flat_hash_map> +std::pair, absl::flat_hash_map> GetSpecsFromLabelsAndMap( const absl::flat_hash_map>& label_to_index, - const absl::flat_hash_map& label_to_sharding_spec) { + const absl::flat_hash_map& label_to_sharding_spec) { int layout_rank = 0; for (const auto& label_and_indices : label_to_index) layout_rank += label_and_indices.second.size(); - std::vector sharding_specs(layout_rank); + std::vector sharding_specs(layout_rank); absl::flat_hash_map dimension_use_count; absl::flat_hash_set dimension_use_set; for (const auto& label_and_indices : label_to_index) { const auto& loc = label_to_sharding_spec.find(label_and_indices.first); if (loc != label_to_sharding_spec.end()) { - const ShardingSpec& sharding_spec = loc->second; + const std::string& sharding_spec = loc->second; for (int index : label_and_indices.second) sharding_specs[index] = sharding_spec; - dimension_use_count[sharding_spec.sharding_spec()] += - label_and_indices.second.size(); + dimension_use_count[sharding_spec] += label_and_indices.second.size(); } else { for (int index : label_and_indices.second) - sharding_specs[index].set_sharding_spec(Layout::kUnshardedDim); + sharding_specs[index] = Layout::kUnshardedDim; } } return std::make_pair(sharding_specs, dimension_use_count); @@ -353,11 +345,11 @@ StatusOr> EinsumSPMDExpander::ComputeLayoutBackward( for (size_t i = 0; i < num_inputs; ++i) { absl::flat_hash_map> labels_to_indices = input_mappings[i]; - std::pair, absl::flat_hash_map> + std::pair, absl::flat_hash_map> sharding_specs_and_dim_count = GetSpecsFromLabelsAndMap( labels_to_indices, output_label_to_sharding_spec); - std::vector sharding_specs = + std::vector sharding_specs = sharding_specs_and_dim_count.first; absl::flat_hash_map dim_count = sharding_specs_and_dim_count.second; @@ -367,7 +359,7 @@ StatusOr> EinsumSPMDExpander::ComputeLayoutBackward( char label = label_to_indices.first; if (labels_for_any.contains(label)) { int index = label_to_indices.second[0]; - sharding_specs[index].set_sharding_spec(Layout::kAny); + sharding_specs[index] = Layout::kAny; } } TF_ASSIGN_OR_RETURN( @@ -422,13 +414,12 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( for (const char label : all_labels) { if (input_label_to_sharding_spec.contains(label) && output_label_to_sharding_spec.contains(label) && - !Equal(input_label_to_sharding_spec[label], - output_label_to_sharding_spec.find(label)->second)) + input_label_to_sharding_spec[label] != + output_label_to_sharding_spec.find(label)->second) return errors::InvalidArgument( "for label ", label, " input and output layouts are sharded on ", - " non-equal dimensions ", - input_label_to_sharding_spec[label].sharding_spec(), " and ", - output_label_to_sharding_spec.find(label)->second.sharding_spec(), + " non-equal dimensions ", input_label_to_sharding_spec[label], + " and ", output_label_to_sharding_spec.find(label)->second, "respectively"); } @@ -438,23 +429,21 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( for (const auto& input_mapping : input_mappings) for (const auto& char_and_positions : input_mapping) if (char_and_positions.second.size() > 1) - input_label_to_sharding_spec[char_and_positions.first] - .set_sharding_spec(Layout::kUnshardedDim); + input_label_to_sharding_spec[char_and_positions.first] = + Layout::kUnshardedDim; absl::flat_hash_map> sharding_dim_to_non_contracting_labels; absl::flat_hash_map> sharding_dim_to_contracting_labels; for (const auto& label_and_spec : input_label_to_sharding_spec) { - if (Layout::IsShardedSpec(label_and_spec.second)) { + if (Layout::IsShardedDimension(label_and_spec.second)) { if (contracting_labels.contains(label_and_spec.first)) - sharding_dim_to_contracting_labels[label_and_spec.second - .sharding_spec()] - .insert(label_and_spec.first); + sharding_dim_to_contracting_labels[label_and_spec.second].insert( + label_and_spec.first); else - sharding_dim_to_non_contracting_labels[label_and_spec.second - .sharding_spec()] - .insert(label_and_spec.first); + sharding_dim_to_non_contracting_labels[label_and_spec.second].insert( + label_and_spec.first); } } @@ -469,12 +458,11 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( if (!contracting_labels.contains(label) && output_label_to_sharding_spec.contains(label) && !input_label_to_sharding_spec.contains(label)) { - const ShardingSpec& sharding_spec = + const std::string& string_spec = output_label_to_sharding_spec.find(label)->second; - const std::string& string_spec = sharding_spec.sharding_spec(); if (!sharding_dim_to_non_contracting_labels.contains(string_spec) && !sharding_dim_to_contracting_labels.contains(string_spec)) { - input_label_to_sharding_spec[label] = sharding_spec; + input_label_to_sharding_spec[label] = string_spec; sharding_dim_to_non_contracting_labels[string_spec].insert(label); } } @@ -503,8 +491,7 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( // keep this stable with respect to ordering. for (const char label : sharding_dim_to_non_contracting_labels[dim]) { if (output_label_to_sharding_spec.contains(label) && - output_label_to_sharding_spec.find(label)->second.sharding_spec() == - dim) { + output_label_to_sharding_spec.find(label)->second == dim) { label_to_keep = label; break; } else if (label < label_to_keep) { @@ -513,8 +500,7 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( } for (const char label : sharding_dim_to_non_contracting_labels[dim]) if (label != label_to_keep) - input_label_to_sharding_spec[label].set_sharding_spec( - Layout::kUnshardedDim); + input_label_to_sharding_spec[label] = Layout::kUnshardedDim; sharding_dim_to_non_contracting_labels[dim].clear(); sharding_dim_to_non_contracting_labels[dim].insert(label_to_keep); } @@ -530,8 +516,8 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( assert(spec_and_labels.second.size() == 1); assert(sharding_dim_to_non_contracting_labels[spec_and_labels.first] .size() == 1); - input_label_to_sharding_spec[*spec_and_labels.second.begin()] - .set_sharding_spec(Layout::kUnshardedDim); + input_label_to_sharding_spec[*spec_and_labels.second.begin()] = + Layout::kUnshardedDim; } } @@ -557,8 +543,7 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( output_layout.mesh())); for (const auto& contracting : contracting_labels) - reduce_dims.emplace( - input_label_to_sharding_spec[contracting].sharding_spec()); + reduce_dims.emplace(input_label_to_sharding_spec[contracting]); return OkStatus(); } diff --git a/tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.cc index e8e24c9b5a2efb..216025d62f5ee5 100644 --- a/tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.cc @@ -164,7 +164,7 @@ ElementwiseSPMDExpander::ComputeLayoutBackward( TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(operand)); Layout output_layout_truncated = output_layout.Truncate( - output_layout.sharding_specs().size() - operand_shape.size(), + output_layout.sharding_spec_strs().size() - operand_shape.size(), /*end=*/true); auto inferred_operand_layout_strs = output_layout_truncated.sharding_spec_strs(); diff --git a/tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.cc index d2c58de60a3423..1151d46cb9159b 100644 --- a/tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/types/optional.h" #include "tensorflow/dtensor/mlir/collectives.h" @@ -47,14 +48,14 @@ StatusOr ExpandDimsExpander::ExpandOp(mlir::Operation* op) { ExtractConstIntFromValue(expand_dims_op.getDim())); if (dim < 0) dim += global_output_shape.size(); - std::vector sharding_specs(global_output_shape.size()); + std::vector sharding_specs(global_output_shape.size()); for (int i = 0; i < global_output_shape.size(); ++i) { if (i < dim) - sharding_specs[i] = operand_layout->dim(i); + sharding_specs[i] = operand_layout->sharding_spec(i); else if (i == dim) - sharding_specs[i].set_sharding_spec(Layout::kUnshardedDim); + sharding_specs[i] = Layout::kUnshardedDim; else - sharding_specs[i] = operand_layout->dim(i - 1); + sharding_specs[i] = operand_layout->sharding_spec(i - 1); } TF_ASSIGN_OR_RETURN(const Layout current_output_layout, Layout::GetLayout(sharding_specs, output_layout->mesh())); diff --git a/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc index f6bd92b4ea9770..fd42942f7e4130 100644 --- a/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc @@ -209,26 +209,28 @@ StatusOr GatherNdGetOutputLayoutFromInput( // replace them with replicated. // If sharding dimension is used by both params and indices, the params // layout will be respected as generally params is larger than indices. - std::vector output_specs(params_rank - index_dimensions + - indices_rank - 1); + std::vector output_specs(params_rank - index_dimensions + + indices_rank - 1); absl::flat_hash_set used_dimensions; const int params_offset = -index_dimensions + indices_rank - 1; for (int i = index_dimensions; i < params_rank; ++i) { - if (params_layout && Layout::IsShardedSpec(params_layout->dim(i))) { - const ShardingSpec& params_spec = params_layout->dim(i); + if (params_layout && + Layout::IsShardedDimension(params_layout->sharding_spec(i))) { + const auto& params_spec = params_layout->sharding_spec(i); output_specs[i + params_offset] = params_spec; - used_dimensions.emplace(params_spec.sharding_spec()); + used_dimensions.emplace(params_spec); } else { - output_specs[i + params_offset].set_sharding_spec(Layout::kUnshardedDim); + output_specs[i + params_offset] = Layout::kUnshardedDim; } } for (int i = 0; i < indices_rank - 1; ++i) { - if (indices_layout && Layout::IsShardedSpec(indices_layout->dim(i)) && + if (indices_layout && + Layout::IsShardedDimension(indices_layout->sharding_spec(i)) && !used_dimensions.contains(indices_layout->sharding_spec(i))) - output_specs[i] = indices_layout->dim(i); + output_specs[i] = indices_layout->sharding_spec(i); else - output_specs[i].set_sharding_spec(Layout::kUnshardedDim); + output_specs[i] = Layout::kUnshardedDim; } return Layout::GetLayout(output_specs, mesh); } @@ -242,20 +244,20 @@ Status GatherNdGetInputLayoutFromOutput(const Layout& output_layout, // indices_layout (with the last dimensions replicated) and the remaining // dimensions to params_layout (with the first index_dimensions dimensions // replicated). - std::vector params_specs(params_rank); - std::vector indices_specs(indices_rank); + std::vector params_specs(params_rank); + std::vector indices_specs(indices_rank); for (int i = 0; i < index_dimensions; ++i) - params_specs[i].set_sharding_spec(Layout::kUnshardedDim); + params_specs[i] = Layout::kUnshardedDim; const int params_offset = -index_dimensions + indices_rank - 1; for (int i = index_dimensions; i < params_rank; ++i) - params_specs[i] = output_layout.dim(i + params_offset); + params_specs[i] = output_layout.sharding_spec(i + params_offset); for (int i = 0; i < indices_rank - 1; ++i) - indices_specs[i] = output_layout.dim(i); + indices_specs[i] = output_layout.sharding_spec(i); - indices_specs[indices_rank - 1].set_sharding_spec(Layout::kUnshardedDim); + indices_specs[indices_rank - 1] = Layout::kUnshardedDim; TF_ASSIGN_OR_RETURN(*params_layout, Layout::GetLayout(params_specs, mesh)); TF_ASSIGN_OR_RETURN(*indices_layout, Layout::GetLayout(indices_specs, mesh)); @@ -309,21 +311,20 @@ StatusOr GatherNdSPMDExpander::ExpandOp(mlir::Operation* op) { // Step 2) llvm::DenseSet used_dimensions; - for (const ShardingSpec& spec : pre_output_layout.sharding_specs()) - if (Layout::IsShardedSpec(spec)) - used_dimensions.insert(spec.sharding_spec()); + for (const auto& spec : pre_output_layout.sharding_spec_strs()) + if (Layout::IsShardedDimension(spec)) used_dimensions.insert(spec); - std::vector sharding_specs(output_layout.rank()); + std::vector sharding_specs(output_layout.rank()); for (int i = 0; i < sharding_specs.size(); ++i) { - if (Layout::IsShardedSpec(pre_output_layout.dim(i))) - sharding_specs[i] = pre_output_layout.dim(i); + if (Layout::IsShardedDimension(pre_output_layout.sharding_spec(i))) + sharding_specs[i] = pre_output_layout.sharding_spec(i); // Merge in sharded dimensions from the output which aren't already used // by the pre_output_layout. - else if (Layout::IsShardedSpec(output_layout.dim(i)) && + else if (Layout::IsShardedDimension(output_layout.sharding_spec(i)) && !used_dimensions.contains(output_layout.sharding_spec(i))) - sharding_specs[i] = output_layout.dim(i); + sharding_specs[i] = output_layout.sharding_spec(i); else - sharding_specs[i].set_sharding_spec(Layout::kUnshardedDim); + sharding_specs[i] = Layout::kUnshardedDim; } // Step 3) diff --git a/tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.cc index 8771b956097519..caae0ade1822ae 100644 --- a/tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.h" #include +#include #include "absl/types/optional.h" #include "mlir/IR/IRMapping.h" // from @llvm-project @@ -35,9 +36,9 @@ namespace { // layout, ensuring that the 2nd dimension is replicated. StatusOr GetSuggestedPredictionsLayout(const Layout& layout) { // predictions is a rank-2 tensor (batch_size x num_classes) - std::vector layout_specs(2); - layout_specs[0].set_sharding_spec(layout.sharding_spec(0)); - layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); + std::vector layout_specs(2); + layout_specs[0] = layout.sharding_spec(0); + layout_specs[1] = Layout::kUnshardedDim; return Layout::GetLayout(layout_specs, layout.mesh()); } @@ -46,10 +47,10 @@ StatusOr GetSuggestedPredictionsLayout(const Layout& layout) { // of "other_layout". StatusOr MatchBatchDim(const Layout& layout, const Layout& other_layout) { - std::vector layout_specs(layout.rank()); - layout_specs[0].set_sharding_spec(other_layout.sharding_spec(0)); + std::vector layout_specs(layout.rank()); + layout_specs[0] = other_layout.sharding_spec(0); for (int i = 1; i < layout.rank(); ++i) { - layout_specs[i].set_sharding_spec(layout.sharding_spec(i)); + layout_specs[i] = layout.sharding_spec(i); } return Layout::GetLayout(layout_specs, layout.mesh()); diff --git a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc index 6cd8ac7f174599..af109f98a15a07 100644 --- a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc @@ -156,21 +156,19 @@ StatusOr MatMulSPMDExpander::OutputLayoutAndReducedDims( // Input layouts are [batch...],a,b;[batch...],b,c // Output layout is [batch...],a,c - const auto& batch_sharding_specs = batch_layout.sharding_specs(); - std::vector output_dims(batch_sharding_specs.begin(), - batch_sharding_specs.end()); + const auto& batch_sharding_specs = batch_layout.sharding_spec_strs(); + std::vector output_dims(batch_sharding_specs.begin(), + batch_sharding_specs.end()); if (Layout::IsShardedDimension(left_layout.sharding_spec(0)) && left_layout.sharding_spec(0) == right_layout.sharding_spec(1)) { // If a and c above are the same and sharded, we should output a replicated // layout during propagation. This is so we don't create an illegal layout. output_dims.resize(output_dims.size() + 2); - output_dims[output_dims.size() - 2].set_sharding_spec( - Layout::kUnshardedDim); - output_dims[output_dims.size() - 1].set_sharding_spec( - Layout::kUnshardedDim); + output_dims[output_dims.size() - 2] = Layout::kUnshardedDim; + output_dims[output_dims.size() - 1] = Layout::kUnshardedDim; } else { - output_dims.emplace_back(left_layout.dim(0)); - output_dims.emplace_back(right_layout.dim(1)); + output_dims.emplace_back(left_layout.sharding_spec(0)); + output_dims.emplace_back(right_layout.sharding_spec(1)); } return Layout::GetLayout(output_dims, left_layout.mesh()); diff --git a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc index d6896a5540d832..a100300d5320a3 100644 --- a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc @@ -197,20 +197,21 @@ StatusOr UnpackSPMDExpander::ExpandOp(mlir::Operation* op) { TF_ASSIGN_OR_RETURN( int axis, CanonicalizeAxis(unpack.getAxis(), /*packed_rank=*/input_rank)); - if (input_layout->num_shards_for_dim(input_layout->dim(axis)) != 1) { + if (input_layout->num_shards_for_dim(axis) != 1) { // If the axis being unpacked is sharded, relayout to replicated along that // axis since each device needs to split across it. - std::vector new_layout_specs(input_rank); + std::vector new_layout_specs(input_rank); for (int input_index = 0; input_index < input_rank; ++input_index) { if (input_index == axis) { - new_layout_specs[input_index].set_sharding_spec(Layout::kUnshardedDim); + new_layout_specs[input_index] = Layout::kUnshardedDim; } else { - new_layout_specs[input_index] = input_layout->dim(input_index); + new_layout_specs[input_index] = + input_layout->sharding_spec(input_index); } } TF_ASSIGN_OR_RETURN( Layout new_input_layout, - Layout::GetLayout(std::move(new_layout_specs), input_layout->mesh())); + Layout::GetLayout(new_layout_specs, input_layout->mesh())); TF_ASSIGN_OR_RETURN( mlir::Value new_input, EmitRelayout(unpack.getOperand(), *input_layout, new_input_layout)); @@ -250,15 +251,12 @@ Status VerifyPaddedDimensionNotSharded(const Layout& layout, const auto input_shape = input_type.getShape(); const auto output_shape = input_type.getShape(); - for (const auto& dim_shard_and_index : - llvm::enumerate(layout.sharding_specs())) { - const int index = dim_shard_and_index.index(); - const auto& tensor_dimension = dim_shard_and_index.value(); + for (int index = 0; index < layout.rank(); ++index) { const int input_shape_for_dim = input_shape[index]; const int output_shape_for_dim = output_shape[index]; if ((input_shape_for_dim == -1 || output_shape_for_dim == -1 || output_shape_for_dim != input_shape_for_dim) && - layout.num_shards_for_dim(tensor_dimension) > 1) { + layout.num_shards_for_dim(index) > 1) { return errors::InvalidArgument( "Padding over sharded dimension is not allowed."); } @@ -346,11 +344,10 @@ namespace { Status VerifyTileOperandLayout(const Layout& operand_layout, llvm::ArrayRef static_multiples) { for (const auto& tensor_dim_and_multiple : - llvm::zip(operand_layout.sharding_specs(), static_multiples)) { - const auto& tensor_dimension = std::get<0>(tensor_dim_and_multiple); - const int64_t multiple_factor = std::get<1>(tensor_dim_and_multiple); - if (multiple_factor > 1 && - operand_layout.num_shards_for_dim(tensor_dimension) > 1) + llvm::enumerate(static_multiples)) { + const auto& index = tensor_dim_and_multiple.index(); + const int64_t multiple_factor = tensor_dim_and_multiple.value(); + if (multiple_factor > 1 && operand_layout.num_shards_for_dim(index) > 1) return errors::InvalidArgument( "tile op with input sharded at dimension where `multiple` > 1 is not " "supported."); @@ -486,12 +483,11 @@ StatusOr> TileSPMDExpander::ComputeLayoutForward( const Layout input_layout = input_layouts.lookup(0); std::vector output_layout_specs; for (const auto& multiple_and_dim_sharding : - llvm::zip(static_multiple, input_layout.sharding_specs())) { + llvm::zip(static_multiple, input_layout.sharding_spec_strs())) { const int multiple = std::get<0>(multiple_and_dim_sharding); const auto& tensor_dimension = std::get<1>(multiple_and_dim_sharding); - output_layout_specs.push_back(multiple == 1 - ? tensor_dimension.sharding_spec() - : Layout::kUnshardedDim); + output_layout_specs.push_back(multiple == 1 ? tensor_dimension + : Layout::kUnshardedDim); } TF_ASSIGN_OR_RETURN(const Layout output_layout, @@ -542,12 +538,11 @@ StatusOr> TileSPMDExpander::ComputeLayoutBackward( const Layout output_layout = output_layouts.lookup(0); std::vector input_layout_specs; for (const auto& multiple_and_dim_sharding : - llvm::zip(static_multiple, output_layout.sharding_specs())) { + llvm::zip(static_multiple, output_layout.sharding_spec_strs())) { const int multiple = std::get<0>(multiple_and_dim_sharding); const auto& tensor_dimension = std::get<1>(multiple_and_dim_sharding); - input_layout_specs.push_back(multiple == 1 - ? tensor_dimension.sharding_spec() - : Layout::kUnshardedDim); + input_layout_specs.push_back(multiple == 1 ? tensor_dimension + : Layout::kUnshardedDim); } TF_ASSIGN_OR_RETURN(const Layout input_layout, Layout::GetLayout(input_layout_specs, mesh)); @@ -648,8 +643,8 @@ StatusOr MakeLayoutForReshape( // first entry of the input segment divides the output shape on the first // entry of the output segment, we request a sharded layout on that axis. for (int i = 0; i < input_segment_start.size(); ++i) { - const int num_shards = input_layout.num_shards_for_dim( - input_layout.dim(input_segment_start[i])); + const int num_shards = + input_layout.num_shards_for_dim(input_segment_start[i]); if (output_shape[output_segment_start[i]] % num_shards == 0) layout_specs[output_segment_start[i]] = input_layout.sharding_spec(input_segment_start[i]); @@ -711,8 +706,8 @@ StatusOr ReshapeSPMDExpander::ExpandOp(mlir::Operation* op) { // inserted. For example, reshape a [2, 4, 3] shape tensor with layout // ['not_sharded', 'x', 'not_sharded'] to [2, 12] shape tensor fully // replicated can be supported. - std::vector tgt_input_layout(input_layout->rank()); - std::vector tgt_output_layout(output_layout->rank()); + std::vector tgt_input_layout(input_layout->rank()); + std::vector tgt_output_layout(output_layout->rank()); for (int i = 0; i < input_segment_start.size(); ++i) { const int input_start = input_segment_start[i]; @@ -724,14 +719,13 @@ StatusOr ReshapeSPMDExpander::ExpandOp(mlir::Operation* op) { // Between this segment and the last segment, if there is a gap, insert // dimensions of size 1 and kUnshardedDim as output layout dim. for (int j = prev_input_segment_end; j < input_start; ++j) - tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_input_layout[j] = Layout::kUnshardedDim; for (int j = prev_output_segment_end; j < output_start; ++j) { local_reshape_const.emplace_back(1); - tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_output_layout[j] = Layout::kUnshardedDim; } - const int num_input_shards = - input_layout->num_shards_for_dim(input_layout->dim(input_start)); + const int num_input_shards = input_layout->num_shards_for_dim(input_start); // Decide on the sharding of the input for this segment. // If the input is already sharded, we try to keep this sharding (unless @@ -740,31 +734,32 @@ StatusOr ReshapeSPMDExpander::ExpandOp(mlir::Operation* op) { // we could 'preshard' the input on this dimension before the reshape. // This is unlikely to have any major gains in performance. if (global_output_shape[output_start] % num_input_shards != 0) { - tgt_input_layout[input_start].set_sharding_spec(Layout::kUnshardedDim); - tgt_output_layout[output_start].set_sharding_spec(Layout::kUnshardedDim); + tgt_input_layout[input_start] = Layout::kUnshardedDim; + tgt_output_layout[output_start] = Layout::kUnshardedDim; local_reshape_const.emplace_back(global_output_shape[output_start]); } else { - tgt_input_layout[input_start] = input_layout->dim(input_start); - tgt_output_layout[output_start] = input_layout->dim(input_start); + tgt_input_layout[input_start] = input_layout->sharding_spec(input_start); + tgt_output_layout[output_start] = + input_layout->sharding_spec(input_start); local_reshape_const.emplace_back(global_output_shape[output_start] / num_input_shards); } for (int j = input_start + 1; j < input_segment_end[i]; ++j) - tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_input_layout[j] = Layout::kUnshardedDim; for (int j = output_start + 1; j < output_segment_end[i]; ++j) { local_reshape_const.emplace_back(global_output_shape[j]); - tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_output_layout[j] = Layout::kUnshardedDim; } } // Fill any remaining dimensions of size 1 and sharding dim on the end of the // layout. for (int j = input_segment_end.back(); j < tgt_input_layout.size(); ++j) - tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_input_layout[j] = Layout::kUnshardedDim; for (int j = output_segment_end.back(); j < tgt_output_layout.size(); ++j) { local_reshape_const.emplace_back(1); - tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_output_layout[j] = Layout::kUnshardedDim; } TF_ASSIGN_OR_RETURN( @@ -889,8 +884,8 @@ StatusOr TransposeSPMDExpander::ExpandOp( TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(transpose.getPerm(), &perm)); for (const auto& p : llvm::enumerate(perm)) { - if (operand_layout->dim(p.value()).sharding_spec() != - output_layout->dim(p.index()).sharding_spec()) { + if (operand_layout->sharding_spec(p.value()) != + output_layout->sharding_spec(p.index())) { return errors::InvalidArgument( "TransposeOp SPMD needs communication is not supported yet. \n " "operand layout: ", @@ -969,12 +964,12 @@ Status RelayoutOneHotInput(const absl::optional& input_layout, " SPMD expansion. Consider adding Relayout() op to specify the " "layout."); - std::vector sharding_specs(input_layout->rank()); + std::vector sharding_specs(input_layout->rank()); for (int i = 0; i < input_layout->rank(); ++i) { if (i < axis) - sharding_specs[i] = output_layout->dim(i); + sharding_specs[i] = output_layout->sharding_spec(i); else - sharding_specs[i] = output_layout->dim(i + 1); + sharding_specs[i] = output_layout->sharding_spec(i + 1); } TF_ASSIGN_OR_RETURN(const Layout new_input_layout, Layout::GetLayout(sharding_specs, input_layout->mesh())); diff --git a/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc index dafb83d289e6aa..20ca961d55c8f4 100644 --- a/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc @@ -63,8 +63,7 @@ StatusOr NullarySPMDExpander::ExpandOp(mlir::Operation* op) { auto shape = dense.getType().getShape(); std::vector new_shape(dense.getType().getRank()); for (int i = 0; i < op_layouts[0]->rank(); ++i) { - const int num_shards = - op_layouts[0]->num_shards_for_dim(op_layouts[0]->dim(i)); + const int num_shards = op_layouts[0]->num_shards_for_dim(i); if (shape[i] % num_shards != 0) return errors::InvalidArgument( "has output dimension size ", shape[i], diff --git a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc index d6ec2663fc951b..32e6f4315849ba 100644 --- a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc @@ -85,10 +85,9 @@ StatusOr GetDeviceSeed(const Layout& layout, mlir::Operation* op) { // to use as the attribute attached to the squeeze op. llvm::SmallVector layout_dims; llvm::SmallSet layout_dims_set; - for (const ShardingSpec& spec : layout.sharding_specs()) { - if (Layout::IsUnshardedSpec(spec)) continue; - layout_dims.emplace_back( - layout.mesh().GetMeshDimIndexWithName(spec.sharding_spec())); + for (const auto& spec : layout.sharding_spec_strs()) { + if (Layout::IsUnshardedDimension(spec)) continue; + layout_dims.emplace_back(layout.mesh().GetMeshDimIndexWithName(spec)); layout_dims_set.insert(layout_dims.back()); } llvm::sort(layout_dims); diff --git a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc index 448d7652aa7fbb..217b41c986bbf0 100644 --- a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc @@ -82,11 +82,11 @@ Status AssertReplicated(mlir::Value operand) { absl::flat_hash_set ReducedMeshDimensions( const dtensor::Layout& input, const dtensor::Layout& output) { absl::flat_hash_set mesh_dims; - for (const auto& dim : input.sharding_specs()) { - mesh_dims.insert(dim.sharding_spec()); + for (const auto& dim : input.sharding_spec_strs()) { + mesh_dims.insert(dim); } - for (const auto& dim : output.sharding_specs()) { - mesh_dims.erase(dim.sharding_spec()); + for (const auto& dim : output.sharding_spec_strs()) { + mesh_dims.erase(dim); } return mesh_dims; } diff --git a/tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.cc index f469ca35e069d8..de11f7494e5750 100644 --- a/tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -41,33 +42,32 @@ StatusOr GetOutputLayout(const absl::optional& tensor_layout, // to replicated. The remainder are set from tensor_layout and updates_layout // with tensor_layout taking priority, as it is generally larger than updates // (as unsharding updates is faster). - std::vector output_specs(tensor_rank); + std::vector output_specs(tensor_rank); // The number of dimensions at the start of the tensor input that are used // for the index, also the size of the second dimension of the indices tensor. const int index_dimensions = tensor_rank - (updates_rank - 1); - for (int i = 0; i < tensor_rank; ++i) - output_specs[i].set_sharding_spec(Layout::kUnshardedDim); + for (int i = 0; i < tensor_rank; ++i) output_specs[i] = Layout::kUnshardedDim; absl::flat_hash_set used_mesh_dims; if (tensor_layout) { for (int i = index_dimensions; i < tensor_rank; ++i) { - output_specs[i] = tensor_layout->dim(i); - if (Layout::IsShardedSpec(output_specs[i])) - used_mesh_dims.emplace(output_specs[i].sharding_spec()); + output_specs[i] = tensor_layout->sharding_spec(i); + if (Layout::IsShardedDimension(output_specs[i])) + used_mesh_dims.emplace(output_specs[i]); } } if (updates_layout) { for (int i = index_dimensions; i < tensor_rank; ++i) { - const ShardingSpec& update_spec = - updates_layout->dim(i - index_dimensions + 1); + const auto& update_spec = + updates_layout->sharding_spec(i - index_dimensions + 1); - if (Layout::IsUnshardedSpec(output_specs[i]) && - Layout::IsShardedSpec(update_spec) && - !used_mesh_dims.contains(update_spec.sharding_spec())) + if (Layout::IsUnshardedDimension(output_specs[i]) && + Layout::IsShardedDimension(update_spec) && + !used_mesh_dims.contains(update_spec)) output_specs[i] = update_spec; } } @@ -122,13 +122,14 @@ StatusOr TensorScatterOpExpand(mlir::Operation* op) { GetOutputLayout(tensor_layout, tensor_rank, updates_layout, updates_rank, tensor_layout->mesh())); - std::vector updates_specs(updates_rank); - updates_specs[0].set_sharding_spec(Layout::kUnshardedDim); + std::vector updates_specs(updates_rank); + updates_specs[0] = Layout::kUnshardedDim; const int index_dimensions = tensor_rank - (updates_rank - 1); for (int i = 0; i < updates_rank - 1; ++i) - updates_specs[i + 1] = pre_output_layout.dim(index_dimensions + i); + updates_specs[i + 1] = + pre_output_layout.sharding_spec(index_dimensions + i); TF_ASSIGN_OR_RETURN(Layout new_updates_layout, Layout::GetLayout(updates_specs, updates_layout->mesh())); diff --git a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc index d28847e2cf896f..ec798f7faf0532 100644 --- a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc @@ -71,7 +71,7 @@ StatusOr ComputeGlobalReduce( // Then an all reduce. absl::flat_hash_set reduced_sharding_specs; for (const int dim : reduced_dims) - if (Layout::IsShardedSpec(input_layout.dim(dim))) + if (Layout::IsShardedDimension(input_layout.sharding_spec(dim))) reduced_sharding_specs.emplace(input_layout.sharding_spec(dim)); TF_ASSIGN_OR_RETURN( mlir::Operation * global_reduce, @@ -192,12 +192,12 @@ StatusOr ComputeShardedSoftmax(mlir::OpBuilder& builder, // 1) Left truncated to match the size of global_shape. // 2) Has unsharded dimensions where ever global_shape is 1. StatusOr GetBroadcastedLayout(llvm::ArrayRef global_shape, - const std::vector& specs, + const std::vector& specs, const Mesh& mesh) { - std::vector new_specs(global_shape.size()); + std::vector new_specs(global_shape.size()); for (int i = 0; i < global_shape.size(); ++i) { if (global_shape[i] == 1) - new_specs[i].set_sharding_spec(Layout::kUnshardedDim); + new_specs[i] = Layout::kUnshardedDim; else new_specs[i] = specs[i + specs.size() - global_shape.size()]; } @@ -248,11 +248,10 @@ StatusOr ComputeOneHot(mlir::OpBuilder& builder, "expected feature input to have at least rank 1, but found rank 0"); const int64_t local_classes = features_type.getShape().back(); - const int64_t classes = - local_classes * - desired_layout.num_shards_for_dim(desired_layout.sharding_specs().back()); + const int64_t classes = local_classes * desired_layout.num_shards_for_dim( + desired_layout.rank() - 1); - int64_t num_shards = desired_layout.num_shards_for_dim(desired_layout.dim(1)); + int64_t num_shards = desired_layout.num_shards_for_dim(1); if (classes % num_shards) return errors::InvalidArgument("unable to shard onehot with size ", classes, " over dimension with ", num_shards, @@ -399,9 +398,9 @@ StatusOr SoftmaxLossOpSPMDExpander::MaybeRelayoutInputs( // This layout represents the 'internal layout' that the softmax will be // operating on. Inputs will be relayout'ed to this layout and outputs will be // relayout'ed from this layout to their desired layout. - std::vector internal_layout(2); - internal_layout[0].set_sharding_spec(Layout::kUnshardedDim); - internal_layout[1].set_sharding_spec(Layout::kUnshardedDim); + std::vector internal_layout(2); + internal_layout[0] = Layout::kUnshardedDim; + internal_layout[1] = Layout::kUnshardedDim; // Choose an internal layout, ideally this layout would be chosen so that // the relayout costs for the inputs (from features_layout/labels_layout to @@ -412,32 +411,34 @@ StatusOr SoftmaxLossOpSPMDExpander::MaybeRelayoutInputs( // Pick a batch sharding, first from features, then labels, loss and backprop. // Due to possible broadcasting on features and labels, they will only // have a batch dim if they are rank 2. - if (features_layout.rank() == 2) internal_layout[0] = features_layout.dim(0); + if (features_layout.rank() == 2) + internal_layout[0] = features_layout.sharding_spec(0); if (((labels_layout.rank() == 2) || (is_sparse && labels_layout.rank() == 1)) && - Layout::IsUnshardedSpec(internal_layout[0])) - internal_layout[0] = labels_layout.dim(0); - if (Layout::IsUnshardedSpec(internal_layout[0])) - internal_layout[0] = loss_layout.dim(0); - if (Layout::IsUnshardedSpec(internal_layout[0])) - internal_layout[0] = backprop_layout.dim(0); + Layout::IsUnshardedDimension(internal_layout[0])) + internal_layout[0] = labels_layout.sharding_spec(0); + if (Layout::IsUnshardedDimension(internal_layout[0])) + internal_layout[0] = loss_layout.sharding_spec(0); + if (Layout::IsUnshardedDimension(internal_layout[0])) + internal_layout[0] = backprop_layout.sharding_spec(0); // Pick a class sharding, first from features, then labels and backprop. // The class dim for features and labels is always the last dim if it exists. // Note that loss and backprop have fixed ranks 1 and 2 respectively where as // ranks of features and labels may involved broadcasting. if (features_layout.rank() > 0 && - (internal_layout[0].sharding_spec() != + (internal_layout[0] != features_layout.sharding_spec(features_layout.rank() - 1))) - internal_layout[1] = features_layout.dim(features_layout.rank() - 1); + internal_layout[1] = + features_layout.sharding_spec(features_layout.rank() - 1); if (!is_sparse && labels_layout.rank() > 0 && - Layout::IsUnshardedSpec(internal_layout[1]) && - (internal_layout[0].sharding_spec() != + Layout::IsUnshardedDimension(internal_layout[1]) && + (internal_layout[0] != labels_layout.sharding_spec(labels_layout.rank() - 1))) - internal_layout[1] = labels_layout.dim(labels_layout.rank() - 1); - if (Layout::IsUnshardedSpec(internal_layout[1]) && - (internal_layout[0].sharding_spec() != backprop_layout.sharding_spec(1))) - internal_layout[1] = backprop_layout.dim(1); + internal_layout[1] = labels_layout.sharding_spec(labels_layout.rank() - 1); + if (Layout::IsUnshardedDimension(internal_layout[1]) && + (internal_layout[0] != backprop_layout.sharding_spec(1))) + internal_layout[1] = backprop_layout.sharding_spec(1); TF_ASSIGN_OR_RETURN( llvm::ArrayRef features_global_shape, @@ -464,7 +465,7 @@ StatusOr SoftmaxLossOpSPMDExpander::MaybeRelayoutInputs( if (is_sparse) { // If we are sparse, then the only possible dimension is the batch_dim. - std::vector sparse_specs = {internal_layout[0]}; + std::vector sparse_specs = {internal_layout[0]}; TF_ASSIGN_OR_RETURN(new_labels_layout, GetBroadcastedLayout(labels_global_shape, sparse_specs, labels_layout.mesh())); @@ -560,7 +561,7 @@ StatusOr SoftmaxLossOpSPMDExpander::ExpandOp( assert(internal_layout.rank() == 2); // If the class dim is unshared, we can emit a local op. - if (Layout::IsUnshardedSpec(internal_layout.dim(1))) { + if (Layout::IsUnshardedDimension(internal_layout.sharding_spec(1))) { op = InferSPMDExpandedLocalShape(op); return MaybeRelayoutOutputs(op, op->getResult(0), op->getResult(1), internal_layout, output_layouts[0], @@ -662,31 +663,32 @@ SoftmaxLossOpSPMDExpander::ComputeLayoutForward( labels_layout.emplace(input_layouts.lookup(1)); // We need to compute shardings for two dimensions: batch and class. - std::vector layout_specs(2); - layout_specs[0].set_sharding_spec(Layout::kUnshardedDim); - layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); + std::vector layout_specs(2); + layout_specs[0] = Layout::kUnshardedDim; + layout_specs[1] = Layout::kUnshardedDim; // First pick the batch dimension, set it to the batch dimension of features // if it exists otherwise to the batch dimesion of labels. if (features_layout && (features_layout->rank() == 2)) - layout_specs[0] = features_layout->dim(0); + layout_specs[0] = features_layout->sharding_spec(0); if (labels_layout && (labels_layout->rank() == 2 || (is_sparse && labels_layout->rank() == 1)) && - Layout::IsUnshardedSpec(layout_specs[0])) - layout_specs[0] = labels_layout->dim(0); + Layout::IsUnshardedDimension(layout_specs[0])) + layout_specs[0] = labels_layout->sharding_spec(0); - // The class dim for features and labels is always the last dim if it - // exists. + // The class sharding_spec for features and labels is always the last + // sharding_spec if it exists. if (features_layout && (features_layout->rank() > 0) && - (layout_specs[0].sharding_spec() != + (layout_specs[0] != features_layout->sharding_spec(features_layout->rank() - 1))) - layout_specs[1] = features_layout->dim(features_layout->rank() - 1); + layout_specs[1] = + features_layout->sharding_spec(features_layout->rank() - 1); if (!is_sparse && labels_layout && (labels_layout->rank() > 0) && - Layout::IsUnshardedSpec(layout_specs[1]) && - (layout_specs[0].sharding_spec() != + Layout::IsUnshardedDimension(layout_specs[1]) && + (layout_specs[0] != labels_layout->sharding_spec(labels_layout->rank() - 1))) - layout_specs[1] = labels_layout->dim(labels_layout->rank() - 1); + layout_specs[1] = labels_layout->sharding_spec(labels_layout->rank() - 1); TF_ASSIGN_OR_RETURN(const Layout backprop_layout, Layout::GetLayout(layout_specs, mesh)); @@ -711,20 +713,19 @@ SoftmaxLossOpSPMDExpander::ComputeLayoutBackward( // We need to compute two possible shardings: // One for the batch dimension and one for the class dimension. - std::vector layout_specs(2); - layout_specs[0].set_sharding_spec(Layout::kUnshardedDim); - layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); + std::vector layout_specs(2); + layout_specs[0] = Layout::kUnshardedDim; + layout_specs[1] = Layout::kUnshardedDim; // Respect the loss layout if it is set, otherwise use the backprop // layout for the batch_dim. - if (loss_layout) layout_specs[0] = loss_layout->dim(0); - if (backprop_layout && Layout::IsUnshardedSpec(layout_specs[0])) - layout_specs[0] = backprop_layout->dim(0); + if (loss_layout) layout_specs[0] = loss_layout->sharding_spec(0); + if (backprop_layout && Layout::IsUnshardedDimension(layout_specs[0])) + layout_specs[0] = backprop_layout->sharding_spec(0); // Only backprop has class dim so use that if it is available. - if (backprop_layout && - backprop_layout->sharding_spec(1) != layout_specs[0].sharding_spec()) - layout_specs[1] = backprop_layout->dim(1); + if (backprop_layout && backprop_layout->sharding_spec(1) != layout_specs[0]) + layout_specs[1] = backprop_layout->sharding_spec(1); TF_ASSIGN_OR_RETURN(const auto features_shape, GetShapeOfValue(op->getOperand(0))); diff --git a/tensorflow/dtensor/mlir/expansions/split_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/split_spmd_expander.cc index f72b22d98df62f..8b271cccb469d9 100644 --- a/tensorflow/dtensor/mlir/expansions/split_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/split_spmd_expander.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include "absl/types/optional.h" #include "mlir/IR/Value.h" // from @llvm-project @@ -40,25 +42,22 @@ StatusOr MergeLayoutsForSplitOutput( int64_t split_dim, const llvm::DenseMap& layouts) { assert(!layouts.empty()); const Layout& first_layout = layouts.begin()->getSecond(); - std::vector sharding_specs( - first_layout.sharding_specs().begin(), - first_layout.sharding_specs().end()); + std::vector sharding_specs = first_layout.sharding_spec_strs(); // Merge remaining layouts. If there is a conflicting sharding, then set the // dim to replicated. for (auto it = layouts.begin(); it != layouts.end(); ++it) { const Layout& output_layout = it->getSecond(); for (int dim = 0; dim < output_layout.rank(); ++dim) { - if (Layout::IsShardedDimension(output_layout.dim(dim).sharding_spec()) && - Layout::IsShardedDimension(sharding_specs[dim].sharding_spec()) && - output_layout.dim(dim).sharding_spec() != - sharding_specs[dim].sharding_spec()) { - sharding_specs[dim].set_sharding_spec(Layout::kUnshardedDim); + if (Layout::IsShardedDimension(output_layout.sharding_spec(dim)) && + Layout::IsShardedDimension(sharding_specs[dim]) && + output_layout.sharding_spec(dim) != sharding_specs[dim]) { + sharding_specs[dim] = Layout::kUnshardedDim; } } } // Force the split_dim to be unsharded. - sharding_specs[split_dim].set_sharding_spec(Layout::kUnshardedDim); + sharding_specs[split_dim] = Layout::kUnshardedDim; return Layout::GetLayout(sharding_specs, first_layout.mesh()); } @@ -89,7 +88,7 @@ StatusOr SplitSPMDExpander::ExpandOp(mlir::Operation* op) { const int64_t split_dim, GetAdjustedSplitDim(split_op.getSplitDim(), split_op.getValue())); - if (Layout::IsShardedDimension(input_layout.dim(split_dim).sharding_spec())) { + if (Layout::IsShardedDimension(input_layout.sharding_spec(split_dim))) { return errors::InvalidArgument( "Spliting over sharded dimension is not supported."); } @@ -142,7 +141,7 @@ StatusOr SplitVSPMDExpander::ExpandOp(mlir::Operation* op) { const int64_t split_dim, GetAdjustedSplitDim(split_v_op.getSplitDim(), split_v_op.getValue())); - if (Layout::IsShardedDimension(input_layout.dim(split_dim).sharding_spec())) { + if (Layout::IsShardedDimension(input_layout.sharding_spec(split_dim))) { return errors::InvalidArgument( "Spliting over sharded dimension is not supported."); } diff --git a/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc index ebfd6d067196d4..c40e08814c36a1 100644 --- a/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.h" +#include #include +#include #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" @@ -55,16 +57,16 @@ StatusOr> SqueezeSPMDExpander::ComputeLayoutForward( TF_ASSIGN_OR_RETURN(auto shape, ExtractGlobalInputShape(op->getOpOperand(0))); std::set squeeze_dims = GetSqueezeDims(op, /*rank=*/shape.size()); - std::vector layout_specs; + std::vector layout_specs; layout_specs.reserve(input_layout.rank()); for (int64 i = 0; i < input_layout.rank(); ++i) { if (squeeze_dims.empty()) { if (shape[i] > 1) { - layout_specs.push_back(input_layout.dim(i)); + layout_specs.push_back(input_layout.sharding_spec(i)); } } else { if (squeeze_dims.find(i) == squeeze_dims.end()) { - layout_specs.push_back(input_layout.dim(i)); + layout_specs.push_back(input_layout.sharding_spec(i)); } } } @@ -85,24 +87,21 @@ SqueezeSPMDExpander::ComputeLayoutBackward( TF_ASSIGN_OR_RETURN(auto shape, ExtractGlobalInputShape(op->getOpOperand(0))); std::set squeeze_dims = GetSqueezeDims(op, /*rank=*/shape.size()); - ShardingSpec unsharded_spec; - unsharded_spec.set_sharding_spec(Layout::kUnshardedDim); - - std::vector layout_specs; + std::vector layout_specs; layout_specs.reserve(output_layout.rank()); size_t j = 0; for (size_t i = 0; i < shape.size(); ++i) { if (squeeze_dims.empty()) { if (shape[i] > 1) { - layout_specs.push_back(output_layout.dim(j++)); + layout_specs.push_back(output_layout.sharding_spec(j++)); } else { - layout_specs.push_back(unsharded_spec); + layout_specs.push_back(Layout::kUnshardedDim); } } else { if (squeeze_dims.find(i) == squeeze_dims.end()) { - layout_specs.push_back(output_layout.dim(j++)); + layout_specs.push_back(output_layout.sharding_spec(j++)); } else { - layout_specs.push_back(unsharded_spec); + layout_specs.push_back(Layout::kUnshardedDim); } } } diff --git a/tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.cc index 6b8a5002489da0..2de0f5851b0575 100644 --- a/tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.h" +#include +#include + #include "mlir/IR/IRMapping.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" #include "tensorflow/dtensor/cc/dstatus.h" @@ -29,14 +32,12 @@ namespace dtensor { // layout -> layout[:-1] + unsharded StatusOr GetSuggestedLayout(const Layout& input_layout) { - std::vector layout_specs(input_layout.rank()); + std::vector layout_specs(input_layout.rank()); for (int i = 0; i < input_layout.rank() - 1; ++i) { - layout_specs[i].set_sharding_spec(input_layout.sharding_spec(i)); + layout_specs[i] = input_layout.sharding_spec(i); } - layout_specs[input_layout.rank() - 1].set_sharding_spec( - Layout::kUnshardedDim); - + layout_specs[input_layout.rank() - 1] = Layout::kUnshardedDim; return Layout::GetLayout(layout_specs, input_layout.mesh()); } diff --git a/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc index 0ee755d440332c..c93ecf9e4ca679 100644 --- a/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc @@ -119,7 +119,7 @@ StatusOr> WhereOpSPMDExpander::ComputeLayoutForward( // Append an unsharded sharding spec for the index dimension generated by the // Where op. std::vector layout_specs; - layout_specs.push_back(layout.dim(0).sharding_spec()); + layout_specs.push_back(layout.sharding_spec(0)); layout_specs.push_back(Layout::kUnshardedDim); TF_ASSIGN_OR_RETURN(Layout new_layout, Layout::GetLayout(layout_specs, layout.mesh())); @@ -138,7 +138,7 @@ WhereOpSPMDExpander::ComputeLayoutBackward( std::vector layout_specs; layout_specs.reserve(layout.rank() - 1); for (int i = 0; i < layout.rank() - 1; i++) { - layout_specs.push_back(layout.dim(i).sharding_spec()); + layout_specs.push_back(layout.sharding_spec(i)); } TF_ASSIGN_OR_RETURN(Layout new_layout, Layout::GetLayout(layout_specs, layout.mesh())); diff --git a/tensorflow/dtensor/mlir/ir/tf_dtensor.cc b/tensorflow/dtensor/mlir/ir/tf_dtensor.cc index c9b4fe235ec7db..14e66c98b6cb2c 100644 --- a/tensorflow/dtensor/mlir/ir/tf_dtensor.cc +++ b/tensorflow/dtensor/mlir/ir/tf_dtensor.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include +#include #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -210,8 +211,8 @@ mlir::LogicalResult DTensorAllToAllOp::verify() { int32_t num_split_dims = 0; int32_t num_concat_dims = 0; - tensorflow::dtensor::ShardingSpec split_spec; - tensorflow::dtensor::ShardingSpec concat_spec; + std::string split_spec; + std::string concat_spec; for (int32_t i = 0; i < input_layout.rank(); ++i) { if (input_layout.sharding_spec(i) == output_layout.sharding_spec(i)) continue; @@ -220,17 +221,17 @@ mlir::LogicalResult DTensorAllToAllOp::verify() { tensorflow::dtensor::Layout::IsShardedDimension( output_layout.sharding_spec(i))) { num_split_dims++; - split_spec = output_layout.dim(i); + split_spec = output_layout.sharding_spec(i); } else if (tensorflow::dtensor::Layout::IsShardedDimension( input_layout.sharding_spec(i)) && tensorflow::dtensor::Layout::IsUnshardedDimension( output_layout.sharding_spec(i))) { num_concat_dims++; - concat_spec = input_layout.dim(i); + concat_spec = input_layout.sharding_spec(i); } } if (num_split_dims != 1 || num_concat_dims != 1 || - split_spec.sharding_spec() != concat_spec.sharding_spec()) { + split_spec != concat_spec) { return op.emitOpError() << "must have one mesh dimension which is being " "unsharded in one axis and sharded in another"; } diff --git a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir index 5da8b06d804c5f..de72a60589b4dc 100644 --- a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir +++ b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir @@ -198,4 +198,78 @@ module attributes {dtensor.all_reduce_combiner.num_ops_in_group = 2} { }) : () -> tensor<4x4xf32> "func.return"() : () -> () } +} + +// ----- +module attributes {dtensor.all_reduce_combiner.topological_distance = 2} { + // Check that when topologicial grouping is enabled in AllReduce combiner, the + // independent DTensorAllReduce ops of the same element type and group assign- + // ment are combined according to the topological distance between two ops. + // + // The following scenario would have 1 group of 7 AllReduces when topological + // distance is *not* set. + // - level 1: %4, %5 (case: <= topo_dist, simple case with same level) + // - level 2: %7 (case: <= topo_dist, simple case for eligible to group) + // - level 4: %16 (case: <= topo_dist, out of order, test for topo sort) + // - level 5: %15 (case: < topo_dist, out of order, test for topo sort) + // - level 8: %14 (case: > topo_dist, ineligible to group and out of order), + // %17 (case: > topo_dist, ineligible to group with 1st group, + // but should get grouped with %14) + // + // Detailed level computations are listed in the test below. + // + // With topological_distance set to 2, we expect the following grouping result + // - group 1: %4, %5, %7, %15, %16 + // - group 2: %14, %17 + // + // Note use of dummy AllReduces (with the same input) gaurantees ops to be + // grouped together if topologicial grouping is not enabled. + // + // CHECK-LABEL: func @main + func.func @main() { + // CHECK: %[[ALL_REDUCE_1:.*]] = "tf.DTensorAllReduce" + // CHECK-SAME: (tensor<1024xf32>, tensor<2x2xi32>) -> tensor<1024xf32> + // CHECK: %[[ALL_REDUCE_2:.*]] = "tf.DTensorAllReduce" + // CHECK-SAME: (tensor<1024xf32>, tensor<2x2xi32>) -> tensor<1024xf32> + // CHECK: %[[ADD:.*]] = "tf.Add" + // CHECK-SAME: (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %0 = "tf_device.cluster"() ({ + // topological level 0 for all tf.Const + %1 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + %2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %3 = "tf.Const"() {value = dense<1.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // %4 topological_level: 1 = max(0, 0) + 1 + %4 = "tf.DTensorAllReduce"(%1, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %5 topological_level: 1 = max(0, 0) + 1 + %5 = "tf.DTensorAllReduce"(%3, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %6 topological_level: 1 = max(0, 0) + 1 + %6 = "tf.Add"(%1, %3) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // %7 topological_level: 2 = max(1, 0) + 1 + %7 = "tf.DTensorAllReduce"(%6, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // Dummy Adds to construct depth in compute graph + // %8 topological_level: 2 = max(1, 0) + 1 + // %9 topological_level: 3 = max(2, 0) + 1 + // %10 topological_level: 4 = max(3, 0) + 1 + // %11 topological_level: 5 = max(4, 0) + 1 + // %12 topological_level: 6 = max(5, 0) + 1 + // %13 topological_level: 7 = max(6, 0) + 1 + %8 = "tf.Add"(%6, %3) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %9 = "tf.Add"(%8, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %10 = "tf.Add"(%9, %3) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %11 = "tf.Add"(%10, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %12 = "tf.Add"(%11, %3) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %13 = "tf.Add"(%12, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // %14 topological_level: 8 = max(7, 0) + 1 + %14 = "tf.DTensorAllReduce"(%13, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %15 topological_level: 5 = max(4, 0) + 1 + %15 = "tf.DTensorAllReduce"(%10, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %16 topological_level: 4 = max(3, 0) + 1 + %16 = "tf.DTensorAllReduce"(%9, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %17 topological_level: 8 = max(7, 0) + 1 + %17 = "tf.DTensorAllReduce"(%13, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + %18 = "tf.Add"(%15, %7) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + "tf_device.return"(%18) : (tensor<4x4xf32>) -> () + }) : () -> tensor<4x4xf32> + "func.return"() : () -> () + } } \ No newline at end of file diff --git a/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir b/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir index a2f33ae33bc654..4de4f74ff094f0 100644 --- a/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir +++ b/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir @@ -162,3 +162,24 @@ module @test_tpu_with_inputs attributes {dtensor.enable_multi_device_mode = true return %arg0 : tensor<4xf32> } } + +// ----- + +// CHECK-LABEL: module @test_inferred_resource_attributes +// CHECK-LABEL: func.func @main +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: _inferred_resource_indices = dense<[1, 2]> +// CHECK-SAME: _inferred_resource_layouts = ["sharding_specs:x,unsharded +// CHECK-SAME , "sharding_specs:unsharded,y + +module @test_inferred_resource_attributes attributes {dtensor.all_reduce_combiner.num_ops_in_group = 0 : i64, dtensor.all_reduce_combiner.topological_distance = 0 : i64, dtensor.eager_operation_name = "AssignVariableOp", dtensor.enable_multi_device_mode = true, tf._default_mesh = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:1"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1555 : i32}} { + func.func @main(%arg0: tensor {tf._global_shape = #tf_type.shape<>}, %arg1: tensor>> {tf._assigned_resource_local_shape = #tf_type.shape<>, tf._global_shape = #tf_type.shape<>, tf._layout = "empty_layout", tf._mesh = "empty_mesh"}, %arg2: tensor>> {tf._assigned_resource_local_shape = #tf_type.shape<>, tf._global_shape = #tf_type.shape<>, tf._layout = "empty_layout", tf._mesh = "empty_mesh"}) attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "eager_operation", inputs = "device_id,op_input_0,op_input_1", outputs = ""}} { + "tf.StatefulPartitionedCall"(%arg0, %arg1) {_inferred_resource_indices = dense<1> : vector<1xi32>, _inferred_resource_layouts = ["sharding_specs:x,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"], _layout = [], _mesh = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", config = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", config_proto = "", executor_type = "", f = @_func} : (tensor, tensor>>) -> () + "tf.StatefulPartitionedCall"(%arg0, %arg2) {_inferred_resource_indices = dense<2> : vector<1xi32>, _inferred_resource_layouts = ["sharding_specs:unsharded,y, mesh:|x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"], _layout = [], _mesh = "|x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", config = "|x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", config_proto = "", executor_type = "", f = @_func} : (tensor, tensor>>) -> () + return + } + func.func private @_func(%arg0: tensor, %arg1: tensor>>) { + "tf.AssignVariableOp"(%arg1, %arg0) {_global_shape = [], _layout = [], device = "", validate_shape = false} : (tensor>>, tensor) -> () + return + } +} diff --git a/tensorflow/dtensor/mlir/utils/collective_lowering.cc b/tensorflow/dtensor/mlir/utils/collective_lowering.cc index 9f874b73d92e00..5c03db58e8e9e3 100644 --- a/tensorflow/dtensor/mlir/utils/collective_lowering.cc +++ b/tensorflow/dtensor/mlir/utils/collective_lowering.cc @@ -960,9 +960,8 @@ mlir::LogicalResult LowerAllGatherOpToCollective( } for (int i = 0; i < src_layout.rank(); i++) { - if (src_layout.num_shards_for_dim(src_layout.dim(i)) == - tgt_layout.num_shards_for_dim(tgt_layout.dim(i)) || - src_layout.num_shards_for_dim(src_layout.dim(i)) == 1) { + if (src_layout.num_shards_for_dim(i) == tgt_layout.num_shards_for_dim(i) || + src_layout.num_shards_for_dim(i) == 1) { continue; } @@ -970,8 +969,7 @@ mlir::LogicalResult LowerAllGatherOpToCollective( perm_for_transpose[0] = perm_for_transpose[i]; perm_for_transpose[i] = temp; - num_shards_per_dim.push_back( - src_layout.num_shards_for_dim(src_layout.dim(i))); + num_shards_per_dim.push_back(src_layout.num_shards_for_dim(i)); previous_sharded_dim[i] = last_sharded_dim; last_sharded_dim = i; @@ -1013,9 +1011,9 @@ mlir::LogicalResult LowerAllGatherOpToCollective( prev_op_result = reshape_op->getResult(0); for (int i = src_layout.rank() - 1; i >= 0; i--) { - if (src_layout.num_shards_for_dim(src_layout.dim(i)) == - tgt_layout.num_shards_for_dim(tgt_layout.dim(i)) || - src_layout.num_shards_for_dim(src_layout.dim(i)) == 1) { + if (src_layout.num_shards_for_dim(i) == + tgt_layout.num_shards_for_dim(i) || + src_layout.num_shards_for_dim(i) == 1) { continue; } @@ -1058,7 +1056,7 @@ mlir::LogicalResult LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather) { llvm::SmallVector concat_dims; for (int64 i = 0; i < src_layout.rank(); ++i) - if (src_layout.num_shards_for_dim(src_layout.dim(i)) > 1 && + if (src_layout.num_shards_for_dim(i) > 1 && Layout::IsUnshardedDimension(tgt_layout.sharding_spec(i))) concat_dims.push_back(i); @@ -1366,9 +1364,8 @@ mlir::LogicalResult LowerAllToAllHelper( absl::flat_hash_set dims_to_gather; for (int i = 0; i < src_layout.rank(); i++) { - if (src_layout.num_shards_for_dim(src_layout.dim(i)) == - tgt_layout.num_shards_for_dim(tgt_layout.dim(i)) || - src_layout.num_shards_for_dim(src_layout.dim(i)) == 1) { + if (src_layout.num_shards_for_dim(i) == tgt_layout.num_shards_for_dim(i) || + src_layout.num_shards_for_dim(i) == 1) { continue; } dims_to_gather.insert(src_layout.sharding_spec(i)); diff --git a/tensorflow/dtensor/python/BUILD b/tensorflow/dtensor/python/BUILD index 2676f040d0cd86..56d789c6c0bc65 100644 --- a/tensorflow/dtensor/python/BUILD +++ b/tensorflow/dtensor/python/BUILD @@ -52,7 +52,7 @@ pytype_strict_library( ":layout", "//tensorflow/python/eager:context", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_util", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", ], @@ -94,6 +94,7 @@ pytype_strict_library( "//tensorflow/python:_pywrap_dtensor_device", "//tensorflow/python/framework:device", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", ], @@ -179,7 +180,7 @@ pytype_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", - "//tensorflow/python/ops:array_ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:io_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/util:tf_export", @@ -204,6 +205,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:_pywrap_utils", "//third_party/py/numpy", @@ -310,6 +312,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/ops:array_ops", diff --git a/tensorflow/dtensor/python/api.py b/tensorflow/dtensor/python/api.py index d971a4dfa7bb0c..2f303e9aa3a218 100644 --- a/tensorflow/dtensor/python/api.py +++ b/tensorflow/dtensor/python/api.py @@ -23,6 +23,7 @@ from tensorflow.dtensor.python import layout as layout_lib from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -167,7 +168,7 @@ def is_dtensor(tensor) -> bool: def copy_to_mesh( tensor: Any, layout: layout_lib.Layout, - source_layout: Optional[layout_lib.Layout] = None) -> ops.Tensor: + source_layout: Optional[layout_lib.Layout] = None) -> tensor_lib.Tensor: """Copies a tf.Tensor onto the DTensor device with the given layout. Copies a regular tf.Tensor onto the DTensor device. Use the mesh attached to @@ -377,7 +378,7 @@ def unpack(tensor: Any) -> Sequence[Any]: @tf_export("experimental.dtensor.fetch_layout", v1=[]) -def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout: +def fetch_layout(tensor: tensor_lib.Tensor) -> layout_lib.Layout: """Fetches the layout of a DTensor. Args: @@ -393,7 +394,7 @@ def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout: @tf_export("experimental.dtensor.check_layout", v1=[]) -def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None: +def check_layout(tensor: tensor_lib.Tensor, layout: layout_lib.Layout) -> None: """Asserts that the layout of the DTensor is `layout`. Args: @@ -410,8 +411,10 @@ def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None: @tf_export("experimental.dtensor.relayout", v1=[]) def relayout( - tensor: ops.Tensor, layout: layout_lib.Layout, name: Optional[str] = None -) -> ops.Tensor: + tensor: tensor_lib.Tensor, + layout: layout_lib.Layout, + name: Optional[str] = None, +) -> tensor_lib.Tensor: """Changes the layout of `tensor`. Changes the layout of `tensor` to `layout`. This is used to fine-tune the @@ -449,8 +452,10 @@ def relayout( @tf_export("experimental.dtensor.relayout_like", v1=[]) def relayout_like( - tensor: ops.Tensor, layout_tensor: ops.Tensor, name: Optional[str] = None -) -> ops.Tensor: + tensor: tensor_lib.Tensor, + layout_tensor: tensor_lib.Tensor, + name: Optional[str] = None, +) -> tensor_lib.Tensor: """Changes the layout of `tensor` to the same as `layout_tensor`. `relayout_like` is often used inside a `tf.function`, to ensure a tensor is diff --git a/tensorflow/dtensor/python/dtensor_device.py b/tensorflow/dtensor/python/dtensor_device.py index c475b8fb8158ec..8a252e64f1e402 100644 --- a/tensorflow/dtensor/python/dtensor_device.py +++ b/tensorflow/dtensor/python/dtensor_device.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.util import _pywrap_utils @@ -119,7 +120,7 @@ def _register_mesh(self, mesh: layout_lib.Mesh): def meshes(self) -> Set[layout_lib.Mesh]: return self._meshes - def copy_to_mesh(self, tensor, new_layout) -> ops.Tensor: + def copy_to_mesh(self, tensor, new_layout) -> tensor_lib.Tensor: """Copy `tensor` to `device` with the given layout.""" self._register_mesh(new_layout.mesh) with ops.device(self.name): diff --git a/tensorflow/dtensor/python/input_util.py b/tensorflow/dtensor/python/input_util.py index 230ae42fdcb5fc..a504c3639a1467 100644 --- a/tensorflow/dtensor/python/input_util.py +++ b/tensorflow/dtensor/python/input_util.py @@ -74,6 +74,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops @@ -110,7 +111,7 @@ class _DTensorIterator(iterator_ops.OwnedIterator): def __init__( self, - dtensor_components: Tuple[ops.Tensor], + dtensor_components: Tuple[tensor.Tensor], global_element_spec: tensor_spec.TensorSpec, layouts: Any): """Initializes a distributed iterator for DTensor datasets. @@ -283,7 +284,7 @@ def _shard_counts(layout: layout_lib.Layout, def _index_matrix(layout: layout_lib.Layout, - elem_spec: tensor_spec.TensorSpec) -> ops.Tensor: + elem_spec: tensor_spec.TensorSpec) -> tensor.Tensor: """Computes a utility matrix to derive device-based slice offsets. This function builds a matrix of shape `[mesh.rank, layout.rank]` for each diff --git a/tensorflow/dtensor/python/layout.py b/tensorflow/dtensor/python/layout.py index d17cc1f18ba911..052e8de2de626e 100644 --- a/tensorflow/dtensor/python/layout.py +++ b/tensorflow/dtensor/python/layout.py @@ -25,6 +25,7 @@ from tensorflow.python import _pywrap_dtensor_device from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.util.tf_export import tf_export # UNSHARDED indicates a tensor dimension is not sharded over any mesh dimension. @@ -220,6 +221,14 @@ def to_spec(d) -> tf_device.DeviceSpec: use_xla_spmd, ) + @classmethod + def _new_object(cls, *args, **kwargs): + # Need to explicitly invoke the base class __init__ because + # Mesh.__init__ overrode it with a different signature. + self = _pywrap_dtensor_device.Mesh.__new__(cls) + super().__init__(self, *args, **kwargs) + return self + def global_device_ids(self) -> np.ndarray: """Returns a global device list as an array.""" return np.array(super().global_device_ids(), dtype=np.int64).reshape( @@ -245,7 +254,7 @@ def __reduce__(self): return Mesh.from_string, (self.to_string(),) # TODO(b/242201545): implement this in Mesh C++ class - def coords(self, device_idx: int) -> ops.Tensor: + def coords(self, device_idx: int) -> tensor.Tensor: """Converts the device index into a tensor of mesh coordinates.""" strides = ops.convert_to_tensor(self.strides) shape = ops.convert_to_tensor(self.shape()) @@ -254,26 +263,25 @@ def coords(self, device_idx: int) -> ops.Tensor: @classmethod def from_proto(cls, proto: layout_pb2.MeshProto) -> 'Mesh': """Construct a mesh instance from input `proto`.""" - mesh = _pywrap_dtensor_device.Mesh.__new__(cls) - _pywrap_dtensor_device.Mesh.__init__(mesh, mesh_proto=proto) - return mesh + return cls._new_object(mesh_proto=proto) @classmethod def from_string(cls, mesh_str: str) -> 'Mesh': - mesh = _pywrap_dtensor_device.Mesh.__new__(cls) - _pywrap_dtensor_device.Mesh.__init__(mesh, mesh_str=mesh_str) - return mesh + return cls._new_object(mesh_str=mesh_str) @classmethod def from_device(cls, device: str) -> 'Mesh': """Constructs a single device mesh from a device string.""" - mesh = _pywrap_dtensor_device.Mesh.__new__(cls) - _pywrap_dtensor_device.Mesh.__init__(mesh, single_device=device) - return mesh + return cls._new_object(single_device=device) + + @classmethod + def _from_mesh(cls, mesh: _pywrap_dtensor_device.Mesh): + """Creates a copy from an existing pywrap mesh object.""" + return cls._new_object(mesh=mesh) @functools.cached_property def _host_mesh(self) -> 'Mesh': - return Mesh.from_string(super().host_mesh().to_string()) + return Mesh._from_mesh(super().host_mesh()) def host_mesh(self) -> 'Mesh': """Returns a host mesh.""" @@ -425,6 +433,14 @@ def __init__(self, sharding_specs: List[str], mesh: Mesh): super().__init__(sharding_specs=sharding_specs, mesh=mesh) + @classmethod + def _new_object(cls, *args, **kwargs): + # Need to explicitly invoke the base class __init__ because + # Layout.__init__ overrode it with a different signature. + self = _pywrap_dtensor_device.Layout.__new__(cls) + super().__init__(self, *args, **kwargs) + return self + def __repr__(self) -> str: return f'Layout(sharding_specs={self.sharding_specs}, mesh={self.mesh})' @@ -435,10 +451,9 @@ def __hash__(self): def __reduce__(self): return Layout.from_string, (self.to_string(),) - # TODO(b/242201545): Find a way to return Mesh object from the pywrap module. @property def mesh(self): - return Mesh.from_proto(super().mesh.as_proto()) + return Mesh._from_mesh(mesh=super().mesh) # pylint: disable=protected-access @property def shape(self): @@ -449,16 +464,13 @@ def batch_sharded( cls, mesh: Mesh, batch_dim: str, rank: int, axis: int = 0 ) -> 'Layout': """Returns a layout sharded on batch dimension.""" - layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__( + return cls._new_object( # Watchout for the different ordering. - layout_obj, mesh=mesh, rank=rank, batch_dim=batch_dim, axis=axis, ) - return layout_obj # TODO(b/242201545): Move this to C++ / find the corresponding function there. def delete(self, dims: List[int]) -> 'Layout': @@ -473,18 +485,12 @@ def delete(self, dims: List[int]) -> 'Layout': @classmethod def from_proto(cls, layout_proto: layout_pb2.LayoutProto) -> 'Layout': """Creates an instance from a LayoutProto.""" - layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__( - layout_obj, layout_proto=layout_proto - ) - return layout_obj + return cls._new_object(layout_proto=layout_proto) @classmethod def from_string(cls, layout_str: str) -> 'Layout': """Creates an instance from a human-readable string.""" - layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__(layout_obj, layout_str=layout_str) - return layout_obj + return cls._new_object(layout_str=layout_str) @classmethod def inner_sharded(cls, mesh: Mesh, inner_dim: str, rank: int) -> 'Layout': @@ -494,9 +500,7 @@ def inner_sharded(cls, mesh: Mesh, inner_dim: str, rank: int) -> 'Layout': @classmethod def from_single_device_mesh(cls, mesh: Mesh) -> 'Layout': """Constructs a single device layout from a single device mesh.""" - layout = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__(layout, mesh=mesh) - return layout + return cls._new_object(mesh=mesh) @classmethod def from_device(cls, device: str) -> 'Layout': @@ -533,6 +537,4 @@ def offset_tuple_to_global_index(self, offset_tuple): @classmethod def replicated(cls, mesh: Mesh, rank: int) -> 'Layout': """Returns a replicated layout of rank `rank`.""" - layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__(layout_obj, mesh=mesh, rank=rank) - return layout_obj + return cls._new_object(mesh=mesh, rank=rank) diff --git a/tensorflow/dtensor/python/save_restore.py b/tensorflow/dtensor/python/save_restore.py index dde0b00d4d5e27..25bd78cdf00a99 100644 --- a/tensorflow/dtensor/python/save_restore.py +++ b/tensorflow/dtensor/python/save_restore.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import io_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.util.tf_export import tf_export @@ -33,10 +34,10 @@ @tf_export('experimental.dtensor.sharded_save', v1=[]) def sharded_save( mesh: layout_lib.Mesh, - file_prefix: Union[str, ops.Tensor], - tensor_names: Union[List[str], ops.Tensor], - shape_and_slices: Union[List[str], ops.Tensor], - tensors: List[Union[ops.Tensor, tf_variables.Variable]], + file_prefix: Union[str, tensor_lib.Tensor], + tensor_names: Union[List[str], tensor_lib.Tensor], + shape_and_slices: Union[List[str], tensor_lib.Tensor], + tensors: List[Union[tensor_lib.Tensor, tf_variables.Variable]], ): """Saves given named tensor slices in a sharded, multi-client safe fashion. @@ -100,7 +101,8 @@ def enable_save_as_bf16(variables: List[tf_variables.Variable]): def name_based_restore( mesh: layout_lib.Mesh, checkpoint_prefix: str, - name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]], + name_tensor_dict: Dict[ + str, Union[tensor_lib.Tensor, tf_variables.Variable]], ): """Restores from checkpoint_prefix to name based DTensors. @@ -163,17 +165,21 @@ def name_based_restore( shape_and_slices=shape_and_slices, input_shapes=input_shapes, input_layouts=input_layouts, - dtypes=[tensor.dtype for tensor in ordered_name_tensor_dict.values()]) + dtypes=[tensor.dtype for tensor in ordered_name_tensor_dict.values()], + ) return collections.OrderedDict( - zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors)) + zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors) + ) @tf_export('experimental.dtensor.name_based_save', v1=[]) -def name_based_save(mesh: layout_lib.Mesh, checkpoint_prefix: Union[str, - ops.Tensor], - name_tensor_dict: Dict[str, Union[ops.Tensor, - tf_variables.Variable]]): +def name_based_save( + mesh: layout_lib.Mesh, + checkpoint_prefix: Union[str, tensor_lib.Tensor], + name_tensor_dict: Dict[ + str, Union[tensor_lib.Tensor, tf_variables.Variable]], +): """Saves name based Tensor into a Checkpoint. The function prepares the input dictionary to the format of a `sharded_save`, diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index 0547f177b31bbc..37c27ed040e706 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -305,23 +305,6 @@ pytype_strict_library( ], ) -pytype_strict_library( - name = "test_backend_util_oss", - srcs = ["test_backend_util.oss.py"], - deps = [ - ":test_util", - "//tensorflow/dtensor/python:config", - "//tensorflow/dtensor/python:layout", - "//tensorflow/dtensor/python:tpu_util", - "//tensorflow/python/platform:client_testlib", - ], -) - -pytype_strict_library( - name = "test_backend_name_oss", - srcs = ["test_backend_name.oss.py"], -) - dtensor_test( name = "multi_client_test", srcs = ["multi_client_test.py"], diff --git a/tensorflow/examples/custom_ops_doc/multiplex_1/README.md b/tensorflow/examples/custom_ops_doc/multiplex_1/README.md index e9800a5f4afd0f..15cac59ad9c28d 100644 --- a/tensorflow/examples/custom_ops_doc/multiplex_1/README.md +++ b/tensorflow/examples/custom_ops_doc/multiplex_1/README.md @@ -338,7 +338,7 @@ py_strict_library( ], ) -tf_py_test( +tf_py_strict_test( name = "multiplex_1_test", size = "small", srcs = ["multiplex_1_test.py"], @@ -399,5 +399,5 @@ Op components | Build rule | Build target Kernels (C++) | `tf_custom_op_library` | `multiplex_1_kernel` | `multiplex_1_kernel.cc`, `multiplex_1_op.cc` Wrapper (automatically generated) | N/A | `gen_multiplex_1_op` | N/A Wrapper (with public API and docstring) | `py_strict_library` | `multiplex_1_op` | `multiplex_1_op.py` -Tests | `tf_py_test` | `multiplex_1_test` | `multiplex_1_test.py` +Tests | `tf_py_strict_test` | `multiplex_1_test` | `multiplex_1_test.py` diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 78a3f922ec655b..bc97bac8a1b102 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -337,6 +337,7 @@ if(TFLITE_ENABLE_GPU) ${TFLITE_SOURCE_DIR}/delegates/gpu/api.cc ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.cc + ${TFLITE_SOURCE_DIR}/delegates/gpu/tflite_profile.cc ${TFLITE_SOURCE_DIR}/experimental/acceleration/compatibility/android_info.cc ${TFLITE_DELEGATES_GPU_CL_SRCS} ${TFLITE_DELEGATES_GPU_CL_DEFAULT_SRCS} diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index 0284beb72a4134..549c447f7413a9 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -236,7 +236,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // artificially adding one to their ref-counts so they are never selected // for deallocation. for (int tensor_index : graph_info_->outputs()) { - ++refcounts_[tensor_index]; + if (tensor_index != kTfLiteOptionalTensor) { + ++refcounts_[tensor_index]; + } } // Variable tensors also should be ensured to be never overwritten and need to diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index 9cbbb61563f159..ded2345c93e8c7 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -709,6 +709,25 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithOptionals) { EXPECT_EQ(GetOffset(2), GetOffsetAfter(4)); } +TEST_F(ArenaPlannerTest, SimpleGraphWithOptionalOutput) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op, with optional + }, + {-1, 3}); + SetGraph(&graph); + Execute(0, graph.nodes().size() - 1); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5 + EXPECT_EQ(GetOffset(5), 12); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(4)); +} + TEST_F(ArenaPlannerTest, SimpleGraphWithLargeTensor) { TestGraph graph({0, -1}, { diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index a0e7a3f6640600..28fd64d1021052 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -182,6 +182,7 @@ cc_test( ":c_api_types", ":common", "//tensorflow/core/platform:resource_loader", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core:subgraph", "//tensorflow/lite/delegates:delegate_test_util", @@ -208,6 +209,7 @@ cc_test( ":c_api_without_op_resolver_without_alwayslink", ":common", "//tensorflow/core/platform:resource_loader", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:selectively_built_c_api_test_lib", "//tensorflow/lite/core:subgraph", @@ -349,6 +351,7 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite:kernel_api", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_opaque_internal", "//tensorflow/lite/core:framework", @@ -390,6 +393,7 @@ tflite_cc_library_with_c_headers_test( ":common", "//tensorflow/lite:framework", "//tensorflow/lite:kernel_api", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_opaque_internal_without_alwayslink", "//tensorflow/lite/core:framework", @@ -416,7 +420,10 @@ cc_test( size = "small", srcs = ["c_api_experimental_test.cc"], copts = tflite_copts(), - data = ["//tensorflow/lite:testdata/add.bin"], + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/custom_sinh.bin", + ], deps = [ ":c_api", ":c_api_experimental", diff --git a/tensorflow/lite/core/c/c_api.cc b/tensorflow/lite/core/c/c_api.cc index 6cef29fc8259f5..fd07b58478b4fd 100644 --- a/tensorflow/lite/core/c/c_api.cc +++ b/tensorflow/lite/core/c/c_api.cc @@ -298,13 +298,6 @@ static TfLiteRegistration* RegistrationExternalToRegistration( // FindOp for builtin op query. const TfLiteRegistration* CallbackOpResolver::FindOp(tflite::BuiltinOperator op, int version) const { - // Use Registration V3 API to find op. - if (op_resolver_callbacks_.find_builtin_op) { - return op_resolver_callbacks_.find_builtin_op( - op_resolver_callbacks_.user_data, - static_cast(op), version); - } - // Check if cached Registration is available. std::lock_guard lock(mutex_); for (const auto& created_registration : temporary_builtin_registrations_) { @@ -314,50 +307,58 @@ const TfLiteRegistration* CallbackOpResolver::FindOp(tflite::BuiltinOperator op, } } + // Try using newer RegistrationExternal API. + if (op_resolver_callbacks_.find_builtin_op_external) { + // Get a RegistrationExternal object and create a Registration (V3) object. + const TfLiteRegistrationExternal* registration_external = + op_resolver_callbacks_.find_builtin_op_external( + op_resolver_callbacks_.user_data, + static_cast(op), version); + if (registration_external && (registration_external->init != nullptr || + registration_external->free != nullptr || + registration_external->invoke != nullptr || + registration_external->prepare != nullptr)) { + TfLiteRegistration* new_registration = + RegistrationExternalToRegistration(registration_external); + temporary_builtin_registrations_.push_back( + std::unique_ptr(new_registration)); + return new_registration; + } + } + + // Use Registration V4 API to find op. + if (op_resolver_callbacks_.find_builtin_op) { + return op_resolver_callbacks_.find_builtin_op( + op_resolver_callbacks_.user_data, + static_cast(op), version); + } + // Try using older Registration V3 API to find op. if (auto* registration = BuildBuiltinOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_builtin_op_v3); registration) { return registration; } + // Try using older Registration V2 API to find op. if (auto* registration = BuildBuiltinOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_builtin_op_v2); registration) { return registration; } + // Try using older Registration V1 API to find op. if (auto* registration = BuildBuiltinOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_builtin_op_v1); registration) { return registration; } - // Try using newer RegistrationExternal API. - if (op_resolver_callbacks_.find_builtin_op_external) { - // Get a RegistrationExternal object and create a Registration (V3) object. - const TfLiteRegistrationExternal* registration_external = - op_resolver_callbacks_.find_builtin_op_external( - op_resolver_callbacks_.user_data, - static_cast(op), version); - if (registration_external) { - TfLiteRegistration* new_registration = - RegistrationExternalToRegistration(registration_external); - temporary_builtin_registrations_.push_back( - std::unique_ptr(new_registration)); - return new_registration; - } - } return nullptr; } // FindOp for custom op query. const TfLiteRegistration* CallbackOpResolver::FindOp(const char* op, int version) const { - // Use TfLiteRegistration API to find op. - if (op_resolver_callbacks_.find_custom_op) { - return op_resolver_callbacks_.find_custom_op( - op_resolver_callbacks_.user_data, op, version); - } // Check if cached Registration is available. std::lock_guard lock(mutex_); for (const auto& created_registration : temporary_custom_registrations_) { @@ -367,37 +368,48 @@ const TfLiteRegistration* CallbackOpResolver::FindOp(const char* op, } } + if (op_resolver_callbacks_.find_custom_op_external) { + // Get a RegistrationExternal object and create a Registration (V3) object. + const TfLiteRegistrationExternal* registration_external = + op_resolver_callbacks_.find_custom_op_external( + op_resolver_callbacks_.user_data, op, version); + if (registration_external && (registration_external->init != nullptr || + registration_external->free != nullptr || + registration_external->invoke != nullptr || + registration_external->prepare != nullptr)) { + TfLiteRegistration* new_registration = + RegistrationExternalToRegistration(registration_external); + temporary_builtin_registrations_.push_back( + std::unique_ptr(new_registration)); + return new_registration; + } + } + // Use TfLiteRegistration V4 API to find op. + if (op_resolver_callbacks_.find_custom_op) { + return op_resolver_callbacks_.find_custom_op( + op_resolver_callbacks_.user_data, op, version); + } + // Use older TfLiteRegistration V3 API to find op. if (auto* registration = BuildCustomOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_custom_op_v3); registration) { return registration; } + // Use older TfLiteRegistration V2 API to find op. if (auto* registration = BuildCustomOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_custom_op_v2); registration) { return registration; } + // Use even older TfLiteRegistration V1 API to find op. if (auto* registration = BuildCustomOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_custom_op_v1); registration) { return registration; } - if (op_resolver_callbacks_.find_custom_op_external) { - // Get a RegistrationExternal object and create a Registration (V2) object. - const TfLiteRegistrationExternal* registration_external = - op_resolver_callbacks_.find_custom_op_external( - op_resolver_callbacks_.user_data, op, version); - if (registration_external) { - TfLiteRegistration* new_registration = - RegistrationExternalToRegistration(registration_external); - temporary_builtin_registrations_.push_back( - std::unique_ptr(new_registration)); - return new_registration; - } - } return nullptr; } @@ -441,6 +453,8 @@ TfLiteInterpreter* InterpreterCreateWithOpResolver( optional_options->op_resolver_callbacks.find_custom_op_v1 != nullptr || optional_options->op_resolver_callbacks.find_builtin_op_v2 != nullptr || optional_options->op_resolver_callbacks.find_custom_op_v2 != nullptr || + optional_options->op_resolver_callbacks.find_builtin_op_v3 != nullptr || + optional_options->op_resolver_callbacks.find_custom_op_v3 != nullptr || optional_options->op_resolver_callbacks.find_builtin_op_external != nullptr || optional_options->op_resolver_callbacks.find_custom_op_external != diff --git a/tensorflow/lite/core/c/c_api_experimental.cc b/tensorflow/lite/core/c/c_api_experimental.cc index 45a8d0f99241a1..a6c5baa6c9a066 100644 --- a/tensorflow/lite/core/c/c_api_experimental.cc +++ b/tensorflow/lite/core/c/c_api_experimental.cc @@ -75,6 +75,28 @@ void TfLiteInterpreterOptionsSetOpResolverExternal( options->op_resolver_callbacks.user_data = op_resolver_user_data; } +void TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + TfLiteInterpreterOptions* options, + const TfLiteRegistrationExternal* (*find_builtin_op_external)( + void* user_data, int op, int version), + const TfLiteRegistrationExternal* (*find_custom_op_external)( + void* user_data, const char* custom_op, int version), + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version), + const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, + int version), + void* op_resolver_user_data) { + options->op_resolver_callbacks = {}; // Sets all fields to null. + options->op_resolver_callbacks.find_builtin_op_external = + find_builtin_op_external; + options->op_resolver_callbacks.find_custom_op_external = + find_custom_op_external; + options->op_resolver_callbacks.find_builtin_op = find_builtin_op; + options->op_resolver_callbacks.find_custom_op = find_custom_op; + options->op_resolver_callbacks.user_data = op_resolver_user_data; +} + void TfLiteInterpreterOptionsSetOpResolver( TfLiteInterpreterOptions* options, const TfLiteRegistration* (*find_builtin_op)(void* user_data, diff --git a/tensorflow/lite/core/c/c_api_experimental.h b/tensorflow/lite/core/c/c_api_experimental.h index f87e0226baf62a..f766931931bfc2 100644 --- a/tensorflow/lite/core/c/c_api_experimental.h +++ b/tensorflow/lite/core/c/c_api_experimental.h @@ -101,6 +101,7 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp( /// The `TfLiteInterpreterOptionsSetOpResolverExternal` function provides an /// alternative method for registering builtin ops and/or custom ops, by /// providing operator resolver callbacks. Unlike using +/// `TfLiteInterpreterOptionsAddRegistrationExternal`, /// `TfLiteInterpreterOptionsAddBuiltinOp` and/or /// `TfLiteInterpreterOptionsAddAddCustomOp`, these let you register all the /// operators in a single call. @@ -126,6 +127,34 @@ void TfLiteInterpreterOptionsSetOpResolverExternal( int version), void* op_resolver_user_data); +/// \private +/// Registers callbacks for resolving builtin or custom operators. +/// +/// This combines the effects of TfLiteInterpreterOptionsSetOpResolverExternal +/// and TfLiteInterpreterOptionsSetOpResolver. The callbacks that return +/// TfLiteRegistrationExternal will be called first, but if they return a +/// TfLiteRegistrationExternal object that has no methods set, then +/// the callbacks that return a TfLiteRegistration will be called to get +/// the methods. +/// +/// WARNING: This function is experimental and subject to change. +/// +/// WARNING: This function is not an official part of the API, +/// and should not be used by apps. It is intended for use only from +/// TF Lite itself. +void TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + TfLiteInterpreterOptions* options, + const TfLiteRegistrationExternal* (*find_builtin_op_external)( + void* user_data, int op, int version), + const TfLiteRegistrationExternal* (*find_custom_op_external)( + void* user_data, const char* custom_op, int version), + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version), + const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, + int version), + void* op_resolver_user_data); + /// Registers callbacks for resolving builtin or custom operators. /// /// The `TfLiteInterpreterOptionsSetOpResolver` function provides an alternative diff --git a/tensorflow/lite/core/c/c_api_experimental_test.cc b/tensorflow/lite/core/c/c_api_experimental_test.cc index f1d045d737c748..9b05252e1ed139 100644 --- a/tensorflow/lite/core/c/c_api_experimental_test.cc +++ b/tensorflow/lite/core/c/c_api_experimental_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/lite/core/c/c_api_experimental.h" -#include - #include +#include +#include #include #include @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/c/c_api.h" +#include "tensorflow/lite/core/c/c_api_opaque.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/delegate_test_util.h" #include "tensorflow/lite/testing/util.h" @@ -206,7 +207,72 @@ const TfLiteRegistrationExternal* MyFindCustomOpExternal(void*, return nullptr; } -// Test using TfLiteInterpreterCreateWithSelectedOps. +TfLiteStatus SinhPrepareOpaque(TfLiteOpaqueContext*, TfLiteOpaqueNode*) { + return kTfLiteOk; +} + +TfLiteStatus SinhEvalOpaque(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) { + EXPECT_EQ(1, TfLiteOpaqueNodeNumberOfInputs(node)); + const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0); + size_t input_bytes = TfLiteOpaqueTensorByteSize(input); + const void* data_ptr = TfLiteOpaqueTensorData(input); + float input_value; + std::memcpy(&input_value, data_ptr, input_bytes); + + EXPECT_EQ(1, TfLiteOpaqueNodeNumberOfOutputs(node)); + TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0); + float output_value = std::sinh(input_value); + TfLiteOpaqueTensorCopyFromBuffer(output, &output_value, sizeof(output_value)); + return kTfLiteOk; +} + +TfLiteStatus SinhPrepare(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } + +TfLiteStatus SinhEval(TfLiteContext* context, TfLiteNode* node) { + EXPECT_EQ(1, node->inputs->size); + const TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + size_t input_bytes = TfLiteTensorByteSize(input); + const void* data_ptr = TfLiteTensorData(input); + float input_value; + std::memcpy(&input_value, data_ptr, input_bytes); + + EXPECT_EQ(1, node->outputs->size); + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + float output_value = std::sinh(input_value); + TfLiteTensorCopyFromBuffer(output, &output_value, sizeof(output_value)); + return kTfLiteOk; +} + +const TfLiteRegistrationExternal* SinhFindCustomOpExternal( + void*, const char* custom_op, int version) { + if (absl::string_view(custom_op) == "Sinh" && version == 1) { + static TfLiteRegistrationExternal* registration = []() { + TfLiteRegistrationExternal* reg = + TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom, "Sinh", 1); + TfLiteRegistrationExternalSetPrepare(reg, &SinhPrepareOpaque); + TfLiteRegistrationExternalSetInvoke(reg, &SinhEvalOpaque); + return reg; + }(); + return registration; + } + return nullptr; +} + +const TfLiteRegistration* SinhFindCustomOp(void*, const char* custom_op, + int version) { + if (absl::string_view(custom_op) == "Sinh" && version == 1) { + static const TfLiteRegistration registration{/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/SinhPrepare, + /*invoke=*/SinhEval}; + return ®istration; + } + return nullptr; +} + +// Test using TfLiteInterpreterOptionsSetOpResolverExternal and +// TfLiteInterpreterCreateWithSelectedOps. TEST(CApiExperimentalTest, SetOpResolverExternal) { TfLiteModel* model = TfLiteModelCreateFromFile( tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin") @@ -233,6 +299,171 @@ TEST(CApiExperimentalTest, SetOpResolverExternal) { TfLiteModelDelete(model); } +// Test using TfLiteInterpreterOptionsSetOpResolverExternalWithFallback and +// TfLiteInterpreterCreateWithSelectedOps, for a builtin op, for the normal +// case where the op is found with the primary op resolver callback that returns +// a TfLiteRegistrationExternal pointer. +TEST(CApiExperimentalTest, + SetOpResolverExternalWithFallback_BuiltinOp_NormalCase) { + TfLiteModel* model = TfLiteModelCreateFromFile( + tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin") + .c_str()); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + options, MyFindBuiltinOpExternal, MyFindCustomOpExternal, + [](void* user_data, TfLiteBuiltinOperator op, + int version) -> const TfLiteRegistration* { return nullptr; }, + [](void* user_data, const char* custom_op, + int version) -> const TfLiteRegistration* { return nullptr; }, + &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_TRUE(my_data.called_for_add); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +// Test using TfLiteInterpreterOptionsSetOpResolverExternalWithFallback and +// TfLiteInterpreterCreateWithSelectedOps, for a builtin op, for the fallback +// case where the op is found with the secondary op resolver callback that +// returns a TfLiteRegistration pointer. +TEST(CApiExperimentalTest, + SetOpResolverExternalWithFallback_BuiltinOp_FallbackCase) { + TfLiteModel* model = TfLiteModelCreateFromFile( + tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin") + .c_str()); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + options, + [](void* user_data, int op, + int version) -> const TfLiteRegistrationExternal* { return nullptr; }, + [](void* user_data, const char* custom_op, + int version) -> const TfLiteRegistrationExternal* { return nullptr; }, + MyFindBuiltinOp, MyFindCustomOp, &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_TRUE(my_data.called_for_add); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +// Test using TfLiteInterpreterOptionsSetOpResolverExternalWithFallback and +// TfLiteInterpreterCreateWithSelectedOps, for a custom op, for the normal +// case where the op is found with the primary op resolver callback that returns +// a TfLiteRegistrationExternal pointer. +TEST(CApiExperimentalTest, + SetOpResolverExternalWithFallback_CustomOp_NormalCase) { + TfLiteModel* model = + TfLiteModelCreateFromFile(tensorflow::GetDataDependencyFilepath( + "tensorflow/lite/testdata/custom_sinh.bin") + .c_str()); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + options, MyFindBuiltinOpExternal, SinhFindCustomOpExternal, + [](void* user_data, TfLiteBuiltinOperator op, + int version) -> const TfLiteRegistration* { return nullptr; }, + [](void* user_data, const char* custom_op, + int version) -> const TfLiteRegistration* { return nullptr; }, + &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + + TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0); + const float input_value = 1.0f; + TfLiteTensorCopyFromBuffer(input_tensor, &input_value, sizeof(float)); + + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + + const TfLiteTensor* output_tensor = + TfLiteInterpreterGetOutputTensor(interpreter, 0); + float output_value; + TfLiteTensorCopyToBuffer(output_tensor, &output_value, sizeof(float)); + EXPECT_EQ(output_value, std::sinh(input_value)); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +// Test using TfLiteInterpreterOptionsSetOpResolverExternalWithFallback and +// TfLiteInterpreterCreateWithSelectedOps, for a custom op, for the fallback +// case where the op is found with the secondary op resolver callback that +// returns a TfLiteRegistration pointer. +TEST(CApiExperimentalTest, + SetOpResolverExternalWithFallback_CustomOp_FallbackCase) { + TfLiteModel* model = + TfLiteModelCreateFromFile(tensorflow::GetDataDependencyFilepath( + "tensorflow/lite/testdata/custom_sinh.bin") + .c_str()); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + options, + [](void* user_data, int op, + int version) -> const TfLiteRegistrationExternal* { return nullptr; }, + [](void* user_data, const char* custom_op, + int version) -> const TfLiteRegistrationExternal* { return nullptr; }, + MyFindBuiltinOp, SinhFindCustomOp, &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + + TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0); + const float input_value = 1.0f; + TfLiteTensorCopyFromBuffer(input_tensor, &input_value, sizeof(float)); + + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_FALSE(my_data.called_for_add); + + const TfLiteTensor* output_tensor = + TfLiteInterpreterGetOutputTensor(interpreter, 0); + float output_value; + TfLiteTensorCopyToBuffer(output_tensor, &output_value, sizeof(float)); + EXPECT_EQ(output_value, std::sinh(input_value)); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + void AllocateAndSetInputs(TfLiteInterpreter* interpreter) { std::array input_dims = {2}; ASSERT_EQ(TfLiteInterpreterResizeInputTensor( diff --git a/tensorflow/lite/core/c/c_api_opaque.cc b/tensorflow/lite/core/c/c_api_opaque.cc index 631412b7aea59a..f889f5a5899899 100644 --- a/tensorflow/lite/core/c/c_api_opaque.cc +++ b/tensorflow/lite/core/c/c_api_opaque.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" namespace { @@ -36,6 +37,13 @@ const TfLiteTensor* Convert(const TfLiteOpaqueTensor* opaque_tensor) { return reinterpret_cast(opaque_tensor); } +TfLiteTensor* Convert(TfLiteOpaqueTensor* opaque_tensor) { + // The following cast is safe only because this code is part of the + // TF Lite runtime implementation. Apps using TF Lite should not rely on + // TfLiteOpaqueTensor and TfLiteTensor being equivalent. + return reinterpret_cast(opaque_tensor); +} + const TfLiteNode* Convert(const TfLiteOpaqueNode* opaque_node) { // The following cast is safe only because this code is part of the // TF Lite runtime implementation. Apps using TF Lite should not rely on @@ -168,6 +176,37 @@ TfLiteStatus TfLiteOpaqueTensorCopyToBuffer( output_data_size); } +int TfLiteOpaqueTensorGetStringCount(const TfLiteOpaqueTensor* tensor) { + return tflite::GetStringCount(Convert(tensor)); +} + +TfLiteStatus TfLiteOpaqueTensorGetString(const TfLiteOpaqueTensor* tensor, + int index, const char** str, + int* len) { + tflite::StringRef str_ref = tflite::GetString(Convert(tensor), index); + *str = str_ref.str; + *len = str_ref.len; + return kTfLiteOk; +} + +TfLiteStatus TfLiteOpaqueTensorWriteStrings(TfLiteOpaqueTensor* tensor, + const char* const* str_array, + int str_array_len, + const int* str_n_len) { + tflite::DynamicBuffer buf; + for (int i = 0; i < str_array_len; ++i) { + buf.AddString(str_array[i], str_n_len[i]); + } + buf.WriteToTensorAsVector(Convert(tensor)); + return kTfLiteOk; +} + +TfLiteStatus TfLiteOpaqueTensorWriteString(TfLiteOpaqueTensor* tensor, + const char* str, const int len) { + TfLiteOpaqueTensorWriteStrings(tensor, &str, 1, &len); + return kTfLiteOk; +} + const TfLiteOpaqueTensor* TfLiteOpaqueNodeGetInput( const TfLiteOpaqueContext* opaque_context, const TfLiteOpaqueNode* opaque_node, int index) { diff --git a/tensorflow/lite/core/c/c_api_opaque.h b/tensorflow/lite/core/c/c_api_opaque.h index ff70c4304401e4..44c95a51d4ece0 100644 --- a/tensorflow/lite/core/c/c_api_opaque.h +++ b/tensorflow/lite/core/c/c_api_opaque.h @@ -113,6 +113,64 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteOpaqueTensorCopyToBuffer( const TfLiteOpaqueTensor* opaque_tensor, void* output_data, size_t output_data_size); +// Returns the number of strings stored in the provided 'tensor'. Returns -1 in +// case of failure. +int TfLiteOpaqueTensorGetStringCount(const TfLiteOpaqueTensor* tensor); + +// Stores the address of the n-th (denoted by the provided 'index') string +// contained in the provided 'tensor' in the provided '*str' pointer. Stores +// the length of the string in the provided '*len' argument. +// +// Returns 'kTfLiteOk' if '*str' and '*len' have been set successfully. Any +// other return value indicates a failure, which leaves '*str' and '*len' in an +// unspecified state. +// +// The range of valid indices is defined by the half open interval [0, N), +// where N == TfLiteOpaqueTensorGetStringCount(tensor). +// +// Note that 'str' is not guaranteed to be null-terminated. Also note that this +// function will not create a copy of the underlying string data. The data is +// owned by the 'tensor'. +TfLiteStatus TfLiteOpaqueTensorGetString(const TfLiteOpaqueTensor* tensor, + int index, const char** str, int* len); + +// Writes the array of strings specified by 'str_array' into +// the specified 'tensor'. The strings provided via the 'str_array' are being +// copied into the 'tensor'. Returns 'kTfLiteOk' in case of success. Any other +// return value indicates a failure. +// +// The provided 'str_array_len' must denote the length of 'str_array' +// and 'str_n_len[i]' must denote the length of the i-th string. +// +// The provided strings don't need to be null terminated and may contain +// embedded null characters. The amount of bytes copied into the 'tensor' is +// entirely determined by 'str_n_len[i]' and it is the caller's responsibility +// to set this value correctly to avoid undefined behavior. +// +// Also note that calling 'TfLiteOpaqueTensorWriteStrings' deallocates any +// previously stored data in the 'tensor'. +TfLiteStatus TfLiteOpaqueTensorWriteStrings(TfLiteOpaqueTensor* tensor, + const char* const* str_array, + int str_array_len, + const int* str_n_len); + +// Writes the string pointed to by the provided 'str' pointer of length 'len' +// into the provided 'tensor'. The string provided via 'str' is +// copied into the 'tensor'. Returns 'kTfLiteOk' in case of success. Any +// other return value indicates a failure. +// +// Note that calling 'TfLiteOpaqueTensorWriteString' deallocates any +// previously stored data in the 'tensor'. E.g. suppose 't' denotes a +// 'TfLiteOpaqueTensor*', then calling 'TfLiteOpaqueTensorWriteString(t, "AB", +// 2)' followed by a call to 'TfLiteOpaqueTensorWriteString(t, "CD", 2)' will +// lead to 't' containing 'CD', not 'ABCD'. +// +// 'TfLiteOpaqueTensorWriteString' is a convenience function for the use case +// of writing a single string to a tensor and its effects are identical to +// calling 'TfLiteOpaqueTensorWriteStrings' with an array of a single string. +TfLiteStatus TfLiteOpaqueTensorWriteString(TfLiteOpaqueTensor* tensor, + const char* str, int len); + // -------------------------------------------------------------------------- // Accessors for TfLiteOpaqueNode. diff --git a/tensorflow/lite/core/c/c_api_test.cc b/tensorflow/lite/core/c/c_api_test.cc index 152f0a66c4fa26..e17611860a76f9 100644 --- a/tensorflow/lite/core/c/c_api_test.cc +++ b/tensorflow/lite/core/c/c_api_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/delegates/delegate_test_util.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/testing/util.h" namespace { @@ -1648,6 +1649,196 @@ TEST(CApiSimple, OpaqueApiAccessors) { EXPECT_TRUE(delegate_kernel_invoked); } +TEST(CApiSimple, OpaqueApiAccessorsStrings) { + ::tflite::Interpreter interpreter; + interpreter.AddTensors(3); + std::vector dims = {1}; + TfLiteQuantizationParams quant{}; + interpreter.SetTensorParametersReadWrite(0, kTfLiteString, "a", dims, quant, + /*is_variable=*/false); + interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "b", dims, quant, + /*is_variable=*/false); + interpreter.SetTensorParametersReadWrite(2, kTfLiteString, "c", dims, quant, + /*is_variable=*/false); + + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({2}); + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver; + TfLiteAddParams* builtin_data = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + builtin_data->activation = kTfLiteActNone; + builtin_data->pot_scale_int16 = false; + const TfLiteRegistration* registration = + resolver.FindOp(::tflite::BuiltinOperator_ADD, 1); + interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, builtin_data, + registration); + + TfLiteOpaqueDelegateBuilder opaque_delegate_builder{}; + opaque_delegate_builder.flags = kTfLiteDelegateFlagsAllowDynamicTensors; + bool delegate_kernel_invoked = false; + opaque_delegate_builder.data = &delegate_kernel_invoked; + opaque_delegate_builder.Prepare = [](TfLiteOpaqueContext* context, + TfLiteOpaqueDelegate* delegate, + void* data) -> TfLiteStatus { + TfLiteRegistrationExternal* registration = TfLiteRegistrationExternalCreate( + kTfLiteBuiltinDelegate, "my delegate", 123); + TfLiteRegistrationExternalSetInit( + registration, + [](TfLiteOpaqueContext* opaque_context, const char* buffer, + size_t length) -> void* { + const TfLiteOpaqueDelegateParams* params = + reinterpret_cast(buffer); + EXPECT_EQ(2, params->input_tensors->size); + TfLiteOpaqueTensor* opaque_input_tensor = + TfLiteOpaqueContextGetOpaqueTensor( + opaque_context, params->input_tensors->data[0]); + EXPECT_EQ(1, TfLiteOpaqueTensorNumDims(opaque_input_tensor)); + EXPECT_EQ(1, TfLiteOpaqueTensorDim(opaque_input_tensor, 0)); + EXPECT_EQ(kTfLiteDynamic, + TfLiteOpaqueTensorGetAllocationType(opaque_input_tensor)); + + bool* delegate_kernel_invoked = + static_cast(params->delegate_data); + *delegate_kernel_invoked = true; + return nullptr; + }); + + TfLiteRegistrationExternalSetPrepare( + registration, + [](TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; }); + + TfLiteRegistrationExternalSetInvoke( + registration, + [](TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) -> TfLiteStatus { + const TfLiteOpaqueTensor* input0 = + TfLiteOpaqueNodeGetInput(context, node, 0); + + EXPECT_EQ(TfLiteOpaqueTensorGetStringCount(input0), 4); + const char* input0_string2 = nullptr; + int input0_string2_len = -1; + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorGetString(input0, 2, &input0_string2, + &input0_string2_len)); + EXPECT_EQ(std::string(input0_string2, input0_string2_len), "F"); + EXPECT_EQ(1, input0_string2_len); + + const TfLiteOpaqueTensor* input1 = + TfLiteOpaqueNodeGetInput(context, node, 1); + EXPECT_EQ(TfLiteOpaqueTensorGetStringCount(input1), 1); + const char* input1_string0 = nullptr; + int input1_string0_len = -1; + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorGetString(input1, 0, &input1_string0, + &input1_string0_len)); + EXPECT_EQ(std::string(input1_string0, input1_string0_len), "XYZ"); + EXPECT_EQ(3, input1_string0_len); + + TfLiteOpaqueTensor* opaque_output0 = + TfLiteOpaqueNodeGetOutput(context, node, 0); + + // + // First use 'TfLiteOpaqueTensorWriteString' to check that we can copy + // a string from an input tensor to an output tensor. + // + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorWriteString( + opaque_output0, input0_string2, input0_string2_len)); + const char* output_str_from_opaque_tensor = nullptr; + int output_str_from_opaque_tensor_len = -1; + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorGetString( + opaque_output0, 0, &output_str_from_opaque_tensor, + &output_str_from_opaque_tensor_len)); + EXPECT_EQ(std::string(output_str_from_opaque_tensor, + output_str_from_opaque_tensor_len), + "F"); + EXPECT_EQ(1, output_str_from_opaque_tensor_len); + + // + // Then perform the 'actual' ADD operation of adding the input tensor + // string to the output tensor. + // + std::vector str_array; + std::vector str_array_len; + for (int i = 0; i < TfLiteOpaqueTensorGetStringCount(input0); ++i) { + const char* input_string = nullptr; + int input_string_len = -1; + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorGetString(input0, i, &input_string, + &input_string_len)); + str_array.push_back(input_string); + str_array_len.push_back(input_string_len); + } + str_array.push_back(input1_string0); + str_array_len.push_back(input1_string0_len); + + EXPECT_EQ(kTfLiteOk, TfLiteOpaqueTensorWriteStrings( + opaque_output0, str_array.data(), + str_array.size(), str_array_len.data())); + return kTfLiteOk; + }); + + TfLiteIntArray* execution_plan{}; + TfLiteOpaqueContextGetExecutionPlan(context, &execution_plan); + TfLiteOpaqueContextReplaceNodeSubsetsWithDelegateKernels( + context, registration, execution_plan, delegate); + return kTfLiteOk; + }; + + TfLiteDelegate my_delegate{}; + my_delegate.opaque_delegate_builder = &opaque_delegate_builder; + EXPECT_EQ(kTfLiteOk, interpreter.ModifyGraphWithDelegate(&my_delegate)); + EXPECT_TRUE(delegate_kernel_invoked); + EXPECT_EQ(kTfLiteOk, interpreter.AllocateTensors()); + + // + // Load input tensors with string data. + // + TfLiteTensor* t0 = interpreter.tensor(0); + tflite::DynamicBuffer buf0; + const char* raw_buf_with_embedded_null = "DDD\0EEE"; + const char* raw_buf_without_embedded_null = "12345678"; + std::vector t0_strings{ + "ABC", + std::string(raw_buf_with_embedded_null, raw_buf_with_embedded_null + 6), + "F", + std::string(raw_buf_without_embedded_null, + raw_buf_without_embedded_null + 4)}; + for (const std::string& s : t0_strings) { + ASSERT_EQ(buf0.AddString(s.data(), s.size()), kTfLiteOk); + } + buf0.WriteToTensorAsVector(t0); + + TfLiteTensor* t1 = interpreter.tensor(1); + char s1[] = "XYZ"; + tflite::DynamicBuffer buf1; + ASSERT_EQ(buf1.AddString(s1, 3), kTfLiteOk); + buf1.WriteToTensorAsVector(t1); + + // + // Invoke the interpreter, so that the input tensor strings get copied to the + // output tensor. + // + EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + + // + // Check that the output tensor stores the combination of the input strings. + // + const std::vector expected_strings{ + "ABC", + std::string(raw_buf_with_embedded_null, raw_buf_with_embedded_null + 6), + "F", "1234", "XYZ"}; + TfLiteTensor* t2 = interpreter.tensor(2); + EXPECT_EQ(tflite::GetStringCount(t2), expected_strings.size()); + for (int i = 0; i < tflite::GetStringCount(t2); ++i) { + tflite::StringRef str_ref = tflite::GetString(t2, i); + EXPECT_EQ(std::string(str_ref.str, str_ref.len), expected_strings[i]); + } +} + void AddNode( tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates* resolver, ::tflite::Interpreter* interpreter) { diff --git a/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/BUILD b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/BUILD new file mode 100644 index 00000000000000..2f9af85ed4d3cc --- /dev/null +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/BUILD @@ -0,0 +1,68 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +load("//tensorflow/lite/experimental/acceleration/mini_benchmark:build_defs.bzl", "cc_library_with_forced_in_process_benchmark_variant") +load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps", "minibenchmark_visibility_allowlist") + +default_visibility_group = [ + "//tensorflow/lite/experimental/acceleration/mini_benchmark:__subpackages__", +] + minibenchmark_visibility_allowlist() + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = default_visibility_group, + licenses = ["notice"], +) + +cc_library_with_forced_in_process_benchmark_variant( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = ["c_api.h"], + in_process_deps = [ + "//tensorflow/lite/experimental/acceleration/mini_benchmark:blocking_validator_runner", + ], + deps = [ + ":c_api_types", + "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/acceleration/configuration/c:delegate_plugin", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:benchmark_result_evaluator", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:status_codes", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:validator_runner_entrypoint", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:validator_runner_options", + "@flatbuffers", + ], +) + +cc_test( + name = "c_api_test", + srcs = ["c_api_test.cc"], + deps = [ + ":c_api", + "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_mobilenet_model", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_mobilenet_validation_model", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_simple_addition_model", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_test_helper", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:status_codes", + "@com_google_googletest//:gtest_main", + "@flatbuffers", + "@flatbuffers//:runtime_cc", + ] + libjpeg_handle_deps(), +) + +cc_library( + name = "c_api_types", + hdrs = ["c_api_types.h"], + visibility = ["//visibility:private"], +) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.cc b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc similarity index 97% rename from tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.cc rename to tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc index 95c32f45bd6ff5..51c927836135dc 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.cc +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h" +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" #include #include @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/benchmark_result_evaluator.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/blocking_validator_runner.h" -#include "tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h" +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_options.h" diff --git a/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h new file mode 100644 index 00000000000000..ed5d17d62beba6 --- /dev/null +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h @@ -0,0 +1,81 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ +#define TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// APIs of TfLiteMiniBenchmarkResult. +typedef struct TfLiteMiniBenchmarkResult TfLiteMiniBenchmarkResult; +int TfLiteMiniBenchmarkResultInitStatus(TfLiteMiniBenchmarkResult* result); +uint8_t* TfLiteMiniBenchmarkResultFlatBufferData( + TfLiteMiniBenchmarkResult* result); +size_t TfLiteMiniBenchmarkResultFlatBufferDataSize( + TfLiteMiniBenchmarkResult* result); +// Free memory allocated with `result`. +void TfLiteMiniBenchmarkResultFree(TfLiteMiniBenchmarkResult* result); + +// APIs of TfLiteMiniBenchmarkCustomValidationInfo. +typedef struct TfLiteMiniBenchmarkCustomValidationInfo + TfLiteMiniBenchmarkCustomValidationInfo; +void TfLiteMiniBenchmarkCustomValidationInfoSetBuffer( + TfLiteMiniBenchmarkCustomValidationInfo* custom_validation, int batch_size, + uint8_t* buffer, size_t* buffer_dim, int buffer_dim_size); +void TfLiteMiniBenchmarkCustomValidationInfoSetAccuracyValidator( + TfLiteMiniBenchmarkCustomValidationInfo* custom_validation, + void* accuracy_validator_user_data, + bool (*accuracy_validator_func)(void* user_data, + uint8_t* benchmark_result_data, + int benchmark_result_data_size)); + +// APIs of TfLiteMiniBenchmarkSettings. +typedef struct TfLiteMiniBenchmarkSettings TfLiteMiniBenchmarkSettings; +TfLiteMiniBenchmarkSettings* TfLiteMiniBenchmarkSettingsCreate(); +TfLiteMiniBenchmarkCustomValidationInfo* +TfLiteMiniBenchmarkSettingsCustomValidationInfo( + TfLiteMiniBenchmarkSettings* settings); +void TfLiteMiniBenchmarkSettingsSetFlatBufferData( + TfLiteMiniBenchmarkSettings* settings, uint8_t* flatbuffer_data, + size_t flatbuffer_data_size); +void TfLiteMiniBenchmarkSettingsSetErrorReporter( + TfLiteMiniBenchmarkSettings* settings, void* error_reporter_user_data, + int (*error_reporter_func)(void* user_data, const char* format, + va_list args)); +void TfLiteMiniBenchmarkSettingsFree(TfLiteMiniBenchmarkSettings* settings); + +// Others. +// Trigger validation for `settings` and return the validation result. +// This returns a pointer, that you must free using +// TfLiteMiniBenchmarkResultFree(). +TfLiteMiniBenchmarkResult* TfLiteBlockingValidatorRunnerTriggerValidation( + TfLiteMiniBenchmarkSettings* settings); + +// This function is a private function that shouldn't be considered as part of +// the APIs. +// TODO: b/290615172 - Remove the function from this header. +void TfLiteMiniBenchmarkSettingsSetGpuPluginHandle( + TfLiteMiniBenchmarkSettings* settings, void* gpu_plugin_handle); + +#ifdef __cplusplus +} // extern "C". +#endif +#endif // TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_test.cc b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_test.cc similarity index 99% rename from tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_test.cc rename to tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_test.cc index 7e251af1d168b2..14ed33a048cdb4 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_test.cc +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h" +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" #include diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h similarity index 91% rename from tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h rename to tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h index bdcef41b9752cf..adace96a21ff94 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ +#ifndef TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ +#define TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ #include #include @@ -87,4 +87,4 @@ struct TfLiteMiniBenchmarkSettings { } // extern "C". #endif -#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ +#endif // TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ diff --git a/tensorflow/lite/core/shims/cc_library_with_tflite.bzl b/tensorflow/lite/core/shims/cc_library_with_tflite.bzl index 5133db9b4134da..7ed021d9d5c9d6 100644 --- a/tensorflow/lite/core/shims/cc_library_with_tflite.bzl +++ b/tensorflow/lite/core/shims/cc_library_with_tflite.bzl @@ -7,7 +7,7 @@ load( "tflite_custom_c_library", "tflite_jni_binary", ) -load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library") load("@bazel_skylib//rules:build_test.bzl", "build_test") def _concat(lists): @@ -93,6 +93,34 @@ def android_library_with_tflite( **kwargs ) +def android_binary_with_tflite( + name, + deps = [], + tflite_deps = [], + **kwargs): + """Defines an android_binary that uses the TFLite shims. + + This is a hook to allow applying different build flags (etc.) + for targets that use the TFLite shims. + + Note that this build rule doesn't itself add any dependencies on + TF Lite; this macro should normally be used in conjunction with a + direct or indirect 'tflite_deps' dependency on one of the "shim" + library targets from //tensorflow/lite/core/shims:*. + + Args: + name: as for android_binary. + deps: as for android_binary. + tflite_deps: dependencies on rules that are themselves defined using + 'cc_library_with_tflite' / 'android_library_with_tflite'. + **kwargs: Additional android_binary parameters. + """ + android_binary( + name = name, + deps = deps + tflite_deps, + **kwargs + ) + def cc_library_with_tflite( name, srcs = [], diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 0de9243868f19a..0f21ec72e5908f 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1407,6 +1407,7 @@ TfLiteStatus Subgraph::MayAllocateOpOutput(TfLiteNode* node) { if (ShouldOptimizeMemoryForLargeTensors()) { for (int i = 0; i < node->outputs->size; ++i) { int tensor_index = node->outputs->data[i]; + if (tensor_index == kTfLiteOptionalTensor) continue; TfLiteTensor* tensor = &context_.tensors[tensor_index]; if (tensor->data.raw == nullptr && tensor->allocation_type == kTfLiteDynamic) { diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 0dafc00ba0c70d..92ef4a610c6dd1 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -10,7 +10,7 @@ load( load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist") load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library", "tflite_flex_shared_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") default_visibility = [ "//tensorflow/compiler/mlir/lite:__subpackages__", @@ -37,7 +37,7 @@ cc_library( name = "buffer_map", srcs = ["buffer_map.cc"], hdrs = ["buffer_map.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_opts_nortti_if_lite_protos(), features = tf_features_nolayering_check_if_ios(), deps = [ @@ -58,7 +58,7 @@ cc_library( name = "buffer_map_util", srcs = ["buffer_map_util.cc"], hdrs = ["buffer_map_util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_opts_nortti_if_lite_protos(), features = tf_features_nolayering_check_if_ios(), deps = [ @@ -106,7 +106,7 @@ tf_cc_test( # ) tflite_flex_cc_library( name = "delegate", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], ) @@ -133,7 +133,7 @@ cc_library( srcs = [ "delegate_symbol.cc", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = ["//visibility:public"], deps = [ @@ -155,7 +155,7 @@ cc_library( hdrs = [ "delegate.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tflite_copts() + tf_opts_nortti_if_android(), features = tf_features_nolayering_check_if_ios(), visibility = ["//visibility:public"], @@ -208,7 +208,7 @@ cc_library( name = "delegate_data", srcs = ["delegate_data.cc"], hdrs = ["delegate_data.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_opts_nortti_if_android(), features = tf_features_nolayering_check_if_ios(), visibility = ["//visibility:public"], @@ -259,7 +259,7 @@ tf_cc_test( cc_library( name = "subgraph_resource", hdrs = ["subgraph_resource.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), features = tf_features_nolayering_check_if_ios(), deps = [ "//tensorflow/lite:cc_api_experimental", @@ -312,7 +312,7 @@ cc_library( name = "util", srcs = ["util.cc"], hdrs = ["util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), features = tf_features_nolayering_check_if_ios(), #TODO(b/206038955): Consider restrict the visibility to '//third_party/fcp/client:__subpackages__'. visibility = ["//visibility:public"], @@ -360,7 +360,7 @@ cc_library( "allowlisted_flex_ops.h", "allowlisted_flex_ops_internal.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), features = tf_features_nolayering_check_if_ios(), visibility = internal_visibility_allowlist(), deps = if_mobile([ @@ -403,7 +403,7 @@ cc_library( cc_library( name = "tflite_subgraph_execute", srcs = ["tflite_subgraph_execute.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_opts_nortti_if_android(), features = tf_features_nolayering_check_if_ios(), deps = [ diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index de0f698ddedd6a..7b4dc17298f1c7 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -85,7 +85,7 @@ class OpInputs { } forwardable_.resize(inputs_.size()); } - ~OpInputs() {} + ~OpInputs() = default; int Size() const { return inputs_.size(); } @@ -438,7 +438,7 @@ tensorflow::Status DelegateKernel::ExecuteOpKernelRunner( } DelegateKernel::DelegateKernel() : op_data_(new OpData) {} -DelegateKernel::~DelegateKernel() {} +DelegateKernel::~DelegateKernel() = default; TfLiteStatus DelegateKernel::Init(TfLiteContext* context, const TfLiteDelegateParams* params) { @@ -572,20 +572,24 @@ TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) { tensor_ref_count[tensor_index] += 2; } - const bool shapes_are_valid = - (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk); - if (shapes_are_valid) { - TFLITE_LOG(tflite::TFLITE_LOG_INFO, - "FlexDelegate: All tensor shapes are consistent."); - } else { - TFLITE_LOG(tflite::TFLITE_LOG_WARNING, - "FlexDelegate: Some tensor shapes are inconsistent."); + // Output shapes which may have initially been inferable may no longer be + // after ResizeInputTensor has been called, so it must be checked again. + if (shapes_are_valid_) { + shapes_are_valid_ = + (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk); + if (shapes_are_valid_) { + TFLITE_LOG(tflite::TFLITE_LOG_INFO, + "FlexDelegate: All tensor shapes are consistent."); + } else { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "FlexDelegate: Some tensor shapes are inconsistent."); + } } // All output tensors are allocated by TensorFlow, so we mark them as // kTfLiteDynamic. for (auto tensor_index : op_data_->subgraph_outputs) { - if (!shapes_are_valid) { + if (!shapes_are_valid_) { SetTensorToDynamic(&context->tensors[tensor_index]); } ++tensor_ref_count[tensor_index]; diff --git a/tensorflow/lite/delegates/flex/kernel.h b/tensorflow/lite/delegates/flex/kernel.h index fabb8367284306..ee162148af5094 100644 --- a/tensorflow/lite/delegates/flex/kernel.h +++ b/tensorflow/lite/delegates/flex/kernel.h @@ -60,6 +60,10 @@ class DelegateKernel : public SimpleDelegateKernelInterface { const std::map& GetTensorReleaseMap() const; std::unique_ptr op_data_; + + // Indicates that the output shapes may be inferred using the input shapes and + // May be allocated during Prepare. + bool shapes_are_valid_ = true; }; } // namespace flex diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 40f95082cd86cf..875c2a4f3da7df 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -240,6 +240,7 @@ cc_library( }) + [ ":api", ":delegate_options", + ":tflite_profile", "//tensorflow/lite:kernel_api", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/core/async:backend_async_kernel_interface", @@ -267,3 +268,14 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "tflite_profile", + srcs = ["tflite_profile.cc"], + hdrs = ["tflite_profile.h"], + deps = [ + "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates/gpu/common/task:profiling_info", + "@com_google_absl//absl/time", + ], +) diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 50ab40d61f8206..d710059af90886 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -36,6 +36,7 @@ cc_library( ":tensor", ":tensor_type_util", "//tensorflow/lite/delegates/gpu:api", + "//tensorflow/lite/delegates/gpu:tflite_profile", "//tensorflow/lite/delegates/gpu/cl/kernels:converter", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index 21462b111af1de..490836435a02df 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/tflite_profile.h" #ifdef CL_DELEGATE_ALLOW_GL #include @@ -454,6 +455,7 @@ class InferenceRunnerImpl : public CLInferenceRunner { #endif ) : queue_(environment->queue()), + profiling_queue_(environment->profiling_queue()), context_(std::move(context)) #ifdef CL_DELEGATE_ALLOW_GL , @@ -555,8 +557,14 @@ class InferenceRunnerImpl : public CLInferenceRunner { } absl::Status RunWithoutExternalBufferCopy() override { - RETURN_IF_ERROR(context_->AddToQueue(queue_)); - clFlush(queue_->queue()); + if (IsTfLiteProfilerActive()) { + ProfilingInfo profiling_info; + RETURN_IF_ERROR(context_->Profile(profiling_queue_, &profiling_info)); + AddTfLiteProfilerEvents(&profiling_info); + } else { + RETURN_IF_ERROR(context_->AddToQueue(queue_)); + clFlush(queue_->queue()); + } return absl::OkStatus(); } @@ -585,6 +593,7 @@ class InferenceRunnerImpl : public CLInferenceRunner { } CLCommandQueue* queue_; + ProfilingCommandQueue* profiling_queue_; std::unique_ptr context_; #ifdef CL_DELEGATE_ALLOW_GL std::unique_ptr gl_interop_fabric_; diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc index 117467691754dc..2abe35c7a248f9 100644 --- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "absl/strings/str_replace.h" #include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" @@ -606,7 +607,7 @@ GPUOperation CreateGpuOperation(const OperationDef& definition, op.elementwise_code_ = std::move(descriptor.code); op.elementwise_ = true; if (definition.src_tensors.size() > 1 && - op.elementwise_code_.find("in2_value")) { + absl::StrContains(op.elementwise_code_, "in2_value")) { const auto second_tensor_def = definition.src_tensors[1]; if (NeedsBroadcast(second_tensor_def, second_shape)) { const std::string x_coord = second_shape.w == 1 ? "0" : "X_COORD"; diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index e0ba598c843fe4..f51e01fe036a9a 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include "tensorflow/lite/delegates/gpu/common/quantization_util.h" #include "tensorflow/lite/delegates/gpu/delegate_options.h" +#include "tensorflow/lite/delegates/gpu/tflite_profile.h" #include "tensorflow/lite/delegates/serialization.h" #if defined(__ANDROID__) @@ -171,7 +172,7 @@ class Delegate { delegate_.CopyFromBufferHandle = nullptr; delegate_.CopyToBufferHandle = nullptr; delegate_.FreeBufferHandle = nullptr; - delegate_.flags = kTfLiteDelegateFlagsNone; + delegate_.flags = kTfLiteDelegateFlagsPerOperatorProfiling; options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default(); if (options_.max_delegated_partitions <= 0) { options_.max_delegated_partitions = 1; @@ -1496,6 +1497,8 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { telemetry::TelemetryReportDelegateSettings( context, "GpuDelegate::DelegatePrepare", telemetry::TelemetrySource::TFLITE_GPU, delegate_setting); + + SetTfLiteProfiler(context->profiler); return status; } diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc index 3d21a0aee8e4c5..fb30986290b4b4 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc @@ -15,17 +15,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h" -#include -#include -#include -#include #include #include #include #include #include -#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/lite/delegates/gpu/common/convert.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -37,38 +32,48 @@ namespace gl { namespace { -bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) { - if (ctx.input_shapes.size() != 2) return false; - - // [H, W, C] x [H, W, 0][0] - if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] && - ctx.input_shapes[0][2] == ctx.input_shapes[1][2] && - ctx.input_shapes[1][3] == 1) { - return true; +// Returns the coordinate to iterate over the second runtime tensor. +absl::Status GetCoordinate(const NodeShader::GenerationContext& ctx, int dim, + const std::string& default_coord, + std::string* coord) { + std::string result; + if (ctx.input_shapes[1][dim] == 1 && ctx.input_shapes[0][dim] != 1) { + result = "0"; + } else if (ctx.input_shapes[0][dim] == ctx.input_shapes[1][dim]) { + result = default_coord; + } else { + return absl::InvalidArgumentError( + absl::StrCat("Second runtime tensor dimension ", dim, + " must either match " + "first tensor's dimensions or be 1.")); } - - // [H, W, C] x [H, W, C] - if (ctx.input_shapes[0] == ctx.input_shapes[1]) return true; - - // [H, W, C] x [0, 0, C] - return ctx.input_shapes[1][1] == 1 && ctx.input_shapes[1][2] == 1 && - ctx.input_shapes[0][3] == ctx.input_shapes[1][3]; + *coord = result; + return absl::OkStatus(); } -absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { - std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "; - if (ctx.input_shapes[1][3] == 1) { - // [H, W, C] x [H, W, 0][0] - absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;"); - } else if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] && - ctx.input_shapes[0][2] == ctx.input_shapes[1][2]) { - // [H, W, C] x [H, W, C] - absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;"); - } else { - // [H, W, C] x [0, 0, C] - absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;"); +absl::Status GenerateMultiplyRuntimeTensorCode( + const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) { + std::string x_coord, y_coord, z_coord; + RETURN_IF_ERROR( + GetCoordinate(ctx, /*dim=*/2, /*default_coord=*/"gid.x", &x_coord)); + RETURN_IF_ERROR( + GetCoordinate(ctx, /*dim=*/1, /*default_coord=*/"gid.y", &y_coord)); + RETURN_IF_ERROR( + GetCoordinate(ctx, /*dim=*/3, /*default_coord=*/"gid.z", &z_coord)); + + std::string source = + absl::StrCat("vec4 input1_value = $input_data_1[", x_coord, ", ", y_coord, + ", ", z_coord, "]$;"); + // Single channel mask support. Without this duplication, the rest of channels + // will be zeros, which will make the mul operation produce incorrect result. + if (ctx.input_shapes[1][3] == 1 && ctx.input_shapes[0][3] != 1) { + absl::StrAppend( + &source, + "\ninput1_value = vec4(input1_value.x, input1_value.x, input1_value.x, " + "input1_value.x);\n"); } + absl::StrAppend( + &source, "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * input1_value;"); *generated_code = { /*parameters=*/{}, @@ -83,7 +88,7 @@ absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, return absl::OkStatus(); } -absl::Status GenerateMultiplyScalarCode( +absl::Status GenerateMultiplyConstantTensorCode( const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) { const auto& attr = std::any_cast(ctx.op_attr); @@ -123,6 +128,26 @@ absl::Status GenerateMultiplyScalarCode( } if (std::holds_alternative>(attr.param)) { + bool single_channel_mask = + std::get>(attr.param).shape.c == 1; + std::string source; + if (single_channel_mask) { + source = "vec4 const_val = $hwc_buffer[gid.x, gid.y, 0]$;"; + // Single channel mask support. Without this duplication, the rest of + // channels will be zeros, which will make the mul operation produce + // incorrect result. + if (ctx.input_shapes[0][3] != 1) { + absl::StrAppend( + &source, + "\nconst_val = vec4(const_val.x, const_val.x, const_val.x, " + "const_val.x);\n"); + } + } else { + source = "vec4 const_val = $hwc_buffer[gid.x, gid.y, gid.z]$;"; + } + + absl::StrAppend(&source, "value_0 *= const_val;"); + *generated_code = { /*parameters=*/{}, /*objects=*/ @@ -140,7 +165,8 @@ absl::Status GenerateMultiplyScalarCode( static_cast(ctx.input_shapes[0][1]), DivideRoundUp(static_cast(ctx.input_shapes[0][3]), 4)), /*workgroup=*/uint3(), - /*source_code=*/"value_0 *= $hwc_buffer[gid.x, gid.y, gid.z]$;", + /*source_code=*/ + std::move(source), /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; @@ -154,10 +180,10 @@ class Multiply : public NodeShader { public: absl::Status GenerateCode(const GenerationContext& ctx, GeneratedCode* generated_code) const final { - if (IsApplyMaskSupported(ctx)) { - return GenerateApplyMaskCode(ctx, generated_code); + if (ctx.input_shapes.size() == 2) { + return GenerateMultiplyRuntimeTensorCode(ctx, generated_code); } else { - return GenerateMultiplyScalarCode(ctx, generated_code); + return GenerateMultiplyConstantTensorCode(ctx, generated_code); } } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mul_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mul_test.cc index 3d931df45247f4..e8379610912c5b 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/mul_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mul_test.cc @@ -31,27 +31,59 @@ namespace gpu { namespace gl { namespace { -TEST(MulTest, Scalar) { +TEST(MulTest, ConstantTensorMatchingShape) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; - input.shape = BHWC(1, 2, 2, 1); + input.shape = BHWC(1, 1, 2, 2); TensorRef output; output.type = DataType::FLOAT32; output.ref = 1; - output.shape = BHWC(1, 2, 2, 1); + output.shape = input.shape; ElementwiseAttributes attr; - attr.param = 2.f; + Tensor tensor_3d; + tensor_3d.shape.h = input.shape.h; + tensor_3d.shape.w = input.shape.w; + tensor_3d.shape.c = input.shape.c; + tensor_3d.id = 2; + tensor_3d.data = {-2, 2, -3, 3}; + attr.param = std::move(tensor_3d); SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); - EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8})); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {-2, 4, -9, 12})); +} + +TEST(MulTest, ConstantTensorSingleChannel) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 1, 2, 2); + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 1; + output.shape = input.shape; + + ElementwiseAttributes attr; + Tensor tensor_3d; + tensor_3d.shape.h = input.shape.h; + tensor_3d.shape.w = input.shape.w; + tensor_3d.shape.c = 1; + tensor_3d.id = 2; + tensor_3d.data = {-2, 2}; + attr.param = std::move(tensor_3d); + + SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); + ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {-2, -4, 6, 8})); } -TEST(MulTest, Linear) { +TEST(MulTest, ConstantTensorLinear) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; @@ -60,7 +92,7 @@ TEST(MulTest, Linear) { TensorRef output; output.type = DataType::FLOAT32; output.ref = 1; - output.shape = BHWC(1, 1, 2, 2); + output.shape = input.shape; ElementwiseAttributes attr; Tensor tensor; @@ -75,33 +107,76 @@ TEST(MulTest, Linear) { EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12})); } -TEST(MulTest, ConstTensor3D) { +TEST(MulTest, ConstantTensorScalar) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; - input.shape = BHWC(1, 1, 2, 2); + input.shape = BHWC(1, 2, 2, 1); TensorRef output; output.type = DataType::FLOAT32; output.ref = 1; - output.shape = BHWC(1, 1, 2, 2); + output.shape = input.shape; ElementwiseAttributes attr; - Tensor tensor_3d; - tensor_3d.shape.h = 1; - tensor_3d.shape.w = 2; - tensor_3d.shape.c = 2; - tensor_3d.id = 2; - tensor_3d.data = {-2, 2, -3, 3}; - attr.param = std::move(tensor_3d); + attr.param = 2.f; SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); - EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {-2, 4, -9, 12})); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8})); +} + +TEST(MulTest, RuntimeTensorMatchingShapeNonOnes) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 2, 2, 2); + + TensorRef mask; + mask.type = DataType::FLOAT32; + mask.ref = 1; + mask.shape = input.shape; + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 2; + output.shape = input.shape; + + SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, + {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4, -1, -2, -3, -4})); + ASSERT_TRUE(model.PopulateTensor(1, {5, 6, 7, 8, 9, 10, 11, 12})); + ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {5, 12, 21, 32, -9, -20, -33, -48})); +} + +TEST(MulTest, RuntimeTensorMatchingShapeHeightOne) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 1, 2, 2); + + TensorRef mask; + mask.type = DataType::FLOAT32; + mask.ref = 1; + mask.shape = input.shape; + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 2; + output.shape = input.shape; + + SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, + {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); + ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4})); + ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16})); } -TEST(MulTest, MaskChannel1) { +TEST(MulTest, RuntimeTensorSingleChannel) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; @@ -110,12 +185,12 @@ TEST(MulTest, MaskChannel1) { TensorRef mask; mask.type = DataType::FLOAT32; mask.ref = 1; - mask.shape = BHWC(1, 1, 2, 1); + mask.shape = BHWC(1, input.shape.h, input.shape.w, 1); TensorRef output; output.type = DataType::FLOAT32; output.ref = 2; - output.shape = BHWC(1, 1, 2, 2); + output.shape = input.shape; SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output}); @@ -125,7 +200,7 @@ TEST(MulTest, MaskChannel1) { EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 9, 12})); } -TEST(MulTest, MaskChannelEqualsToInputChannel) { +TEST(MulTest, RuntimeTensorLinear) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; @@ -134,19 +209,43 @@ TEST(MulTest, MaskChannelEqualsToInputChannel) { TensorRef mask; mask.type = DataType::FLOAT32; mask.ref = 1; - mask.shape = BHWC(1, 1, 2, 2); + mask.shape = BHWC(1, 1, 1, input.shape.c); TensorRef output; output.type = DataType::FLOAT32; output.ref = 2; - output.shape = BHWC(1, 1, 2, 2); + output.shape = input.shape; SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output}); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); - ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4})); + ASSERT_TRUE(model.PopulateTensor(1, {1, 2})); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); - EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16})); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 3, 8})); +} + +TEST(MulTest, RuntimeTensorScalar) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 1, 2, 2); + + TensorRef mask; + mask.type = DataType::FLOAT32; + mask.ref = 1; + mask.shape = BHWC(1, 1, 1, 1); + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 2; + output.shape = input.shape; + + SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, + {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); + ASSERT_TRUE(model.PopulateTensor(1, {5})); + ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {5, 10, 15, 20})); } } // namespace diff --git a/tensorflow/lite/delegates/gpu/tflite_profile.cc b/tensorflow/lite/delegates/gpu/tflite_profile.cc new file mode 100644 index 00000000000000..f0b95553845db4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/tflite_profile.cc @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/gpu/tflite_profile.h" + +#include "absl/time/time.h" +#include "tensorflow/lite/core/api/profiler.h" + +namespace tflite { +namespace gpu { + +static void* s_profiler = nullptr; + +bool IsTfLiteProfilerActive() { return s_profiler != nullptr; } + +void SetTfLiteProfiler(void* profiler) { s_profiler = profiler; } + +void* GetTfLiteProfiler() { return s_profiler; } + +void AddTfLiteProfilerEvents(tflite::gpu::ProfilingInfo* profiling_info) { + tflite::Profiler* profile = + reinterpret_cast(GetTfLiteProfiler()); + if (profile == nullptr) return; + + int node_index = 0; + for (const auto& dispatch : profiling_info->dispatches) { + profile->AddEvent( + dispatch.label.c_str(), + Profiler::EventType::DELEGATE_PROFILED_OPERATOR_INVOKE_EVENT, + absl::ToDoubleMicroseconds(dispatch.duration), node_index++); + } +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/tflite_profile.h b/tensorflow/lite/delegates/gpu/tflite_profile.h new file mode 100644 index 00000000000000..6e9d7310ffa04c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/tflite_profile.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_TFLITE_PROFILE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_TFLITE_PROFILE_H_ + +#include "tensorflow/lite/delegates/gpu/common/task/profiling_info.h" + +namespace tflite { +namespace gpu { + +// Returns if TFLite Profiler is active. +bool IsTfLiteProfilerActive(); + +// Save the given TFLite Profiler object (from TfLiteContext) for op profiling. +void SetTfLiteProfiler(void* profiler); + +// Returns saved TFLite Profiler object. +void* GetTfLiteProfiler(); + +// Generate TFLite Profiler events with the given ProfilingInfo object. +void AddTfLiteProfilerEvents(tflite::gpu::ProfilingInfo* profiling_info); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_TFLITE_PROFILE_H_ diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.cc b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.cc index 4546fe0d4ff2bb..a561992ab54ec2 100644 --- a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.cc +++ b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/experimental/acceleration/compatibility/database_generated.h" #include "tensorflow/lite/experimental/acceleration/compatibility/devicedb.h" #include "tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_binary.h" +#include "tensorflow/lite/experimental/acceleration/compatibility/variables.h" namespace tflite { namespace acceleration { @@ -111,14 +112,7 @@ gpu::CompatibilityStatus GPUCompatibilityList::GetStatus( CanonicalizeValues(&variables); if (!database_) return gpu::CompatibilityStatus::kUnknown; UpdateVariablesFromDatabase(&variables, *database_); - const std::string& status = variables[gpu::kStatus]; - if (status == gpu::kStatusSupported) { - return gpu::CompatibilityStatus::kSupported; - } else if (status == gpu::kStatusUnsupported) { - return gpu::CompatibilityStatus::kUnsupported; - } else { - return gpu::CompatibilityStatus::kUnknown; - } + return StringToCompatibilityStatus(variables[gpu::kStatus]); } TfLiteGpuDelegateOptionsV2 GPUCompatibilityList::GetBestOptionsFor( @@ -156,5 +150,29 @@ std::map GPUCompatibilityList::InfosToMap( return variables; } +// static +std::string GPUCompatibilityList::CompatibilityStatusToString( + gpu::CompatibilityStatus status) { + switch (status) { + case gpu::CompatibilityStatus::kSupported: + return gpu::kStatusSupported; + case gpu::CompatibilityStatus::kUnsupported: + return gpu::kStatusUnsupported; + case gpu::CompatibilityStatus::kUnknown: + return gpu::kStatusUnknown; + } +} + +// static +gpu::CompatibilityStatus GPUCompatibilityList::StringToCompatibilityStatus( + absl::string_view status) { + if (status == gpu::kStatusSupported) { + return gpu::CompatibilityStatus::kSupported; + } else if (status == gpu::kStatusUnsupported) { + return gpu::CompatibilityStatus::kUnsupported; + } + return gpu::CompatibilityStatus::kUnknown; +} + } // namespace acceleration } // namespace tflite diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h index de2b66c5d7f2b8..59f73c2c9a7759 100644 --- a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h +++ b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/delegate_options.h" #include "tensorflow/lite/experimental/acceleration/compatibility/android_info.h" @@ -113,6 +114,16 @@ class GPUCompatibilityList { const AndroidInfo& android_info, const ::tflite::gpu::GpuInfo& gpu_info) const; + // Converts the compatibility status enum value to the corresponding status + // string. + static std::string CompatibilityStatusToString( + gpu::CompatibilityStatus status); + + // Converts the status string to the corresponding compatibility status enum + // value. + static gpu::CompatibilityStatus StringToCompatibilityStatus( + absl::string_view status); + protected: const DeviceDatabase* database_; diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_test.cc b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_test.cc index c7427cea792dde..d7d4538c94c718 100644 --- a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_test.cc +++ b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_test.cc @@ -152,4 +152,34 @@ TEST(GPUCompatibility, CreationWithNullCompatibilityListFlatbuffer) { EXPECT_EQ(list, nullptr); } +TEST(GPUCompatibility, ConvertCompatibilityStatusToStringCorrectly) { + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::CompatibilityStatusToString( + tflite::acceleration::gpu::CompatibilityStatus::kSupported), + tflite::acceleration::gpu::kStatusSupported); + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::CompatibilityStatusToString( + tflite::acceleration::gpu::CompatibilityStatus::kUnsupported), + tflite::acceleration::gpu::kStatusUnsupported); + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::CompatibilityStatusToString( + tflite::acceleration::gpu::CompatibilityStatus::kUnknown), + tflite::acceleration::gpu::kStatusUnknown); +} + +TEST(GPUCompatibility, ConvertStringToCompatibilityStatusCorrectly) { + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::StringToCompatibilityStatus( + tflite::acceleration::gpu::kStatusSupported), + tflite::acceleration::gpu::CompatibilityStatus::kSupported); + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::StringToCompatibilityStatus( + tflite::acceleration::gpu::kStatusUnsupported), + tflite::acceleration::gpu::CompatibilityStatus::kUnsupported); + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::StringToCompatibilityStatus( + tflite::acceleration::gpu::kStatusUnknown), + tflite::acceleration::gpu::CompatibilityStatus::kUnknown); +} + } // namespace diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD index 4d6f631e7d5a85..1a6282228fd76e 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD @@ -150,6 +150,9 @@ cc_library( ":libjpeg_handle_hdr", "//tensorflow/lite/core/c:c_api_types", ] + libjpeg_deps(), + # Some targets only have an implicit dependency on LibjpegHandle. + # This avoids warnings about backwards references when linking. + alwayslink = True, ) cc_library( @@ -451,6 +454,7 @@ cc_library( "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/tools:model_loader", + "//tensorflow/lite/tools/benchmark:register_custom_op", ], ) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl b/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl index dbab6f2bb7af66..1c2a1a3f8cb561 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl @@ -21,7 +21,17 @@ load( load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "add_suffix") load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps") -def embedded_binary(name, binary, array_variable_name, testonly = False): +def _concat(lists): + """Concatenate a list of lists, without requiring the inner lists to be iterable. + + This allows the inner lists to be obtained by calls to select(). + """ + result = [] + for selected_list in lists: + result = result + selected_list + return result + +def embedded_binary(name, binary, array_variable_name, testonly = False, exec_properties = None): """Create a cc_library that embeds a binary as constant data. Args: @@ -55,6 +65,7 @@ def embedded_binary(name, binary, array_variable_name, testonly = False): srcs = [cc_name], hdrs = [h_name], testonly = testonly, + exec_properties = exec_properties, ) def validation_model( @@ -173,33 +184,66 @@ def validation_test(name, validation_model, tags = [], copts = [], deps = []): ], "//conditions:default": [], }) + libjpeg_handle_deps(), + linkstatic = 1, ) def cc_library_with_forced_in_process_benchmark_variant( name, deps = [], + forced_in_process_deps = [], in_process_deps = [], + non_in_process_deps_selects = [], **kwargs): """Defines a cc_library that optionally forces benchmark runs in process. This generates two cc_library target. The first one runs the benchmark in a separate process on Android, while it runs the benchmark in process on all - other platforms. The second one, which has "_in_process" appended to the - name, forces benchmark runs in process. + other platforms. It doesn't have TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS + defined. + The second one, which has "_in_process" appended to the name, forces + benchmark runs in process on all platforms. It has + TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS defined. + + The default option for MiniBenchmark is to run the benchmark in a separate + process on Android, as this is safer than running the benchmark in the app + process. However, forcing the benchmark to run in-process on Android allows + the benchmark to reuse the same TF Lite runtime that is initialized in the + application process. These two variants may use different dependencies. + For example, the in-process variant uses the statically linked libjpeg + handle, while the other variant uses the dynamically linked libjpeg handle + on Android to minimize binary size. + + This build rule ensures that the dependencies listed in + "forced_in_process_deps" are added only when + TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS is defined, that the dependencies + listed in "non_in_process_deps_selects" are added only when + TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS is NOT defined, and that + TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS is defined automatically when + using the "_in_process" target. + Args: name: determines the name used for the generated cc_library targets. + forced_in_process_deps: dependencies that will be enabled only when the + benchmark is forced to run in-process on all platforms. This should be + used for dependencies arising from code inside + '#ifdef TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS'. deps: dependencies that will be unconditionally included in the deps of the generated cc_library targets. in_process_deps: dependencies on rules that are themselves defined using 'cc_library_with_forced_in_process_benchmark_variant'. Must be iterable, so cannot be computed by calling 'select'. + non_in_process_deps_selects: A list of dictionaries that will be + converted to dependencies with select on rules. The dependencies will + be enabled only when the benchmark runs in a separate process on + Android. This should be used for dependencies arising from code inside + '#ifndef TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS'. **kwargs: Additional cc_library parameters. """ native.cc_library( name = name, - deps = deps + in_process_deps + [ + deps = deps + in_process_deps + _concat([select(map) for map in non_in_process_deps_selects]) + [ clean_dep("//tensorflow/lite/experimental/acceleration/mini_benchmark:tflite_acceleration_in_process_default"), ], **kwargs @@ -208,7 +252,7 @@ def cc_library_with_forced_in_process_benchmark_variant( in_process_deps_renamed = [add_suffix(in_process_dep, "_in_process") for in_process_dep in in_process_deps] native.cc_library( name = name + "_in_process", - deps = deps + in_process_deps_renamed + [ + deps = deps + in_process_deps_renamed + forced_in_process_deps + [ clean_dep("//tensorflow/lite/experimental/acceleration/mini_benchmark:tflite_acceleration_in_process_enable"), ], **kwargs diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/BUILD index b6e36c51f11a96..e55041460e02b5 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -load("//tensorflow/lite/experimental/acceleration/mini_benchmark:build_defs.bzl", "cc_library_with_forced_in_process_benchmark_variant") -load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps", "minibenchmark_visibility_allowlist") +load("//tensorflow/lite:build_def.bzl", "tflite_cc_library_with_c_headers_test") +load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite_with_c_headers_test") +load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "minibenchmark_visibility_allowlist") default_visibility_group = [ "//tensorflow/lite/experimental/acceleration/mini_benchmark:__subpackages__", @@ -25,44 +26,19 @@ package( licenses = ["notice"], ) -cc_library_with_forced_in_process_benchmark_variant( +# This target runs MiniBenchmark in a separate processon Android, while it runs MiniBenchmark +# in-process on all other platforms. +cc_library_with_tflite_with_c_headers_test( name = "c_api", - srcs = ["c_api.cc"], hdrs = ["c_api.h"], - in_process_deps = [ - "//tensorflow/lite/experimental/acceleration/mini_benchmark:blocking_validator_runner", - ], - deps = [ - ":c_api_types", - "//tensorflow/lite/acceleration/configuration:configuration_fbs", - "//tensorflow/lite/acceleration/configuration/c:delegate_plugin", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:benchmark_result_evaluator", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:status_codes", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:validator_runner_entrypoint", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:validator_runner_options", - "@flatbuffers", - ], + deps = ["//tensorflow/lite/core/experimental/acceleration/mini_benchmark/c:c_api"], ) -cc_test( - name = "c_api_test", - srcs = ["c_api_test.cc"], +# This target forces MiniBenchmark to run in-process on all platforms including Android. +tflite_cc_library_with_c_headers_test( + name = "c_api_in_process", + hdrs = ["c_api.h"], deps = [ - ":c_api", - "//tensorflow/lite/acceleration/configuration:configuration_fbs", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_mobilenet_model", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_mobilenet_validation_model", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_simple_addition_model", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_test_helper", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:status_codes", - "@com_google_googletest//:gtest_main", - "@flatbuffers", - "@flatbuffers//:runtime_cc", - ] + libjpeg_handle_deps(), -) - -cc_library( - name = "c_api_types", - hdrs = ["c_api_types.h"], - visibility = ["//visibility:private"], + "//tensorflow/lite/core/experimental/acceleration/mini_benchmark/c:c_api_in_process", + ], ) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h index 2d68200d457461..e62b599d7e5294 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,62 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ -#include -#include -#include +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" // IWYU pragma: export -#ifdef __cplusplus -extern "C" { -#endif - -// APIs of TfLiteMiniBenchmarkResult. -typedef struct TfLiteMiniBenchmarkResult TfLiteMiniBenchmarkResult; -int TfLiteMiniBenchmarkResultInitStatus(TfLiteMiniBenchmarkResult* result); -uint8_t* TfLiteMiniBenchmarkResultFlatBufferData( - TfLiteMiniBenchmarkResult* result); -size_t TfLiteMiniBenchmarkResultFlatBufferDataSize( - TfLiteMiniBenchmarkResult* result); -// Free memory allocated with `result`. -void TfLiteMiniBenchmarkResultFree(TfLiteMiniBenchmarkResult* result); - -// APIs of TfLiteMiniBenchmarkCustomValidationInfo. -typedef struct TfLiteMiniBenchmarkCustomValidationInfo - TfLiteMiniBenchmarkCustomValidationInfo; -void TfLiteMiniBenchmarkCustomValidationInfoSetBuffer( - TfLiteMiniBenchmarkCustomValidationInfo* custom_validation, int batch_size, - uint8_t* buffer, size_t* buffer_dim, int buffer_dim_size); -void TfLiteMiniBenchmarkCustomValidationInfoSetAccuracyValidator( - TfLiteMiniBenchmarkCustomValidationInfo* custom_validation, - void* accuracy_validator_user_data, - bool (*accuracy_validator_func)(void* user_data, - uint8_t* benchmark_result_data, - int benchmark_result_data_size)); - -// APIs of TfLiteMiniBenchmarkSettings. -typedef struct TfLiteMiniBenchmarkSettings TfLiteMiniBenchmarkSettings; -TfLiteMiniBenchmarkSettings* TfLiteMiniBenchmarkSettingsCreate(); -TfLiteMiniBenchmarkCustomValidationInfo* -TfLiteMiniBenchmarkSettingsCustomValidationInfo( - TfLiteMiniBenchmarkSettings* settings); -void TfLiteMiniBenchmarkSettingsSetFlatBufferData( - TfLiteMiniBenchmarkSettings* settings, uint8_t* flatbuffer_data, - size_t flatbuffer_data_size); -void TfLiteMiniBenchmarkSettingsSetErrorReporter( - TfLiteMiniBenchmarkSettings* settings, void* error_reporter_user_data, - int (*error_reporter_func)(void* user_data, const char* format, - va_list args)); -void TfLiteMiniBenchmarkSettingsSetGpuPluginHandle( - TfLiteMiniBenchmarkSettings* settings, void* gpu_plugin_handle); -void TfLiteMiniBenchmarkSettingsFree(TfLiteMiniBenchmarkSettings* settings); - -// Others. -// Trigger validation for `settings` and return the validation result. -// This returns a pointer, that you must free using -// TfLiteMiniBenchmarkResultFree(). -TfLiteMiniBenchmarkResult* TfLiteBlockingValidatorRunnerTriggerValidation( - TfLiteMiniBenchmarkSettings* settings); - -#ifdef __cplusplus -} // extern "C". -#endif #endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/BUILD index 08f0e0a52fdf59..ca96f4c8d67a8b 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps") +load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps", "register_selected_ops_deps") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -89,7 +89,7 @@ cc_binary( "//tensorflow/lite/tools:command_line_flags", "@com_google_absl//absl/strings", "@flatbuffers", - ] + libjpeg_handle_deps(), + ] + libjpeg_handle_deps() + register_selected_ops_deps(), ) cc_library( @@ -125,5 +125,5 @@ cc_test( "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools:model_loader", "@com_google_googletest//:gtest_main", - ] + libjpeg_handle_deps(), + ] + libjpeg_handle_deps() + register_selected_ops_deps(), ) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder_main.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder_main.cc index 0287f3d1dabf7c..0e11653bb8d1dd 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder_main.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder_main.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_register.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder.h" #include "tensorflow/lite/schema/reflection/schema_generated.h" +#include "tensorflow/lite/tools/benchmark/register_custom_op.h" #include "tensorflow/lite/tools/command_line_flags.h" namespace tflite { @@ -116,6 +117,9 @@ int RunEmbedder(const EmbedderOptions& options) { resolver.AddCustom( "validation/decode_jpeg", ::tflite::acceleration::decode_jpeg_kernel::Register_DECODE_JPEG(), 1); + + RegisterSelectedOps(&resolver); + auto status = embedder.CreateModelWithEmbeddedValidation(&fbb, &resolver); if (!status.ok()) { std::cerr << "Creating model with embedded validation failed: " diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc index 55cb8db57950da..377c281068bed7 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc @@ -226,7 +226,7 @@ TEST_F(LocalizerValidationRegressionTest, NnapiSl) { } #endif // ENABLE_NNAPI_SL_TEST -TEST_F(LocalizerValidationRegressionTest, Gpu) { +TEST_F(LocalizerValidationRegressionTest, GpuAny) { AndroidInfo android_info; auto status = RequestAndroidInfo(&android_info); ASSERT_TRUE(status.ok()); @@ -237,7 +237,47 @@ TEST_F(LocalizerValidationRegressionTest, Gpu) { fbb_.Finish(CreateComputeSettings(fbb_, ExecutionPreference_ANY, CreateTFLiteSettings(fbb_, Delegate_GPU))); #ifdef __ANDROID__ - CheckValidation("GPU"); + CheckValidation("GPUANY"); +#endif // __ANDROID__ +} + +TEST_F(LocalizerValidationRegressionTest, GpuOpenGL) { + AndroidInfo android_info; + auto status = RequestAndroidInfo(&android_info); + ASSERT_TRUE(status.ok()); + if (android_info.is_emulator) { + std::cerr << "Skipping GPU on emulator\n"; + return; + } + fbb_.Finish(CreateComputeSettings( + fbb_, ExecutionPreference_ANY, + CreateTFLiteSettings( + fbb_, Delegate_GPU, 0, + CreateGPUSettings(fbb_, /* allow_precision_loss */ false, + /* allow_quantized_inference */ true, + GPUBackend_OPENGL)))); +#ifdef __ANDROID__ + CheckValidation("GPUOPENGL"); +#endif // __ANDROID__ +} + +TEST_F(LocalizerValidationRegressionTest, GpuOpenCL) { + AndroidInfo android_info; + auto status = RequestAndroidInfo(&android_info); + ASSERT_TRUE(status.ok()); + if (android_info.is_emulator) { + std::cerr << "Skipping GPU on emulator\n"; + return; + } + fbb_.Finish(CreateComputeSettings( + fbb_, ExecutionPreference_ANY, + CreateTFLiteSettings( + fbb_, Delegate_GPU, 0, + CreateGPUSettings(fbb_, /* allow_precision_loss */ false, + /* allow_quantized_inference */ true, + GPUBackend_OPENCL)))); +#ifdef __ANDROID__ + CheckValidation("GPUOPENCL"); #endif // __ANDROID__ } diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl b/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl index c2e0c24bdbef75..aa16873972d0a1 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl @@ -31,5 +31,12 @@ def libjpeg_handle_deps(): def minibenchmark_visibility_allowlist(): """Returns a list of packages that can depend on mini_benchmark.""" return [ + "//tensorflow/lite/core/experimental/acceleration/mini_benchmark/c:__subpackages__", "//tensorflow/lite/tools/benchmark/experimental/delegate_performance:__subpackages__", ] + +def register_selected_ops_deps(): + """Return a list of dependencies for registering selected ops.""" + return [ + clean_dep("//tensorflow/lite/tools/benchmark:register_custom_op"), + ] diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc index a9eee994832f49..98851a68bae1f1 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/tools/benchmark/register_custom_op.h" #include "tensorflow/lite/tools/model_loader.h" #ifndef TEMP_FAILURE_RETRY @@ -331,6 +332,8 @@ MinibenchmarkStatus Validator::CreateInterpreter(int* delegate_error_out, "validation/decode_jpeg", ::tflite::acceleration::decode_jpeg_kernel::Register_DECODE_JPEG(), 1); + RegisterSelectedOps(resolver_.get()); + tflite::InterpreterBuilder builder(*model_loader_->GetModel(), *resolver_); // Add delegate if not running on CPU. if (delegate_ != nullptr) { diff --git a/tensorflow/lite/g3doc/guide/inference.md b/tensorflow/lite/g3doc/guide/inference.md index 64054f6cc5330a..2ea6ec267157e6 100644 --- a/tensorflow/lite/g3doc/guide/inference.md +++ b/tensorflow/lite/g3doc/guide/inference.md @@ -616,16 +616,15 @@ running inference in different languages. All the examples assume that the input shape is defined as `[1/None, 10]`, and need to be resized to `[3, 10]`. -
- -###### C++ {.new-tab} +C++ example: ```c++ // Resize input tensors before allocate tensors interpreter->ResizeInputTensor(/*tensor_index=*/0, std::vector{3,10}); interpreter->AllocateTensors(); ``` -###### Python {.new-tab} + +Python example: ```python # Load the TFLite model in TFLite Interpreter diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc index f961dae3124ace..809c185621e015 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.cc +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -317,20 +317,32 @@ void SubgraphBuilder::BuildXNNPACKSubgraph(Subgraph* subgraph) { } void SubgraphBuilder::BuildInputIsOutputSubgraph(Subgraph* subgraph) { - enum { kInputCounter, kInputValue, kOutputCounter, kTensorCount }; + enum { + kInputCounter, + kInputValue0, + kInputOutput, + kOutputCounter, + kOutputValue0, + kConstRhs, + kTensorCount + }; int first_new_tensor_index; ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), kTfLiteOk); ASSERT_EQ(first_new_tensor_index, 0); - ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue}), kTfLiteOk); - ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kInputValue}), kTfLiteOk); + ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue0, kInputOutput}), + kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kOutputValue0, kInputOutput}), + kTfLiteOk); for (int i = 0; i < kTensorCount; ++i) { SetupTensor(subgraph, i, kTfLiteInt32); } + CreateConstantTensor(subgraph, kConstRhs, {1}, {1}); - AddAddNode(subgraph, kInputCounter, kInputValue, kOutputCounter); + AddAddNode(subgraph, kInputCounter, kConstRhs, kOutputCounter); + AddAddNode(subgraph, kInputValue0, kInputOutput, kOutputValue0); } void SubgraphBuilder::BuildInputIsDifferentOutputSubgraph(Subgraph* subgraph) { diff --git a/tensorflow/lite/kernels/while.cc b/tensorflow/lite/kernels/while.cc index 35e0305dda8dd4..fd2251f10cfc6f 100644 --- a/tensorflow/lite/kernels/while.cc +++ b/tensorflow/lite/kernels/while.cc @@ -251,6 +251,12 @@ TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output)); } + // Prepare and check the body subgraph. + TF_LITE_ENSURE_OK( + context, CopyTensorsShapeAndType( + context, this_subgraph, TfLiteIntArrayView(node->inputs), + body_subgraph, body_subgraph->inputs(), true)); + // Detect when a WHILE input is read only. const std::vector input_tensors_count = this_subgraph->GetInputTensorsCount(); @@ -265,18 +271,15 @@ TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { body_subgraph->tensor(body_subgraph->inputs()[i]); if (body_input->type == kTfLiteString) continue; if (IsResourceOrVariant(body_input)) continue; + TfLiteTensor* this_output = + this_subgraph->tensor(node->outputs->data[i]); + TfLiteTensorDataFree(this_output); node->outputs->data[i] = kTfLiteOptionalTensor; body_input->allocation_type = kTfLiteCustom; } } } - // Prepare and check the body subgraph. - TF_LITE_ENSURE_OK( - context, CopyTensorsShapeAndType( - context, this_subgraph, TfLiteIntArrayView(node->inputs), - body_subgraph, body_subgraph->inputs(), true)); - for (int i = 0; i < num_inputs; ++i) { TfLiteTensor* body_input = body_subgraph->tensor(body_subgraph->inputs()[i]); @@ -300,7 +303,7 @@ TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, !IsDynamicTensor(body_output)); if (!TfLiteIntArrayEqual(body_input->dims, body_output->dims)) { // Don't unnecessarily set an output to dynamic when one of input/output - // is a scalar and the other an tensor of size 1. + // is a scalar and the other a tensor of size 1. // If both tensors are scalars or both tensors have shape [1], then // TfLiteIntArrayEqual would return true. We want to detect when one // tensor is a scalar and the other has shape [1], so the total number @@ -340,6 +343,9 @@ TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { Subgraph* this_subgraph = reinterpret_cast(context->impl_); if (this_subgraph->ShouldOptimizeMemoryForLargeTensors()) { + OpData* op_data = reinterpret_cast(node->user_data); + // Call Prepare to ensure input shapes are propagated to the body subgraph. + op_data->subgraphs_prepared = false; // Apply lazy initialization of WHILE kernel. // Just make node output tensors dynamic. int num_outputs = node->outputs->size; @@ -354,10 +360,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return Prepare_impl(context, node); } -TfLiteStatus Prepare_lazy(TfLiteContext* context, TfLiteNode* node) { - return Prepare_impl(context, node); -} - // Evaluate cond subgraph and set the result. TfLiteStatus Eval_cond_subgraph(TfLiteContext* context, Subgraph* cond_subgraph, bool cond_has_dynamic_output_tensors, @@ -584,7 +586,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); if (op_data->subgraphs_prepared == false) { - TF_LITE_ENSURE_OK(context, Prepare_lazy(context, node)); + TF_LITE_ENSURE_OK(context, Prepare_impl(context, node)); } else { TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); diff --git a/tensorflow/lite/kernels/while_test.cc b/tensorflow/lite/kernels/while_test.cc index f4e21513abbc20..8cc2df70233e08 100644 --- a/tensorflow/lite/kernels/while_test.cc +++ b/tensorflow/lite/kernels/while_test.cc @@ -77,24 +77,27 @@ TEST_F(WhileTest, TestWithXNNPACK) { TEST_F(WhileTest, TestInputIsOutput) { interpreter_ = std::make_unique(); AddSubgraphs(2); - builder_->BuildLargeLessEqualCondSubgraph(interpreter_->subgraph(1), 3, 2); + builder_->BuildLargeLessEqualCondSubgraph(interpreter_->subgraph(1), 3, 3); builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(2)); - builder_->BuildMultiInputWhileSubgraph(&interpreter_->primary_subgraph(), 2); + builder_->BuildMultiInputWhileSubgraph(&interpreter_->primary_subgraph(), 3); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1}); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); CheckIntTensor(output0, {1}, {4}); TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); - CheckIntTensor(output1, {1}, {1}); + CheckIntTensor(output1, {1}, {4}); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); @@ -216,30 +219,47 @@ TEST_F(WhileTest, TestAllCases) { } TEST_F(WhileTest, TestStaticUnconsumedOutputs) { - interpreter_ = std::make_unique(); - AddSubgraphs(2); - builder_->BuildLargeLessEqualCondSubgraph(interpreter_->subgraph(1), 3, 2); - builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(2)); - builder_->BuildMultiInputWhileSubgraphWithUnconsumedOutput( - &interpreter_->primary_subgraph(), 2); + for (bool dynamic_tensors : {true, false}) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildLargeLessEqualCondSubgraph(interpreter_->subgraph(1), 3, 3); + builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputWhileSubgraphWithUnconsumedOutput( + &interpreter_->primary_subgraph(), 3); - ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), - kTfLiteOk); - ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), - kTfLiteOk); - ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); - FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {2}); + InterpreterOptions options; + if (dynamic_tensors) { + options.OptimizeMemoryForLargeTensors(1); + interpreter_->ApplyOptions(&options); + } - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); - CheckIntTensor(output0, {1}, {5}); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2}); - ASSERT_EQ(interpreter_->subgraph(2)->tensor(1)->data.data, - interpreter_->tensor(interpreter_->inputs()[1])->data.data); - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {4}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {1}, {8}); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {2}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2, 2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + CheckIntTensor(output1, {2}, {8, 8}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + } } // Test a body subgraph which triggers the reallocation of an inplace output diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 28d8f7629be940..0bfe04c903c8d2 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -577,7 +577,6 @@ def build_conversion_flags( enable_mlir_variable_quantization=False, disable_fuse_mul_and_fc=False, quantization_options: Optional[quant_opts_pb2.QuantizationOptions] = None, - enable_hlo_to_tf_conversion=False, mlir_dump_dir=None, mlir_dump_pass_regex=None, mlir_dump_func_regex=None, @@ -686,9 +685,6 @@ def build_conversion_flags( a custom method, and allows finer, modular control. This option will override any other existing quantization flags. We plan on gradually migrating all quantization-related specs into this option. - enable_hlo_to_tf_conversion: Enable HLO to TF conversion in the Converter. - Set this to False by default as this may increase the conversion time if - set otherwise. mlir_dump_dir: A string specifying the target directory to output MLIR dumps produced during conversion. If populated, enables MLIR dumps. mlir_dump_pass_regex: A string containing a regular expression for filtering @@ -797,8 +793,6 @@ def build_conversion_flags( if quantization_options: conversion_flags.quantization_options.CopyFrom(quantization_options) - conversion_flags.enable_hlo_to_tf_conversion = enable_hlo_to_tf_conversion - # Transfer debug options. Check for existence before populating in order to # leverage defaults specified in proto definition. if mlir_dump_dir is not None: diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index f414485a6ecedd..014dfe73fe5761 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -635,7 +635,6 @@ def __init__(self): self._experimental_enable_dynamic_update_slice = False self._experimental_preserve_assert_op = False self._experimental_guarantee_all_funcs_one_use = False - self._experimental_enable_hlo_to_tf_conversion = False # When the value is true, the MLIR quantantizer triggers dynamic range # quantization in MLIR instead of the old quantizer. Used only if @@ -790,9 +789,6 @@ def _get_base_converter_args(self): "allow_all_select_tf_ops": self._experimental_allow_all_select_tf_ops, "disable_fuse_mul_and_fc": self._experimental_disable_fuse_mul_and_fc, "quantization_options": self._experimental_quantization_options, - "enable_hlo_to_tf_conversion": ( - self._experimental_enable_hlo_to_tf_conversion - ), "mlir_dump_dir": self.mlir_dump_dir, "mlir_dump_pass_regex": self.mlir_dump_pass_regex, "mlir_dump_func_regex": self.mlir_dump_func_regex, @@ -2131,6 +2127,11 @@ def from_keras_model(cls, model): return TFLiteKerasModelConverterV2(model) @classmethod + @_deprecation.deprecated( + None, + "Use `jax2tf.convert` and (`lite.TFLiteConverter.from_saved_model`" + " or `lite.TFLiteConverter.from_concrete_functions`) instead.", + ) def experimental_from_jax(cls, serving_funcs, inputs): # Experimental API, subject to changes. # TODO(b/197690428): Currently only support single function. diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 376f4d8e168526..c819e3c67df39f 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -173,7 +173,7 @@ def set_tensor_shapes(tensors, shapes): """Sets Tensor shape for each tensor if the shape is defined. Args: - tensors: TensorFlow ops.Tensor. + tensors: TensorFlow tensor.Tensor. shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). diff --git a/tensorflow/lite/signature_runner.cc b/tensorflow/lite/signature_runner.cc index 3223aefa3cce36..521f5abaaf3eb4 100644 --- a/tensorflow/lite/signature_runner.cc +++ b/tensorflow/lite/signature_runner.cc @@ -85,4 +85,26 @@ TfLiteStatus SignatureRunner::Invoke() { return kTfLiteOk; } +TfLiteStatus SignatureRunner::SetCustomAllocationForInputTensor( + const char* input_name, const TfLiteCustomAllocation& allocation, + int64_t flags) { + const auto& it = signature_def_->inputs.find(input_name); + if (it == signature_def_->inputs.end()) { + subgraph_->ReportError("Input name %s was not found", input_name); + return kTfLiteError; + } + return subgraph_->SetCustomAllocationForTensor(it->second, allocation, flags); +} + +TfLiteStatus SignatureRunner::SetCustomAllocationForOutputTensor( + const char* output_name, const TfLiteCustomAllocation& allocation, + int64_t flags) { + const auto& it = signature_def_->outputs.find(output_name); + if (it == signature_def_->outputs.end()) { + subgraph_->ReportError("Output name %s was not found", output_name); + return kTfLiteError; + } + return subgraph_->SetCustomAllocationForTensor(it->second, allocation, flags); +} + } // namespace tflite diff --git a/tensorflow/lite/signature_runner.h b/tensorflow/lite/signature_runner.h index ae904e99edd1ca..165c98ef82bca7 100644 --- a/tensorflow/lite/signature_runner.h +++ b/tensorflow/lite/signature_runner.h @@ -145,6 +145,56 @@ class SignatureRunner { /// WARNING: This is an experimental API and subject to change. TfLiteStatus Cancel() { return subgraph_->Cancel(); } + /// \brief Assigns (or reassigns) a custom memory allocation for the given + /// tensor name. `flags` is a bitmask, see TfLiteCustomAllocationFlags. + /// The runtime does NOT take ownership of the underlying memory. + /// + /// NOTE: User needs to call AllocateTensors() after this. + /// Invalid/insufficient buffers will cause an error during AllocateTensors or + /// Invoke (in case of dynamic shapes in the graph). + /// + /// Parameters should satisfy the following conditions: + /// 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent + /// In general, this is true for I/O tensors & variable tensors. + /// 2. allocation->data has the appropriate permissions for runtime access + /// (Read-only for inputs, Read-Write for others), and outlives + /// Interpreter. + /// 3. allocation->bytes >= tensor->bytes. + /// This condition is checked again if any tensors are resized. + /// 4. allocation->data should be aligned to kDefaultTensorAlignment + /// defined in lite/util.h. (Currently 64 bytes) + /// This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is + /// set through `flags`. + /// \warning This is an experimental API and subject to change. \n + TfLiteStatus SetCustomAllocationForInputTensor( + const char* input_name, const TfLiteCustomAllocation& allocation, + int64_t flags = kTfLiteCustomAllocationFlagsNone); + + /// \brief Assigns (or reassigns) a custom memory allocation for the given + /// tensor name. `flags` is a bitmask, see TfLiteCustomAllocationFlags. + /// The runtime does NOT take ownership of the underlying memory. + /// + /// NOTE: User needs to call AllocateTensors() after this. + /// Invalid/insufficient buffers will cause an error during AllocateTensors or + /// Invoke (in case of dynamic shapes in the graph). + /// + /// Parameters should satisfy the following conditions: + /// 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent + /// In general, this is true for I/O tensors & variable tensors. + /// 2. allocation->data has the appropriate permissions for runtime access + /// (Read-only for inputs, Read-Write for others), and outlives + /// Interpreter. + /// 3. allocation->bytes >= tensor->bytes. + /// This condition is checked again if any tensors are resized. + /// 4. allocation->data should be aligned to kDefaultTensorAlignment + /// defined in lite/util.h. (Currently 64 bytes) + /// This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is + /// set through `flags`. + /// \warning This is an experimental API and subject to change. \n + TfLiteStatus SetCustomAllocationForOutputTensor( + const char* output_name, const TfLiteCustomAllocation& allocation, + int64_t flags = kTfLiteCustomAllocationFlagsNone); + private: // The life cycle of SignatureRunner depends on the life cycle of Subgraph, // which is owned by an Interpreter. Therefore, the Interpreter will takes the diff --git a/tensorflow/lite/simple_planner.cc b/tensorflow/lite/simple_planner.cc index 3cf26384966dfd..9e24ad0660c7b8 100644 --- a/tensorflow/lite/simple_planner.cc +++ b/tensorflow/lite/simple_planner.cc @@ -101,7 +101,9 @@ TfLiteStatus SimplePlanner::PlanAllocations() { // artificially adding one to their ref-counts so they are never selected // for deallocation. for (int tensor_index : graph_info_->outputs()) { - refcounts[tensor_index]++; + if (tensor_index != kTfLiteOptionalTensor) { + refcounts[tensor_index]++; + } } // Variable tensors also should be ensured to be never overwritten and need to diff --git a/tensorflow/lite/simple_planner_test.cc b/tensorflow/lite/simple_planner_test.cc index 4e3f7e06186629..0b49600f569d39 100644 --- a/tensorflow/lite/simple_planner_test.cc +++ b/tensorflow/lite/simple_planner_test.cc @@ -365,5 +365,24 @@ TEST_F(SimplePlannerTest, SimpleGraphWithPersistentResetAllocationsAfter) { EXPECT_TRUE(tensor5_ptr == (*graph.tensors())[5].data.raw); } +TEST_F(SimplePlannerTest, SimpleGraphOptionalOutput) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op + }, + {-1, 3}); + SetGraph(&graph); + Execute(0, 10); + + EXPECT_TRUE(IsAllocated(1)); + EXPECT_TRUE(IsAllocated(2)); + EXPECT_TRUE(IsAllocated(3)); + EXPECT_TRUE(IsAllocated(4)); + EXPECT_TRUE(IsAllocated(5)); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index 774f849c702b68..a75d7351a8d593 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -1,5 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") -load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") +load("//tensorflow:tensorflow.bzl", "py_binary", "tf_py_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -48,6 +47,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite:model_builder", "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:common", "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", @@ -78,7 +78,7 @@ cc_library( ) # Compatibility stub. Remove when internal customers moved. -py_strict_library( +py_library( name = "tensorflow_wrap_toco", srcs = ["tensorflow_wrap_toco.py"], srcs_version = "PY3", @@ -92,7 +92,7 @@ py_strict_library( ], ) -py_strict_binary( +py_binary( name = "toco_from_protos", srcs = ["toco_from_protos.py"], python_version = "PY3", @@ -106,7 +106,7 @@ py_strict_binary( ], ) -tf_py_strict_test( +tf_py_test( name = "toco_from_protos_test", srcs = ["toco_from_protos_test.py"], python_version = "PY3", @@ -114,8 +114,7 @@ tf_py_strict_test( "no_oss", ], deps = [ - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/core:protos_all_py", + "//tensorflow:tensorflow_py", "//tensorflow/lite/toco:model_flags_proto_py", "//tensorflow/lite/toco:toco_flags_proto_py", "//tensorflow/python/platform:resource_loader", diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index c6339e81cb0080..48af2bdce7cb9a 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/toco/python/toco_python_api.h" #include -#include #include #include #include @@ -32,9 +31,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -327,28 +326,28 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, auto tflite_model = std::make_unique(); model->GetModel()->UnPackTo(tflite_model.get(), nullptr); - tflite::TensorType inference_tensor_type = + const tflite::TensorType inference_tensor_type = FromTocoDataTypeToTflitToTensorType(inference_type); - tflite::TensorType input_type = + const tflite::TensorType input_type = FromTocoDataTypeToTflitToTensorType(input_data_type); - tflite::TensorType output_type = + const tflite::TensorType output_type = FromTocoDataTypeToTflitToTensorType(output_data_type); - flatbuffers::FlatBufferBuilder builder; + std::string output_model; + const absl::string_view input_model_buffer(buf, length); auto status = mlir::lite::QuantizeModel( - *tflite_model, input_type, output_type, inference_tensor_type, {}, - disable_per_channel, fully_quantize, &builder, error_reporter.get(), - enable_numeric_verify, enable_whole_model_verify, + input_model_buffer, input_type, output_type, inference_tensor_type, + /*operator_names=*/{}, disable_per_channel, fully_quantize, output_model, + error_reporter.get(), enable_numeric_verify, enable_whole_model_verify, /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes, enable_variable_quantization); - if (status != kTfLiteOk) { error_reporter->exception(); return nullptr; } - return tflite::python_utils::ConvertToPyString( - reinterpret_cast(builder.GetCurrentBufferPointer()), - builder.GetSize()); + + return tflite::python_utils::ConvertToPyString(output_model.data(), + output_model.size()); } PyObject* MlirSparsifyModel(PyObject* data) { diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index c8cad79d3bacfd..1421f614f82ccb 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -320,7 +320,8 @@ message TocoFlags { // Flag to enable hlo to tf conversion. // This is useful to exercise StableHLO -> HLO -> TF -> TFLite path. - optional bool enable_hlo_to_tf_conversion = 55 [default = false]; + optional bool enable_hlo_to_tf_conversion = 55 + [default = false, deprecated = true]; // Additional parameters for controlling debug facilities. optional tensorflow.converter.DebugOptions debug_options = 56; diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 95ca87a9e870dc..3592da77f5bf0c 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -256,6 +256,24 @@ cc_library( ], ) +cc_library( + name = "register_custom_op", + srcs = [ + "register_custom_op.cc", + ], + hdrs = [ + "register_custom_op.h", + ], + copts = common_copts, + deps = [ + "//tensorflow/lite:op_resolver", + "@com_google_absl//absl/base:core_headers", + ], + alwayslink = 1, +) + +exports_files(["register_custom_op.h"]) + cc_library( name = "benchmark_utils", srcs = [ diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/BUILD b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/BUILD index 0ef82065e3cf93..163907fe32df63 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/BUILD @@ -2,7 +2,8 @@ # Delegate Performance Benchmark (DPB) Android app. # This provides model-level latency & accuracy testings for delegates, on Android. -load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_binary_with_tflite", "android_library_with_tflite") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -20,6 +21,7 @@ android_library( name = "benchmark_accuracy_impl", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java"], deps = [ + ":benchmark_accuracy", ":benchmark_report", ":csv_writer", ":delegate_performance_benchmark_utils", @@ -32,6 +34,11 @@ android_library( ], ) +android_library( + name = "benchmark_accuracy", + srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java"], +) + android_library( name = "benchmark_latency_activity", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkLatencyActivity.java"], @@ -56,7 +63,7 @@ android_library( ], ) -android_library( +android_library_with_tflite( name = "benchmark_report", srcs = [ "src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkReport.java", @@ -74,11 +81,11 @@ android_library( srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkResultType.java"], ) -android_library( +android_library_with_tflite( name = "csv_writer", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/CsvWriter.java"], + tflite_deps = [":benchmark_report"], deps = [ - ":benchmark_report", ":delegate_metrics_entry", ":metrics_entry", ":model_benchmark_report_interface", @@ -95,16 +102,18 @@ android_library( ], ) -android_library( +android_library_with_tflite( name = "delegate_performance_benchmark_lib", srcs = [ "src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java", "src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkLatencyActivity.java", ], + tflite_deps = [ + "//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native:benchmark_native", + ], deps = [ ":benchmark_accuracy_impl", ":benchmark_latency_impl", - "//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native:benchmark_native", ], ) @@ -123,11 +132,11 @@ android_library( ], ) -android_library( +android_library_with_tflite( name = "html_writer", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/HtmlWriter.java"], + tflite_deps = [":benchmark_report"], deps = [ - ":benchmark_report", ":benchmark_result_type", ":delegate_metrics_entry", ":metrics_entry", @@ -136,10 +145,10 @@ android_library( ], ) -android_library( +android_library_with_tflite( name = "json_writer", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/JsonWriter.java"], - deps = [":benchmark_report"], + tflite_deps = [":benchmark_report"], ) android_library( @@ -148,7 +157,7 @@ android_library( deps = [":benchmark_result_type"], ) -android_library( +android_library_with_tflite( name = "model_benchmark_report", srcs = [ "src/main/java/org/tensorflow/lite/benchmark/delegateperformance/AccuracyBenchmarkReport.java", @@ -204,7 +213,7 @@ android_library( ) # The main test app. -android_binary( +android_binary_with_tflite( name = "delegate_performance_benchmark", assets = [ "//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/models:accuracy_models", @@ -221,8 +230,8 @@ android_binary( # can't be built. We need to prevent the build system from trying to # use the target in that case. tags = ["manual"], - visibility = ["//visibility:public"], - deps = [ + tflite_deps = [ ":delegate_performance_benchmark_lib", ], + visibility = ["//visibility:public"], ) diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java new file mode 100644 index 00000000000000..90ef295cb09bc8 --- /dev/null +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.benchmark.delegateperformance; + +import android.content.Context; + +/** Interface for Delegate Performance Accuracy Benchmark. */ +public interface BenchmarkAccuracy { + /** + * Initializes and runs the accuracy benchmark. + * + * @param context the context to use for finding the test models and exporting reports + * @param tfliteSettingsJsonFiles the list of paths to delegate JSON configurations + * @return {@code true} if the benchmark was successfully initialized and executed. Otherwise, + * returns {@code false}. + */ + boolean benchmark(Context context, String[] tfliteSettingsJsonFiles); +} diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java index 9645a94cd1cfc4..93f8c92797bde3 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java @@ -41,14 +41,10 @@ public void onCreate(Bundle savedInstanceState) { Intent intent = getIntent(); Bundle bundle = intent.getExtras(); String[] tfliteSettingsJsonFiles = bundle.getStringArray(TFLITE_SETTINGS_FILES_INTENT_KEY_0); - BenchmarkAccuracyImpl impl = - new BenchmarkAccuracyImpl(getApplicationContext(), tfliteSettingsJsonFiles); - - if (impl.initialize()) { - impl.benchmark(); - } else { - Log.e(TAG, "Failed to initialize the accuracy benchmarking."); + if (!new BenchmarkAccuracyImpl().benchmark(getApplicationContext(), tfliteSettingsJsonFiles)) { + Log.i(TAG, "Accuracy benchmark failed."); } + finish(); } } diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java index 997b959e6c756d..7ce74f13c4c39a 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java @@ -54,49 +54,25 @@ * configuration and relative performance differences as percentages in HTML. * */ -public class BenchmarkAccuracyImpl { +public class BenchmarkAccuracyImpl implements BenchmarkAccuracy { private static final String TAG = "TfLiteAccuracyImpl"; private static final String ACCURACY_FOLDER_NAME = "accuracy"; - private final Context context; - private final String[] tfliteSettingsJsonFiles; - private final BenchmarkReport report; + private Context context; + private String[] tfliteSettingsJsonFiles; + private BenchmarkReport report; - public BenchmarkAccuracyImpl(Context context, String[] tfliteSettingsJsonFiles) { - this.context = context; - this.tfliteSettingsJsonFiles = tfliteSettingsJsonFiles; - this.report = BenchmarkReport.create(); - } - - /** - * Initializes the test environment. Checks the validity of input arguments and creates the result - * folder. - * - *

Returns {@code true} if the initialization was successful. Otherwise, returns {@code false}. - */ - public boolean initialize() { - if (tfliteSettingsJsonFiles == null || tfliteSettingsJsonFiles.length == 0) { - Log.e(TAG, "No TFLiteSettings file provided."); - return false; - } - - try { - // Creates root result folder. - String resultFolderPath = - DelegatePerformanceBenchmark.createResultFolder( - context.getFilesDir(), ACCURACY_FOLDER_NAME); - report.addWriter(JsonWriter.create(resultFolderPath)); - report.addWriter(CsvWriter.create(resultFolderPath)); - report.addWriter(HtmlWriter.create(resultFolderPath)); - } catch (IOException e) { - Log.e(TAG, "Failed to create result folder", e); + @Override + public boolean benchmark(Context context, String[] tfliteSettingsJsonFiles) { + if (!initialize(context, tfliteSettingsJsonFiles)) { + Log.e(TAG, "Failed to initialize accuracy benchmark."); return false; } - return true; + return benchmarkDelegatesAndExportReport(); } - public void benchmark() { + private boolean benchmarkDelegatesAndExportReport() { Log.i( TAG, "Running accuracy benchmark with TFLiteSettings JSON files: " @@ -105,14 +81,14 @@ public void benchmark() { DelegatePerformanceBenchmark.loadTfLiteSettingsList(tfliteSettingsJsonFiles); if (tfliteSettingsList.size() < 2) { Log.e(TAG, "Failed to load the TFLiteSettings JSON file."); - return; + return false; } String[] assets; try { assets = context.getAssets().list(ACCURACY_FOLDER_NAME); } catch (IOException e) { Log.e(TAG, "Failed to list files from assets folder.", e); - return; + return false; } for (String asset : assets) { if (!asset.endsWith(".tflite")) { @@ -127,7 +103,7 @@ public void benchmark() { context.getFilesDir(), ACCURACY_FOLDER_NAME + "/" + modelName); } catch (IOException e) { Log.e(TAG, "Failed to create result folder for " + modelName + ". Exiting application.", e); - return; + return false; } try (AssetFileDescriptor modelFileDescriptor = context.getAssets().openFd(ACCURACY_FOLDER_NAME + "/" + asset)) { @@ -148,7 +124,7 @@ public void benchmark() { AccuracyBenchmarkReport.create(modelName, rawDelegateMetricsEntries)); } catch (IOException e) { Log.e(TAG, "Failed to open assets file " + asset, e); - return; + return false; } } // Computes the aggregated results and export the report to local files. @@ -158,5 +134,36 @@ public void benchmark() { TAG, String.format( "Accuracy benchmark result for %s: %s.", testTarget.filePath(), report.result())); + return true; + } + + /** + * Initializes the test environment. Checks the validity of input arguments and creates the result + * folder. + * + * @return {@code true} if the initialization was successful. Otherwise, returns {@code false}. + */ + private boolean initialize(Context context, String[] tfliteSettingsJsonFiles) { + if (tfliteSettingsJsonFiles == null || tfliteSettingsJsonFiles.length == 0) { + Log.e(TAG, "No TFLiteSettings file provided."); + return false; + } + this.context = context; + this.tfliteSettingsJsonFiles = tfliteSettingsJsonFiles; + report = BenchmarkReport.create(); + + try { + // Creates root result folder. + String resultFolderPath = + DelegatePerformanceBenchmark.createResultFolder( + context.getFilesDir(), ACCURACY_FOLDER_NAME); + report.addWriter(JsonWriter.create(resultFolderPath)); + report.addWriter(CsvWriter.create(resultFolderPath)); + report.addWriter(HtmlWriter.create(resultFolderPath)); + } catch (IOException e) { + Log.e(TAG, "Failed to create result folder", e); + return false; + } + return true; } } diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD index ce2e1e300a64ee..5209526ab6712c 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD @@ -1,9 +1,8 @@ # Description: # Holds the native layer of the app. -load("//tensorflow/lite:build_def.bzl", "tflite_jni_binary") -load("//tensorflow:tensorflow.bzl", "clean_dep") load("//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android:build_defs.bzl", "accuracy_benchmark_extra_deps", "latency_benchmark_extra_deps") +load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -11,11 +10,11 @@ package( licenses = ["notice"], ) -tflite_jni_binary( +jni_binary_with_tflite( name = "libdelegate_performance_benchmark.so", srcs = ["delegate_performance_benchmark_jni.cc"], + tflite_deps = [":accuracy_benchmark"], deps = [ - ":accuracy_benchmark", ":latency_benchmark", "//tensorflow/lite/acceleration/configuration:configuration_fbs", "//tensorflow/lite/delegates/utils/experimental/stable_delegate:tflite_settings_json_parser", @@ -46,7 +45,7 @@ cc_library( ] + latency_benchmark_extra_deps(), ) -cc_library( +cc_library_with_tflite( name = "accuracy_benchmark", srcs = ["accuracy_benchmark.cc"], hdrs = ["accuracy_benchmark.h"], @@ -54,6 +53,7 @@ cc_library( ":status_codes", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/acceleration/configuration:gpu_plugin", "//tensorflow/lite/acceleration/configuration:stable_delegate_plugin", "//tensorflow/lite/acceleration/configuration:xnnpack_plugin", "//tensorflow/lite/core/acceleration/configuration:nnapi_plugin", @@ -65,18 +65,10 @@ cc_library( "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:tool_params", "@flatbuffers", - ] + select({ - # On Android, as the validation runs in a separate process as a - # different binary, any TFLite delegates to be validated need to - # include corresponding delegate plugins. - clean_dep("//tensorflow:android"): [ - "//tensorflow/lite/acceleration/configuration:gpu_plugin", - ], - "//conditions:default": [], - }) + accuracy_benchmark_extra_deps(), + ] + accuracy_benchmark_extra_deps(), ) -cc_library( +cc_library_with_tflite( name = "benchmark_native", - srcs = ["libdelegate_performance_benchmark.so"], + tflite_jni_binaries = [":libdelegate_performance_benchmark.so"], ) diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc index e538cc1688478d..87fe702e3dac19 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc @@ -24,7 +24,10 @@ limitations under the License. #include "tensorflow/lite/delegates/utils/experimental/stable_delegate/tflite_settings_json_parser.h" #include "tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/delegate_performance.pb.h" #include "tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h" + +#ifndef TFLITE_WITH_STABLE_ABI #include "tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.h" +#endif // !TFLITE_WITH_STABLE_ABI namespace { @@ -67,6 +70,10 @@ Java_org_tensorflow_lite_benchmark_delegateperformance_DelegatePerformanceBenchm JNIEnv* env, jclass clazz, jobjectArray args_obj, jbyteArray tflite_settings_byte_array, jstring tflite_settings_path_obj, jint model_fd, jlong model_offset, jlong model_size) { + tflite::proto::benchmark::LatencyResults results; + +// The latency benchmark doesn't support TF Lite with the stable ABI path. +#ifndef TFLITE_WITH_STABLE_ABI std::vector args = toStringVector(env, args_obj); const char* tflite_settings_path_chars = env->GetStringUTFChars(tflite_settings_path_obj, nullptr); @@ -76,16 +83,15 @@ Java_org_tensorflow_lite_benchmark_delegateperformance_DelegatePerformanceBenchm flatbuffers::GetRoot( reinterpret_cast(tflite_settings_bytes)); - tflite::proto::benchmark::LatencyResults results = - tflite::benchmark::latency::Benchmark( - *tflite_settings, tflite_settings_path_chars, - static_cast(model_fd), static_cast(model_offset), - static_cast(model_size), args); + results = tflite::benchmark::latency::Benchmark( + *tflite_settings, tflite_settings_path_chars, static_cast(model_fd), + static_cast(model_offset), static_cast(model_size), args); env->ReleaseByteArrayElements(tflite_settings_byte_array, tflite_settings_bytes, JNI_ABORT); env->ReleaseStringUTFChars(tflite_settings_path_obj, tflite_settings_path_chars); +#endif // !TFLITE_WITH_STABLE_ABI return CppProtoToBytes(env, results); } diff --git a/tensorflow/lite/tools/benchmark/register_custom_op.cc b/tensorflow/lite/tools/benchmark/register_custom_op.cc new file mode 100644 index 00000000000000..3592663d3f2f9d --- /dev/null +++ b/tensorflow/lite/tools/benchmark/register_custom_op.cc @@ -0,0 +1,23 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/base/attributes.h" +#include "tensorflow/lite/op_resolver.h" + +// Version with Weak linker attribute doing nothing: if someone links this +// library with another definition of this function (presumably to actually +// register custom ops), that version will be used instead. +void ABSL_ATTRIBUTE_WEAK +RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {} diff --git a/tensorflow/lite/tools/benchmark/register_custom_op.h b/tensorflow/lite/tools/benchmark/register_custom_op.h new file mode 100644 index 00000000000000..9278e31a43fbe7 --- /dev/null +++ b/tensorflow/lite/tools/benchmark/register_custom_op.h @@ -0,0 +1,23 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_REGISTER_CUSTOM_OP_H_ +#define TENSORFLOW_LITE_TOOLS_BENCHMARK_REGISTER_CUSTOM_OP_H_ + +#include "tensorflow/lite/op_resolver.h" + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); + +#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_REGISTER_CUSTOM_OP_H_ diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD index 74875c04112619..a071a2615e9a93 100644 --- a/tensorflow/lite/tools/versioning/BUILD +++ b/tensorflow/lite/tools/versioning/BUILD @@ -2,7 +2,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -20,7 +20,7 @@ cc_library( "op_version.h", "runtime_version.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":op_signature", "//tensorflow/core:tflite_portable_logging", @@ -62,7 +62,7 @@ cc_library( hdrs = [ "op_signature.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/lite:stderr_reporter", "//tensorflow/lite/core/api", @@ -101,7 +101,7 @@ cc_library( hdrs = [ "gpu_compatibility.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":op_signature", "//tensorflow/lite:builtin_op_data", diff --git a/tensorflow/lite/util.h b/tensorflow/lite/util.h index 2ba25f84588dfc..6e8264974501ce 100644 --- a/tensorflow/lite/util.h +++ b/tensorflow/lite/util.h @@ -22,6 +22,7 @@ limitations under the License. #define TENSORFLOW_LITE_UTIL_H_ #include +#include #include #include diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index aed2a433199a7e..9ccc484ca3a247 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -99,6 +99,7 @@ tensorflow/lite/delegates/hexagon/hexagon_nn/BUILD: tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/BUILD: tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external.cc: tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h: +tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h: tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg.h: tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl: tensorflow/lite/interpreter.h: diff --git a/tensorflow/py.default.bzl b/tensorflow/py.default.bzl new file mode 100644 index 00000000000000..bad528e901bbd1 --- /dev/null +++ b/tensorflow/py.default.bzl @@ -0,0 +1,12 @@ +"""Shims for loading the plain Python rules. + +These are used to make internal/external code transformations managable. Once +Tensorflow is loading the Python rules directly from rules_python, these shims +can be removed. +""" + +py_test = native.py_test + +py_binary = native.py_binary + +py_library = native.py_library diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f23e2de226a2f9..18716ed6ea152f 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -182,12 +182,14 @@ py_library( "//tensorflow/python/framework:config", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:extension_type", + "//tensorflow/python/framework:flexible_dtypes", "//tensorflow/python/framework:for_generated_wrappers", "//tensorflow/python/framework:graph_util", "//tensorflow/python/framework:kernels", "//tensorflow/python/framework:subscribe", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_ops", # TODO(b/183988750): Break testing code out into separate rule. + "//tensorflow/python/framework:weak_tensor", "//tensorflow/python/grappler:tf_cluster", "//tensorflow/python/grappler:tf_item", "//tensorflow/python/grappler:tf_optimizer", @@ -245,6 +247,8 @@ py_library( "//tensorflow/python/ops:tensor_array_ops", "//tensorflow/python/ops:uniform_quant_ops_gen", "//tensorflow/python/ops:variable_v1", + "//tensorflow/python/ops:weak_tensor_ops", + "//tensorflow/python/ops:weak_tensor_test_util", "//tensorflow/python/ops:weights_broadcast_ops", "//tensorflow/python/ops:while_loop", "//tensorflow/python/ops:while_v2", @@ -344,7 +348,12 @@ py_library( ":no_contrib", ":tf2", "//tensorflow/core/function/trace_type", + "//tensorflow/python/compiler/mlir", + "//tensorflow/python/compiler/xla", + "//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/python/data", + "//tensorflow/python/debug/lib:check_numerics_callback", + "//tensorflow/python/debug/lib:dumping_callback", "//tensorflow/python/distribute", "//tensorflow/python/distribute:merge_call_interim", "//tensorflow/python/distribute:multi_process_runner", @@ -352,6 +361,7 @@ py_library( "//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/distribute/failure_handling:failure_handling_lib", "//tensorflow/python/distribute/failure_handling:preemption_watcher", + "//tensorflow/python/dlpack", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:monitoring", @@ -373,6 +383,7 @@ py_library( "//tensorflow/python/ops:composite_tensor_ops", "//tensorflow/python/ops:cond_v2", "//tensorflow/python/ops:cudnn_rnn_ops_gen", + "//tensorflow/python/ops:debug_ops_gen", "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops:image_ops", "//tensorflow/python/ops:initializers_ns", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 411e32a440f901..bfb0114b305461 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -72,22 +72,6 @@ from tensorflow.python.util.all_util import make_all from tensorflow.python.util.tf_export import tf_export -# TensorFlow Debugger (tfdbg). -from tensorflow.python.debug.lib import check_numerics_callback -from tensorflow.python.debug.lib import dumping_callback -from tensorflow.python.ops import gen_debug_ops - -# DLPack -from tensorflow.python.dlpack.dlpack import from_dlpack -from tensorflow.python.dlpack.dlpack import to_dlpack - -# XLA JIT compiler APIs. -from tensorflow.python.compiler.xla import jit -from tensorflow.python.compiler.xla import xla - -# MLIR APIs. -from tensorflow.python.compiler.mlir import mlir - # Update dispatch decorator docstrings to contain lists of registered APIs. # (This should come after any imports that register APIs.) from tensorflow.python.util import dispatch diff --git a/tensorflow/python/autograph/BUILD b/tensorflow/python/autograph/BUILD index 01306a77aa1018..af9ec45ab93dfd 100644 --- a/tensorflow/python/autograph/BUILD +++ b/tensorflow/python/autograph/BUILD @@ -13,7 +13,6 @@ py_strict_library( srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python/autograph/converters:__init__", "//tensorflow/python/autograph/core:converter", "//tensorflow/python/autograph/impl:api", "//tensorflow/python/autograph/lang:directives", diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 2261a63a93c29e..d45a9a330ea7a1 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -5,15 +5,6 @@ package( licenses = ["notice"], ) -py_strict_library( - name = "__init__", - srcs = ["__init__.py"], - visibility = ["//tensorflow:__subpackages__"], - deps = [ - ":list_comprehensions", - ], -) - py_strict_library( name = "slices", srcs = ["slices.py"], @@ -40,17 +31,6 @@ py_strict_library( ], ) -py_strict_library( - name = "list_comprehensions", - srcs = ["list_comprehensions.py"], - visibility = ["//tensorflow:__subpackages__"], - deps = [ - "//tensorflow/python/autograph/core:converter", - "//tensorflow/python/autograph/pyct:templates", - "@gast_archive//:gast", - ], -) - py_strict_library( name = "logical_expressions", srcs = ["logical_expressions.py"], @@ -104,28 +84,6 @@ py_strict_library( ], ) -py_strict_library( - name = "control_flow_deprecated_py2", - srcs = ["control_flow_deprecated_py2.py"], - visibility = ["//tensorflow:__subpackages__"], - deps = [ - "//tensorflow/python/autograph/core:converter", - "//tensorflow/python/autograph/lang:directives", - "//tensorflow/python/autograph/pyct:anno", - "//tensorflow/python/autograph/pyct:ast_util", - "//tensorflow/python/autograph/pyct:cfg", - "//tensorflow/python/autograph/pyct:parser", - "//tensorflow/python/autograph/pyct:qual_names", - "//tensorflow/python/autograph/pyct:templates", - "//tensorflow/python/autograph/pyct/static_analysis:activity", - "//tensorflow/python/autograph/pyct/static_analysis:annos", - "//tensorflow/python/autograph/pyct/static_analysis:liveness", - "//tensorflow/python/autograph/pyct/static_analysis:reaching_definitions", - "//tensorflow/python/autograph/pyct/static_analysis:reaching_fndefs", - "@gast_archive//:gast", - ], -) - py_strict_library( name = "directives", srcs = ["directives.py"], @@ -352,18 +310,6 @@ py_strict_test( ], ) -py_strict_test( - name = "list_comprehensions_test", - srcs = ["list_comprehensions_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":list_comprehensions", - "//tensorflow/python/autograph/core:test_lib", - "//tensorflow/python/platform:client_testlib", - ], -) - py_strict_test( name = "lists_test", srcs = ["lists_test.py"], @@ -376,7 +322,7 @@ py_strict_test( "//tensorflow/python/autograph/lang:directives", "//tensorflow/python/autograph/lang:special_functions", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/ops:list_ops", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py b/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py deleted file mode 100644 index 6fa1deee76b61f..00000000000000 --- a/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py +++ /dev/null @@ -1,635 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Handles control flow statements: while, for, if. - -Python 2 compatibility version. Not maintained. -""" - -import gast - -from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.lang import directives -from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import ast_util -from tensorflow.python.autograph.pyct import cfg -from tensorflow.python.autograph.pyct import parser -from tensorflow.python.autograph.pyct import qual_names -from tensorflow.python.autograph.pyct import templates -from tensorflow.python.autograph.pyct.static_analysis import activity -from tensorflow.python.autograph.pyct.static_analysis import annos -from tensorflow.python.autograph.pyct.static_analysis import liveness -from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions -from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs - - -# TODO(mdan): Refactor functions to make them smaller. - - -class ControlFlowTransformer(converter.Base): - """Transforms control flow structures like loops an conditionals.""" - - def _create_cond_branch(self, body_name, aliased_orig_names, - aliased_new_names, body, returns): - if len(returns) == 1: - template = """ - return retval - """ - return_stmt = templates.replace(template, retval=returns[0]) - else: - template = """ - return (retvals,) - """ - return_stmt = templates.replace(template, retvals=returns) - - if aliased_orig_names: - alias_declarations = [] - for new_name, old_name in zip(aliased_new_names, aliased_orig_names): - template = """ - try: - aliased_new_name = aliased_orig_name - except NameError: - aliased_new_name = ag__.Undefined(symbol_name) - """ - - alias_declarations.extend( - templates.replace( - template, - aliased_new_name=new_name, - aliased_orig_name=old_name, - symbol_name=gast.Constant(str(old_name), kind=None))) - - template = """ - def body_name(): - alias_declarations - body - return_stmt - """ - return templates.replace( - template, - alias_declarations=alias_declarations, - body_name=body_name, - body=body, - return_stmt=return_stmt) - else: - template = """ - def body_name(): - body - return_stmt - """ - return templates.replace( - template, body_name=body_name, body=body, return_stmt=return_stmt) - - def _create_cond_expr(self, results, test, body_name, orelse_name, - state_getter_name, state_setter_name, - basic_symbol_names, composite_symbol_names): - if results is not None: - template = """ - results = ag__.if_stmt(test, body_name, orelse_name, - state_getter_name, state_setter_name, - (basic_symbol_names,), - (composite_symbol_names,)) - """ - return templates.replace( - template, - test=test, - results=results, - body_name=body_name, - orelse_name=orelse_name, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) - else: - template = """ - ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name, - (basic_symbol_names,), (composite_symbol_names,)) - """ - return templates.replace( - template, - test=test, - body_name=body_name, - orelse_name=orelse_name, - getter_name=state_getter_name, - setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) - - def _fmt_symbols(self, symbol_set): - if not symbol_set: - return 'no variables' - return ', '.join(map(str, symbol_set)) - - def _determine_aliased_symbols(self, scope, node_defined_in): - modified_live = scope.modified & node_defined_in - # Composite symbols are handled elsewhere see _create_state_functions - return {s for s in modified_live if not s.is_composite()} - - def _create_state_functions(self, composites, state_getter_name, - state_setter_name): - - if composites: - composite_tuple = tuple(composites) - - template = """ - def state_getter_name(): - return composite_tuple, - def state_setter_name(vals): - composite_tuple, = vals - """ - node = templates.replace( - template, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - composite_tuple=composite_tuple) - else: - template = """ - def state_getter_name(): - return () - def state_setter_name(_): - pass - """ - node = templates.replace( - template, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name) - - return node - - def _create_loop_options(self, node): - if not anno.hasanno(node, anno.Basic.DIRECTIVES): - return gast.Dict([], []) - - loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) - if directives.set_loop_options not in loop_directives: - return gast.Dict([], []) - - opts_dict = loop_directives[directives.set_loop_options] - str_keys, values = zip(*opts_dict.items()) - keys = [gast.Constant(s, kind=None) for s in str_keys] - values = list(values) # ast and gast don't play well with tuples. - return gast.Dict(keys, values) - - def _create_undefined_assigns(self, undefined_symbols): - assignments = [] - for s in undefined_symbols: - template = ''' - var = ag__.Undefined(symbol_name) - ''' - assignments += templates.replace( - template, - var=s, - symbol_name=gast.Constant(s.ssf(), kind=None)) - return assignments - - def visit_If(self, node): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) - defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) - live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - - # Note: this information needs to be extracted before the body conversion - # that happens in the call to generic_visit below, because the conversion - # generates nodes that lack static analysis annotations. - need_alias_in_body = self._determine_aliased_symbols( - body_scope, defined_in) - need_alias_in_orelse = self._determine_aliased_symbols( - orelse_scope, defined_in) - - node = self.generic_visit(node) - - modified_in_cond = body_scope.modified | orelse_scope.modified - returned_from_cond = set() - composites = set() - for s in modified_in_cond: - if s in live_out and not s.is_composite(): - returned_from_cond.add(s) - if s.is_composite(): - # Special treatment for compound objects, always return them. - # This allows special handling within the if_stmt itself. - # For example, in TensorFlow we need to restore the state of composite - # symbols to ensure that only effects from the executed branch are seen. - composites.add(s) - - created_in_body = body_scope.modified & returned_from_cond - defined_in - created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in - - basic_created_in_body = tuple( - s for s in created_in_body if not s.is_composite()) - basic_created_in_orelse = tuple( - s for s in created_in_orelse if not s.is_composite()) - - # These variables are defined only in a single branch. This is fine in - # Python so we pass them through. Another backend, e.g. Tensorflow, may need - # to handle these cases specially or throw an Error. - possibly_undefined = (set(basic_created_in_body) ^ - set(basic_created_in_orelse)) - - # Alias the closure variables inside the conditional functions, to allow - # the functions access to the respective variables. - # We will alias variables independently for body and orelse scope, - # because different branches might write different variables. - aliased_body_orig_names = tuple(need_alias_in_body) - aliased_orelse_orig_names = tuple(need_alias_in_orelse) - aliased_body_new_names = tuple( - self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) - for s in aliased_body_orig_names) - aliased_orelse_new_names = tuple( - self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) - for s in aliased_orelse_orig_names) - - alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) - alias_orelse_map = dict( - zip(aliased_orelse_orig_names, aliased_orelse_new_names)) - - node_body = ast_util.rename_symbols(node.body, alias_body_map) - node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) - - cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) - body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) - orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) - all_referenced = body_scope.referenced | orelse_scope.referenced - state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) - state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) - - returned_from_cond = tuple(returned_from_cond) - composites = tuple(composites) - - if returned_from_cond: - if len(returned_from_cond) == 1: - cond_results = returned_from_cond[0] - else: - cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) - - returned_from_body = tuple( - alias_body_map[s] if s in need_alias_in_body else s - for s in returned_from_cond) - returned_from_orelse = tuple( - alias_orelse_map[s] if s in need_alias_in_orelse else s - for s in returned_from_cond) - - else: - # When the cond would return no value, we leave the cond called without - # results. That in turn should trigger the side effect guards. The - # branch functions will return a dummy value that ensures cond - # actually has some return value as well. - cond_results = None - # TODO(mdan): Replace with None once side_effect_guards is retired. - returned_from_body = (templates.replace_as_expression( - 'ag__.match_staging_level(1, cond_var_name)', - cond_var_name=cond_var_name),) - returned_from_orelse = (templates.replace_as_expression( - 'ag__.match_staging_level(1, cond_var_name)', - cond_var_name=cond_var_name),) - - cond_assign = self.create_assignment(cond_var_name, node.test) - body_def = self._create_cond_branch( - body_name, - aliased_orig_names=aliased_body_orig_names, - aliased_new_names=aliased_body_new_names, - body=node_body, - returns=returned_from_body) - orelse_def = self._create_cond_branch( - orelse_name, - aliased_orig_names=aliased_orelse_orig_names, - aliased_new_names=aliased_orelse_new_names, - body=node_orelse, - returns=returned_from_orelse) - undefined_assigns = self._create_undefined_assigns(possibly_undefined) - composite_defs = self._create_state_functions( - composites, state_getter_name, state_setter_name) - - basic_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond) - composite_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in composites) - - cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, - orelse_name, state_getter_name, - state_setter_name, basic_symbol_names, - composite_symbol_names) - - if_ast = ( - undefined_assigns + composite_defs + body_def + orelse_def + - cond_assign + cond_expr) - return if_ast - - def _get_basic_loop_vars(self, modified_symbols, live_in, live_out): - # The loop variables corresponding to simple symbols (e.g. `x`). - basic_loop_vars = [] - for s in modified_symbols: - if s.is_composite(): - # TODO(mdan): Raise an error when this happens for a TF loop. - continue - # Variables not live into or out of the loop are considered local to the - # loop. - if s not in live_in and s not in live_out: - continue - basic_loop_vars.append(s) - return frozenset(basic_loop_vars) - - def _get_composite_loop_vars(self, modified_symbols, live_in): - # The loop variables corresponding to composite symbols (e.g. `self.x`). - composite_loop_vars = [] - for s in modified_symbols: - if not s.is_composite(): - continue - # Mutations made to objects created inside the loop will appear as writes - # to composite symbols. Because these mutations appear as modifications - # made to composite symbols, we check whether the composite's parent is - # actually live into the loop. - # Example: - # while cond: - # x = Foo() - # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. - # - # Note that some parents might not be symbols - for example, in x['foo'], - # 'foo' is a parent, but it's a literal, not a symbol. We don't check the - # liveness of literals. - support_set_symbols = tuple( - sss for sss in s.support_set if sss.is_symbol()) - if not all(sss in live_in for sss in support_set_symbols): - continue - composite_loop_vars.append(s) - return frozenset(composite_loop_vars) - - def _get_loop_vars(self, node, modified_symbols): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) - live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) - live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - reserved_symbols = body_scope.referenced - - basic_loop_vars = self._get_basic_loop_vars( - modified_symbols, live_in, live_out) - composite_loop_vars = self._get_composite_loop_vars( - modified_symbols, live_in) - - # Variable that are used or defined inside the loop, but not defined - # before entering the loop. Only simple variables must be defined. The - # composite ones will be implicitly checked at runtime. - undefined_lives = basic_loop_vars - defined_in - - return (basic_loop_vars, composite_loop_vars, reserved_symbols, - undefined_lives) - - def _loop_var_constructs(self, basic_loop_vars): - loop_vars = tuple(basic_loop_vars) - loop_vars_ast_tuple = gast.Tuple([n.ast() for n in loop_vars], None) - - if len(loop_vars) == 1: - loop_vars = loop_vars[0] - - return loop_vars, loop_vars_ast_tuple - - def visit_While(self, node): - node = self.generic_visit(node) - - (basic_loop_vars, composite_loop_vars, reserved_symbols, - possibly_undefs) = self._get_loop_vars( - node, - anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) - loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( - basic_loop_vars) - - state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) - state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) - state_functions = self._create_state_functions( - composite_loop_vars, state_getter_name, state_setter_name) - - basic_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) - composite_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) - - opts = self._create_loop_options(node) - - # TODO(mdan): Use a single template. - # If the body and test functions took a single tuple for loop_vars, instead - # of *loop_vars, then a single template could be used. - if loop_vars: - template = """ - state_functions - def body_name(loop_vars): - body - return loop_vars, - def test_name(loop_vars): - return test - loop_vars_ast_tuple = ag__.while_stmt( - test_name, - body_name, - state_getter_name, - state_setter_name, - (loop_vars,), - (basic_symbol_names,), - (composite_symbol_names,), - opts) - """ - node = templates.replace( - template, - loop_vars=loop_vars, - loop_vars_ast_tuple=loop_vars_ast_tuple, - test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), - test=node.test, - body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), - body=node.body, - state_functions=state_functions, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names, - opts=opts) - else: - template = """ - state_functions - def body_name(): - body - return () - def test_name(): - return test - ag__.while_stmt( - test_name, - body_name, - state_getter_name, - state_setter_name, - (), - (), - (composite_symbol_names,), - opts) - """ - node = templates.replace( - template, - test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), - test=node.test, - body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), - body=node.body, - state_functions=state_functions, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - composite_symbol_names=composite_symbol_names, - opts=opts) - - undefined_assigns = self._create_undefined_assigns(possibly_undefs) - return undefined_assigns + node - - def visit_For(self, node): - node = self.generic_visit(node) - - (basic_loop_vars, composite_loop_vars, - reserved_symbols, possibly_undefs) = self._get_loop_vars( - node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified - | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified)) - loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( - basic_loop_vars) - body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) - - state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) - state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) - state_functions = self._create_state_functions( - composite_loop_vars, state_getter_name, state_setter_name) - - if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): - extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) - extra_test_name = self.ctx.namer.new_symbol( - 'extra_test', reserved_symbols) - template = """ - def extra_test_name(loop_vars): - return extra_test_expr - """ - extra_test_function = templates.replace( - template, - extra_test_name=extra_test_name, - loop_vars=loop_vars, - extra_test_expr=extra_test) - else: - extra_test_name = parser.parse_expression('None') - extra_test_function = [] - - # Workaround for PEP-3113 - # iterates_var holds a single variable with the iterates, which may be a - # tuple. - iterates_var_name = self.ctx.namer.new_symbol( - 'iterates', reserved_symbols) - template = """ - iterates = iterates_var_name - """ - iterate_expansion = templates.replace( - template, - iterates=node.target, - iterates_var_name=iterates_var_name) - - undefined_assigns = self._create_undefined_assigns(possibly_undefs) - - basic_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) - composite_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) - - opts = self._create_loop_options(node) - - # TODO(mdan): Use a single template. - # If the body and test functions took a single tuple for loop_vars, instead - # of *loop_vars, then a single template could be used. - if loop_vars: - template = """ - undefined_assigns - state_functions - def body_name(iterates_var_name, loop_vars): - iterate_expansion - body - return loop_vars, - extra_test_function - loop_vars_ast_tuple = ag__.for_stmt( - iter_, - extra_test_name, - body_name, - state_getter_name, - state_setter_name, - (loop_vars,), - (basic_symbol_names,), - (composite_symbol_names,), - opts) - """ - return templates.replace( - template, - undefined_assigns=undefined_assigns, - loop_vars=loop_vars, - loop_vars_ast_tuple=loop_vars_ast_tuple, - iter_=node.iter, - iterate_expansion=iterate_expansion, - iterates_var_name=iterates_var_name, - extra_test_name=extra_test_name, - extra_test_function=extra_test_function, - body_name=body_name, - body=node.body, - state_functions=state_functions, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names, - opts=opts) - else: - template = """ - undefined_assigns - state_functions - def body_name(iterates_var_name): - iterate_expansion - body - return () - extra_test_function - ag__.for_stmt( - iter_, - extra_test_name, - body_name, - state_getter_name, - state_setter_name, - (), - (), - (composite_symbol_names,), - opts) - """ - return templates.replace( - template, - undefined_assigns=undefined_assigns, - iter_=node.iter, - iterate_expansion=iterate_expansion, - iterates_var_name=iterates_var_name, - extra_test_name=extra_test_name, - extra_test_function=extra_test_function, - body_name=body_name, - body=node.body, - state_functions=state_functions, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - composite_symbol_names=composite_symbol_names, - opts=opts) - - -class AnnotatedDef(reaching_definitions.Definition): - - def __init__(self): - super(AnnotatedDef, self).__init__() - self.directives = {} - - -def transform(node, ctx): - graphs = cfg.build(node) - node = qual_names.resolve(node) - node = activity.resolve(node, ctx, None) - node = reaching_definitions.resolve(node, ctx, graphs) - node = reaching_fndefs.resolve(node, ctx, graphs) - node = liveness.resolve(node, ctx, graphs) - - node = ControlFlowTransformer(ctx).visit(node) - return node diff --git a/tensorflow/python/autograph/converters/list_comprehensions.py b/tensorflow/python/autograph/converters/list_comprehensions.py deleted file mode 100644 index 8e8b97d03cc43d..00000000000000 --- a/tensorflow/python/autograph/converters/list_comprehensions.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Lowers list comprehensions into for and if statements. - -Example: - - result = [x * x for x in xs] - -becomes - - result = [] - for x in xs: - elt = x * x - result.append(elt) -""" - -import gast - -from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.pyct import templates - - -# TODO(mdan): This should covert directly to operator calls. - - -class ListCompTransformer(converter.Base): - """Lowers list comprehensions into standard control flow.""" - - def visit_Assign(self, node): - if not isinstance(node.value, gast.ListComp): - return self.generic_visit(node) - if len(node.targets) > 1: - raise NotImplementedError('multiple assignments') - - target, = node.targets - list_comp_node = node.value - - template = """ - target = [] - """ - initialization = templates.replace(template, target=target) - - template = """ - target.append(elt) - """ - body = templates.replace(template, target=target, elt=list_comp_node.elt) - - for gen in reversed(list_comp_node.generators): - for gen_if in reversed(gen.ifs): - template = """ - if test: - body - """ - body = templates.replace(template, test=gen_if, body=body) - template = """ - for target in iter_: - body - """ - body = templates.replace( - template, iter_=gen.iter, target=gen.target, body=body) - - return initialization + body - - -def transform(node, ctx): - return ListCompTransformer(ctx).visit(node) diff --git a/tensorflow/python/autograph/converters/list_comprehensions_test.py b/tensorflow/python/autograph/converters/list_comprehensions_test.py deleted file mode 100644 index 630aad030c1e0a..00000000000000 --- a/tensorflow/python/autograph/converters/list_comprehensions_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for list_comprehensions module.""" - -from tensorflow.python.autograph.converters import list_comprehensions -from tensorflow.python.autograph.core import converter_testing -from tensorflow.python.platform import test - - -class ListCompTest(converter_testing.TestCase): - - def assertTransformedEquivalent(self, f, *inputs): - tr = self.transform(f, list_comprehensions) - self.assertEqual(f(*inputs), tr(*inputs)) - - def test_basic(self): - - def f(l): - s = [e * e for e in l] - return s - - self.assertTransformedEquivalent(f, []) - self.assertTransformedEquivalent(f, [1, 2, 3]) - - def test_multiple_generators(self): - - def f(l): - s = [e * e for sublist in l for e in sublist] # pylint:disable=g-complex-comprehension - return s - - self.assertTransformedEquivalent(f, []) - self.assertTransformedEquivalent(f, [[1], [2], [3]]) - - def test_cond(self): - - def f(l): - s = [e * e for e in l if e > 1] - return s - - self.assertTransformedEquivalent(f, []) - self.assertTransformedEquivalent(f, [1, 2, 3]) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py index 43dfa5f48e1622..a6613f47ab7d16 100644 --- a/tensorflow/python/autograph/converters/lists_test.py +++ b/tensorflow/python/autograph/converters/lists_test.py @@ -20,7 +20,7 @@ from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.lang import special_functions from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import list_ops from tensorflow.python.platform import test @@ -37,7 +37,7 @@ def f(): tl = tr() # Empty tensor lists cannot be evaluated or stacked. - self.assertIsInstance(tl, ops.Tensor) + self.assertIsInstance(tl, tensor.Tensor) self.assertEqual(tl.dtype, dtypes.variant) def test_initialized_list(self): diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index cde1fbf8bf2daf..765ab5fa24b67c 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -154,7 +154,7 @@ py_strict_test( ":data_structures", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:list_ops", "//tensorflow/python/ops:tensor_array_ops", diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py index 599d0a21e10ef5..707406b9651cda 100644 --- a/tensorflow/python/autograph/operators/data_structures_test.py +++ b/tensorflow/python/autograph/operators/data_structures_test.py @@ -17,7 +17,7 @@ from tensorflow.python.autograph.operators import data_structures from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops @@ -30,7 +30,7 @@ def test_new_list_empty(self): l = data_structures.new_list() # Can't evaluate an empty list. # TODO(mdan): sess.run should allow tf.variant maybe? - self.assertTrue(isinstance(l, ops.Tensor)) + self.assertTrue(isinstance(l, tensor.Tensor)) def test_new_list_tensor(self): l = data_structures.new_list([3, 4, 5]) diff --git a/tensorflow/python/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD index 254881dd92eb8a..d758c28801c315 100644 --- a/tensorflow/python/autograph/utils/BUILD +++ b/tensorflow/python/autograph/utils/BUILD @@ -19,7 +19,7 @@ py_strict_library( srcs = ["tensor_list.py"], visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:list_ops", "//tensorflow/python/ops:tensor_array_ops", ], @@ -51,6 +51,7 @@ py_strict_library( visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:math_ops_gen", diff --git a/tensorflow/python/autograph/utils/misc.py b/tensorflow/python/autograph/utils/misc.py index 7404ea5ec75c0a..d14b4758aba03f 100644 --- a/tensorflow/python/autograph/utils/misc.py +++ b/tensorflow/python/autograph/utils/misc.py @@ -15,6 +15,7 @@ """Miscellaneous utilities that don't fit anywhere else.""" from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -36,7 +37,7 @@ def alias_tensors(*args): """ def alias_if_tensor(a): - return array_ops.identity(a) if isinstance(a, ops.Tensor) else a + return array_ops.identity(a) if isinstance(a, tensor.Tensor) else a # TODO(mdan): Recurse into containers? # TODO(mdan): Anything we can do about variables? Fake a scope reuse? diff --git a/tensorflow/python/autograph/utils/tensor_list.py b/tensorflow/python/autograph/utils/tensor_list.py index c91b9be2868dac..c8bdf3ae982982 100644 --- a/tensorflow/python/autograph/utils/tensor_list.py +++ b/tensorflow/python/autograph/utils/tensor_list.py @@ -14,7 +14,7 @@ # ============================================================================== """A typed list in Python.""" -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops @@ -28,7 +28,7 @@ def dynamic_list_append(target, element): # It may be possible to use TensorList alone if the loop body will not # require wrapping it, although we'd have to think about an autoboxing # mechanism for lists received as parameter. - if isinstance(target, ops.Tensor): + if isinstance(target, tensor.Tensor): return list_ops.tensor_list_push_back(target, element) # Python targets (including TensorList): fallback to their original append. diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index 3c037c04be4e95..d9c0ad4432b068 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -273,6 +273,7 @@ py_strict_library( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", "//tensorflow/python/framework:stack", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:session_ops", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/training/experimental:mixed_precision_global_state", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 6c68d0c17595f0..8c0fec1591bf79 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import stack +from tensorflow.python.framework import tensor from tensorflow.python.ops import session_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.experimental import mixed_precision_global_state @@ -502,7 +503,7 @@ def __init__(self, graph, fetches, feeds, feed_handles=None): self._fetches.append(fetch) self._ops.append(False) # Remember the fetch if it is for a tensor handle. - if (isinstance(fetch, ops.Tensor) and + if (isinstance(fetch, tensor.Tensor) and (fetch.op.type == 'GetSessionHandle' or fetch.op.type == 'GetSessionHandleV2')): self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype @@ -1158,7 +1159,7 @@ def _feed_fn(feed, feed_val): raise TypeError( f'Cannot interpret feed_dict key as Tensor: {e.args[0]}') - if isinstance(subfeed_val, ops.Tensor): + if isinstance(subfeed_val, tensor.Tensor): raise TypeError( 'The value of a feed cannot be a tf.Tensor object. Acceptable ' 'feed values include Python scalars, strings, lists, numpy ' @@ -1322,7 +1323,7 @@ def _single_operation_run(): self._call_tf_sessionrun(None, {}, [], target_list, None) return _single_operation_run - elif isinstance(fetches, ops.Tensor): + elif isinstance(fetches, tensor.Tensor): # Special case for fetching a single tensor, because the # function can return the result of `TF_Run()` directly. assert len(fetch_list) == 1 diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index 46ebe0d5a3b527..61551f0f8c1e6f 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -1618,12 +1618,10 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { // Note: users should prefer using tf.cast or equivalent, and only when // it's infeasible to set the type via OpDef's type constructor and // inference function. - m.def("SetFullType", [](PyGraph* graph, TF_Operation* op, - const std::string& serialized_full_type) { - tensorflow::FullTypeDef proto; - proto.ParseFromString(serialized_full_type); - tensorflow::SetFullType(graph->tf_graph(), op, proto); - }); + m.def("SetFullType", + [](PyGraph* graph, TF_Operation* op, const TF_Buffer* full_type_proto) { + tensorflow::SetFullType(graph->tf_graph(), op, full_type_proto); + }); m.def( "TF_LoadLibrary", diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD index 5d6100d4f01210..404927a33501fe 100644 --- a/tensorflow/python/compat/BUILD +++ b/tensorflow/python/compat/BUILD @@ -21,6 +21,7 @@ py_strict_library( "//tensorflow/python/data/ops:readers", "//tensorflow/python/eager:monitoring", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:control_flow_v2_toggles", "//tensorflow/python/ops:variable_scope", diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 97bc7f8d44c1c1..a8256da81c5d6c 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 10) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 17) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py index 179f64008961dc..481c3c69e34855 100644 --- a/tensorflow/python/compat/v2_compat.py +++ b/tensorflow/python/compat/v2_compat.py @@ -23,6 +23,7 @@ from tensorflow.python.data.ops import readers from tensorflow.python.eager import monitoring from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import variable_scope @@ -60,7 +61,7 @@ def enable_v2_behavior(): ops.enable_eager_execution() tensor_shape.enable_v2_tensorshape() # Also switched by tf2 variable_scope.enable_resource_variables() - ops.enable_tensor_equality() + tensor.enable_tensor_equality() # Enables TensorArrayV2 and control flow V2. control_flow_v2_toggles.enable_control_flow_v2() # Make sure internal uses of tf.data symbols map to V2 versions. @@ -105,7 +106,7 @@ def disable_v2_behavior(): ops.disable_eager_execution() tensor_shape.disable_v2_tensorshape() # Also switched by tf2 variable_scope.disable_resource_variables() - ops.disable_tensor_equality() + tensor.disable_tensor_equality() # Disables TensorArrayV2 and control flow V2. control_flow_v2_toggles.disable_control_flow_v2() # Make sure internal uses of tf.data symbols map to V1 versions. diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index 89ee69d2637dc1..f3fd845ff53b10 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -49,6 +49,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/grappler:tf_optimizer", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:resource_variable_ops_gen", diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 392869eee6d0f5..2f80952e303852 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import importer from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.grappler import tf_optimizer from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -562,7 +563,7 @@ def _add_nodes_denylist(self): collection_def = self._grappler_meta_graph_def.collection_def["train_op"] denylist = collection_def.node_list.value for i in self._nodes_denylist: - if isinstance(i, ops.Tensor): + if isinstance(i, tensor.Tensor): denylist.append(_to_bytes(i.name)) else: denylist.append(_to_bytes(i)) @@ -692,7 +693,7 @@ def calibrate(self, for k, v in input_map_fn().items(): if not isinstance(k, str): raise ValueError("Keys of input_map_fn must be of type str") - if not isinstance(v, ops.Tensor): + if not isinstance(v, tensor.Tensor): raise ValueError("Values of input_map_fn must be of type tf.Tensor") self._calibration_graph = ops.Graph() diff --git a/tensorflow/python/compiler/xla/experimental/BUILD b/tensorflow/python/compiler/xla/experimental/BUILD index 595897108d19c8..018e4daedb1292 100644 --- a/tensorflow/python/compiler/xla/experimental/BUILD +++ b/tensorflow/python/compiler/xla/experimental/BUILD @@ -29,7 +29,7 @@ py_strict_test( "//tensorflow/compiler/xla:xla_data_proto_py", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//third_party/py/numpy", diff --git a/tensorflow/python/compiler/xla/experimental/xla_sharding_test.py b/tensorflow/python/compiler/xla/experimental/xla_sharding_test.py index 2f0281e99b21de..924f737c81fb48 100644 --- a/tensorflow/python/compiler/xla/experimental/xla_sharding_test.py +++ b/tensorflow/python/compiler/xla/experimental/xla_sharding_test.py @@ -22,7 +22,7 @@ from tensorflow.python.compiler.xla.experimental import xla_sharding from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -90,7 +90,7 @@ def test_tile_annotates_tensor_correctly(self): def tile_helper(tensor): self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6])) - self.assertIsInstance(tiled_tensor, ops.Tensor) + self.assertIsInstance(tiled_tensor, tensor_lib.Tensor) tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor) tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding) # This is the shape of the tile assignment [2, 1, 6] @@ -108,7 +108,7 @@ def test_split_annotates_tensor_correctly(self): def split_helper(tensor): self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) split_tensor = xla_sharding.split(tensor, 2, 3) - self.assertIsInstance(split_tensor, ops.Tensor) + self.assertIsInstance(split_tensor, tensor_lib.Tensor) split_sharding = xla_sharding.get_tensor_sharding(split_tensor) split_shape = xla_sharding.get_sharding_tile_shape(split_sharding) expected_shape = [1, 1, 3] diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index d58641c73509fe..f659a77f5914fa 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -313,7 +313,7 @@ tf_py_strict_test( "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:io_ops", "//tensorflow/python/ops:parsing_ops", "//tensorflow/python/platform:client_testlib", @@ -537,6 +537,7 @@ tf_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:parsing_ops", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py index 6d7b26ce88c10a..bfba1b69547d5d 100644 --- a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test @@ -214,7 +214,7 @@ def testDropFinalBatch(self, batch_size, num_epochs): batch_size=batch_size, drop_final_batch=True) for tensor in nest.flatten(outputs): - if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. + if isinstance(tensor, tensor_lib.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) @combinations.generate(test_base.default_test_combinations()) @@ -227,7 +227,7 @@ def testIndefiniteRepeatShapeInference(self): for shape, clazz in zip( nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)), nest.flatten(dataset_ops.get_legacy_output_classes(dataset))): - if issubclass(clazz, ops.Tensor): + if issubclass(clazz, tensor_lib.Tensor): self.assertEqual(32, shape[0]) @combinations.generate(test_base.default_test_combinations()) diff --git a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py index 373aebc2d3e231..28e9c1379b9c73 100644 --- a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import parsing_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test @@ -96,8 +97,10 @@ def _test(self, # Check shapes; if serialized is a Tensor we need its size to # properly check. batch_size = ( - self.evaluate(input_tensor).size if isinstance(input_tensor, ops.Tensor) - else np.asarray(input_tensor).size) + self.evaluate(input_tensor).size + if isinstance(input_tensor, tensor.Tensor) + else np.asarray(input_tensor).size + ) for k, f in feature_val.items(): if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None: self.assertEqual( diff --git a/tensorflow/python/data/experimental/kernel_tests/service/BUILD b/tensorflow/python/data/experimental/kernel_tests/service/BUILD index 3fbe3508facd08..6c379d28da5808 100644 --- a/tensorflow/python/data/experimental/kernel_tests/service/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/service/BUILD @@ -247,6 +247,9 @@ tf_py_strict_test( size = "medium", srcs = ["distributed_save_ft_test.py"], shard_count = 17, + tags = [ + "no_mac", # TODO(b/290355883): Fix the flakyness in macos + ], deps = [ ":test_base", "//tensorflow/python/data/experimental/ops:distributed_save_op", diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index d2018a755e031f..9f96277f5dcd24 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -81,7 +81,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:experimental_dataset_ops_gen", "//tensorflow/python/ops:string_ops", @@ -243,7 +243,7 @@ py_strict_library( ":cardinality", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:experimental_dataset_ops_gen", "//tensorflow/python/ops:lookup_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index d2732ed973944c..3a003314ec49d5 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -30,7 +30,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.ops import string_ops @@ -200,7 +200,7 @@ def _get_compression_proto(compression): def _to_tensor(dataset_id): """Converts `dataset_id` to Tensor.""" - if isinstance(dataset_id, ops.Tensor): + if isinstance(dataset_id, tensor.Tensor): return dataset_id if isinstance(dataset_id, str) or isinstance(dataset_id, bytes): return ops.convert_to_tensor( @@ -212,7 +212,7 @@ def _to_tensor(dataset_id): def _to_string(dataset_id): """Converts `dataset_id` to string.""" - if isinstance(dataset_id, ops.Tensor): + if isinstance(dataset_id, tensor.Tensor): return (dataset_id if dataset_id.dtype == dtypes.string else string_ops.as_string(dataset_id)) return (dataset_id.decode() @@ -334,7 +334,7 @@ def __init__(self, uncompress_func = structured_function.StructuredFunctionWrapper( lambda x: compression_ops.uncompress(x, output_spec=element_spec), transformation_name="DataServiceDataset.uncompress()", - input_structure=tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)) + input_structure=tensor.TensorSpec(shape=(), dtype=dtypes.variant)) cross_trainer_cache_options = ( cross_trainer_cache._to_proto().SerializeToString() if cross_trainer_cache else None) @@ -1004,7 +1004,8 @@ def _get_element_spec(): else: protocol, address = _parse_service(service) if job_name is not None: - if not isinstance(job_name, str) and not isinstance(job_name, ops.Tensor): + if not isinstance(job_name, str) and not isinstance( + job_name, tensor.Tensor): raise ValueError( "`job_name` must be a string or Tensor, but `job_name` was of type " f"{type(job_name)}. job_name={job_name}.") diff --git a/tensorflow/python/data/experimental/ops/lookup_ops.py b/tensorflow/python/data/experimental/ops/lookup_ops.py index 6fc0fef4761cf0..aef2902813eca1 100644 --- a/tensorflow/python/data/experimental/ops/lookup_ops.py +++ b/tensorflow/python/data/experimental/ops/lookup_ops.py @@ -17,7 +17,7 @@ from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops @@ -35,10 +35,10 @@ def _check_table_initializer_element_spec(element_spec): f"{len(element_spec)} components instead of two " "(key, value) components. Full dataset element spec: " f"{element_spec}.") - if not isinstance(element_spec[0], tensor_spec.TensorSpec): + if not isinstance(element_spec[0], tensor.TensorSpec): raise ValueError(base_error + "However, the given dataset produces " f"non-Tensor keys of type {type(element_spec[0])}.") - if not isinstance(element_spec[1], tensor_spec.TensorSpec): + if not isinstance(element_spec[1], tensor.TensorSpec): raise ValueError(base_error + "However, the given dataset produces " f"non-Tensor values of type {type(element_spec[1])}.") if element_spec[0].shape.rank not in (None, 0): @@ -163,14 +163,14 @@ def table_from_dataset(dataset=None, if num_oov_buckets < 0: raise ValueError("`num_oov_buckets` must be greater than or equal to 0, " f"got {num_oov_buckets}.") - if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and + if (not isinstance(vocab_size, tensor.Tensor) and vocab_size is not None and vocab_size < 1): raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.") if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): raise TypeError("`key_dtype` must be either an integer or string type, " f"but got {key_dtype}") if vocab_size is not None: - if isinstance(vocab_size, ops.Tensor): + if isinstance(vocab_size, tensor.Tensor): vocab_size = math_ops.cast(vocab_size, dtypes.int64) dataset = dataset.take(vocab_size) dataset = dataset.apply(assert_cardinality(vocab_size)) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 5768e0605fe9ed..c30eb5106394bd 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -172,6 +172,7 @@ py_strict_library( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:lookup_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/ops/ragged:ragged_tensor_value", @@ -622,7 +623,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:data_flow_ops", diff --git a/tensorflow/python/data/kernel_tests/checkpoint_test_base.py b/tensorflow/python/data/kernel_tests/checkpoint_test_base.py index 5f42c9fc0053ff..21b9266c90b806 100644 --- a/tensorflow/python/data/kernel_tests/checkpoint_test_base.py +++ b/tensorflow/python/data/kernel_tests/checkpoint_test_base.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_tensor_value @@ -43,7 +44,7 @@ def remove_variants(get_next_op): """Remove variants from a nest structure, so sess.run will execute.""" def _remove_variant(x): - if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: + if isinstance(x, tensor.Tensor) and x.dtype == dtypes.variant: return () else: return x diff --git a/tensorflow/python/data/kernel_tests/from_tensors_test.py b/tensorflow/python/data/kernel_tests/from_tensors_test.py index 7584d6bd2357c1..aa564d371421b7 100644 --- a/tensorflow/python/data/kernel_tests/from_tensors_test.py +++ b/tensorflow/python/data/kernel_tests/from_tensors_test.py @@ -14,8 +14,9 @@ # ============================================================================== """Tests for `tf.data.Dataset.from_tensors().""" import collections -from absl.testing import parameterized +import dataclasses +from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 @@ -45,6 +46,22 @@ attr = None +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: np.ndarray + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + def __tf_unflatten__(self, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) @@ -151,6 +168,12 @@ class Foo: dataset = dataset_ops.Dataset.from_tensors(element) self.assertDatasetProduces(dataset, expected_output=[element]) + @combinations.generate(test_base.default_test_combinations()) + def testFromTensorsDataclass(self): + mt = MaskedTensor(mask=True, value=np.array([1])) + dataset = dataset_ops.Dataset.from_tensors(mt) + self.assertDatasetProduces(dataset, expected_output=[mt]) + @combinations.generate(test_base.default_test_combinations()) def testFromTensorsMixedRagged(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0), diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index e5cd37bb6e1db0..8d1e384e033683 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops @@ -808,9 +808,8 @@ def testRepeatedGetNextWarning(self): combinations.times( test_base.default_test_combinations(), combinations.combine( - expected_element_structure=tensor_spec.TensorSpec([], - dtypes.float32), - expected_output_classes=ops.Tensor, + expected_element_structure=tensor.TensorSpec([], dtypes.float32), + expected_output_classes=tensor.Tensor, expected_output_types=dtypes.float32, expected_output_shapes=[[]]))) def testTensorIteratorStructure(self, expected_element_structure, @@ -872,13 +871,13 @@ def tf_value_fn(): combinations.combine( expected_element_structure={ "a": - tensor_spec.TensorSpec([], dtypes.float32), - "b": (tensor_spec.TensorSpec([1], dtypes.string), - tensor_spec.TensorSpec([], dtypes.string)) + tensor.TensorSpec([], dtypes.float32), + "b": (tensor.TensorSpec([1], dtypes.string), + tensor.TensorSpec([], dtypes.string)) }, expected_output_classes={ - "a": ops.Tensor, - "b": (ops.Tensor, ops.Tensor) + "a": tensor.Tensor, + "b": (tensor.Tensor, tensor.Tensor) }, expected_output_types={ "a": dtypes.float32, @@ -973,7 +972,7 @@ def finalize_fn(n): @def_function.function def fn(): - output_signature = tensor_spec.TensorSpec((), dtypes.int64) + output_signature = tensor.TensorSpec((), dtypes.int64) dataset = from_generator_op._GeneratorDataset(1, init_fn, next_fn, finalize_fn, output_signature) diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py index 608764cff7ae6a..a949be7b1893d2 100644 --- a/tensorflow/python/data/kernel_tests/map_test.py +++ b/tensorflow/python/data/kernel_tests/map_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Tests for `tf.data.Dataset.map()`.""" import collections +import dataclasses import functools import threading import time @@ -135,6 +136,59 @@ def __init__(self): pass +@dataclasses.dataclass +class MyDataclass: + value1: ops.Tensor + value2: ops.Tensor + + def __tf_flatten__(self): + metadata = tuple() + components = (self.value1, self.value2) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, components): + del metadata + return cls(value1=components[0], value2=components[1]) + + +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: ops.Tensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + +@dataclasses.dataclass +class NestedMaskedTensor: + mask: bool + value: MaskedTensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, components): + mask = metadata[0] + value = components[0] + return NestedMaskedTensor(mask=mask, value=value) + + def __eq__(self, other): + return self.mask == other.mask and self.value == other.value + + class MapTest(test_base.DatasetTestBase, parameterized.TestCase): def _map_dataset_factory(self, components, apply_map, count): @@ -547,6 +601,118 @@ def testMapDict(self, apply_map): self.assertDatasetProduces( dataset, expected_output=[i * 2 + i**2 for i in range(10)]) + @combinations.generate(_test_combinations()) + def testMapDataclass(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: MyDataclass(value1=x, value2=2 * x)) + dataset = apply_map(dataset, lambda x: x.value1 + x.value2) + self.assertDatasetProduces( + dataset, + expected_output=[3 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapMaskedTensor(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: MaskedTensor(mask=True, value=x)) + dataset = apply_map(dataset, lambda x: 3 * x.value) + self.assertDatasetProduces( + dataset, + expected_output=[3 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapDataclassWithInputAndOutput(self, apply_map): + dataset = dataset_ops.Dataset.from_tensors(MyDataclass(value1=1, value2=2)) + dataset = apply_map(dataset, lambda x: (x.value1 * 5, x.value2)) + dataset = apply_map( + dataset, lambda x, y: MaskedTensor(mask=True, value=x + y) + ) + dataset = apply_map( + dataset, lambda m: NestedMaskedTensor(mask=False, value=m) + ) + self.assertDatasetProduces( + dataset, + expected_output=[ + NestedMaskedTensor( + mask=False, value=MaskedTensor(mask=True, value=7) + ) + ], + ) + + @combinations.generate(_test_combinations()) + def testMapListOfDataclassObjects(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + + # Creates a list of dataclass objects. + dataset = apply_map( + dataset, + lambda x: [ # pylint: disable=g-long-lambda + MyDataclass(value1=x, value2=1), + MyDataclass(value1=2, value2=2 * x), + ], + ) + + # Takes a list of dataclass objects as input. + dataset = apply_map(dataset, lambda *x: x[0].value1 + x[1].value2) + + self.assertDatasetProduces( + dataset, + expected_output=[3 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapDictOfDataclassValues(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + + # Creates a dict of {str -> dataclass}. + dataset = apply_map( + dataset, + lambda x: { # pylint: disable=g-long-lambda + "a": MyDataclass(value1=x, value2=1), + "b": MyDataclass(value1=2, value2=2 * x), + }, + ) + # Takes a dict of dataclass values as input. + dataset = apply_map(dataset, lambda x: x["a"].value1 + x["b"].value2) + self.assertDatasetProduces( + dataset, + expected_output=[3 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapNestedMaskedTensorWithDataclassInput(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: MaskedTensor(mask=True, value=x)) + dataset = apply_map( + dataset, + # Takes a MaskedTensor as input. + lambda x: NestedMaskedTensor(mask=False, value=x), + ) + dataset = apply_map(dataset, lambda x: 5 * x.value.value) + self.assertDatasetProduces( + dataset, + expected_output=[5 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapNestedMaskedTensorWithDataclassOutput(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map( + dataset, + lambda x: NestedMaskedTensor( # pylint: disable=g-long-lambda + mask=False, value=MaskedTensor(mask=True, value=x) + ), + ) + + # Return a MaskedTensor as the return value. + dataset = apply_map(dataset, lambda x: x.value) + dataset = apply_map(dataset, lambda x: 7 * x.value) + self.assertDatasetProduces( + dataset, + expected_output=[7 * x for x in range(10)], + ) + @combinations.generate(_test_combinations()) def testMapNamedtuple(self, apply_map): # construct dataset of tuples diff --git a/tensorflow/python/data/kernel_tests/zip_test.py b/tensorflow/python/data/kernel_tests/zip_test.py index c4f270f59648b2..c81c70a2cd8bd6 100644 --- a/tensorflow/python/data/kernel_tests/zip_test.py +++ b/tensorflow/python/data/kernel_tests/zip_test.py @@ -14,8 +14,9 @@ # ============================================================================== """Tests for `tf.data.Dataset.zip()`.""" import collections -from absl.testing import parameterized +import dataclasses +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import random_access @@ -42,6 +43,23 @@ def _dataset_factory(components): return dataset_ops.Dataset.zip(datasets) +@dataclasses.dataclass +class MaskedNdarrayPair: + mask: bool + value1: np.ndarray + value2: np.ndarray + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value1, self.value2) + return metadata, components + + def __tf_unflatten__(self, metadata, components): + mask = metadata[0] + value1, value2 = components + return MaskedNdarrayPair(mask=mask, value1=value1, value2=value2) + + class ZipTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) @@ -112,6 +130,21 @@ def testNamedTuple(self): expected = [Foo(x=0, y=3), Foo(x=1, y=4), Foo(x=2, y=5)] self.assertDatasetProduces(dataset, expected) + @combinations.generate(test_base.default_test_combinations()) + def testDataclass(self): + mtp = MaskedNdarrayPair( + mask=True, + value1=dataset_ops.Dataset.range(3), + value2=dataset_ops.Dataset.range(3, 6), + ) + dataset = dataset_ops.Dataset.zip(mtp) + expected = [ + MaskedNdarrayPair(mask=True, value1=0, value2=3), + MaskedNdarrayPair(mask=True, value1=1, value2=4), + MaskedNdarrayPair(mask=True, value1=2, value2=5), + ] + self.assertDatasetProduces(dataset, expected) + @combinations.generate(test_base.default_test_combinations()) def testAttrs(self): if attr is None: diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index b02e4d4985d959..0ca307c5b58c10 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -119,6 +119,7 @@ py_strict_library( "//tensorflow/python/framework:random_seed", "//tensorflow/python/framework:smart_cond", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", @@ -178,6 +179,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", @@ -192,7 +194,6 @@ py_strict_library( "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:lazy_loader", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 5354345577e9c2..82defdbf3210ff 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -30,8 +30,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_utils from tensorflow.python.ops import gen_dataset_ops @@ -219,7 +219,7 @@ def from_structure(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: - output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) + output_classes = nest.map_structure(lambda _: tensor.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) @@ -293,7 +293,7 @@ def from_string_handle(string_handle, tensor_shape.as_shape, output_shapes) if output_classes is None: - output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) + output_classes = nest.map_structure(lambda _: tensor.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) @@ -930,7 +930,7 @@ def _serialize(self): @property def _component_specs(self): - return (tensor_spec.TensorSpec([], dtypes.resource),) + return (tensor.TensorSpec([], dtypes.resource),) def _to_components(self, value): return (value._iterator_resource,) # pylint: disable=protected-access diff --git a/tensorflow/python/data/ops/ragged_batch_op.py b/tensorflow/python/data/ops/ragged_batch_op.py index 02a147e796d375..bb886ca7bbd5ff 100644 --- a/tensorflow/python/data/ops/ragged_batch_op.py +++ b/tensorflow/python/data/ops/ragged_batch_op.py @@ -17,8 +17,7 @@ from tensorflow.python.data.ops import structured_function from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.ops.ragged import ragged_tensor @@ -56,7 +55,7 @@ def __init__(self, input_dataset, row_splits_dtype, name=None): # corresponding RaggedTensorSpec. def to_ragged_spec(spec): """Returns the new spec based on RaggedTensors.""" - if (not isinstance(spec, tensor_spec.TensorSpec) or + if (not isinstance(spec, tensor.TensorSpec) or spec.shape.rank is None or spec.shape.is_fully_defined()): return spec @@ -80,12 +79,12 @@ def to_ragged_spec(spec): # RaggedTensorSpec._from_tensor_list. def to_ragged_variant(value): """Re-encode Tensors as RaggedTensors.""" - if (not isinstance(value, ops.Tensor) or + if (not isinstance(value, tensor.Tensor) or value.shape.rank is None or value.shape.is_fully_defined()): return value else: - spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) + spec = to_ragged_spec(tensor.TensorSpec.from_tensor(value)) if spec._ragged_rank > 0: # pylint: disable=protected-access value = ragged_tensor.RaggedTensor.from_tensor( value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access diff --git a/tensorflow/python/data/ops/sample_from_datasets_op.py b/tensorflow/python/data/ops/sample_from_datasets_op.py index 38c53bfb072050..29fc0d627d1436 100644 --- a/tensorflow/python/data/ops/sample_from_datasets_op.py +++ b/tensorflow/python/data/ops/sample_from_datasets_op.py @@ -19,6 +19,7 @@ from tensorflow.python.data.ops import map_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_stateless_random_ops from tensorflow.python.ops import math_ops @@ -48,7 +49,7 @@ def _skip_datasets_with_zero_weight(datasets, weights): logits = [[1.0] * len(datasets)] else: - if isinstance(weights, ops.Tensor): + if isinstance(weights, tensor.Tensor): if not weights.shape.is_compatible_with([len(datasets)]): raise ValueError(f"Invalid `weights`. The shape of `weights` " f"should be compatible with `[len(datasets)]` " @@ -62,7 +63,7 @@ def _skip_datasets_with_zero_weight(datasets, weights): # Use the given `weights` as the probability of choosing the respective # input. - if not isinstance(weights, ops.Tensor): + if not isinstance(weights, tensor.Tensor): datasets, weights = _skip_datasets_with_zero_weight(datasets, weights) weights = ops.convert_to_tensor(weights, name="weights") if weights.dtype not in (dtypes.float32, dtypes.float64): diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index 126c2d12967616..abb9dbcded44f9 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -26,6 +26,7 @@ py_strict_test( "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -43,8 +44,8 @@ py_strict_library( deps = [ ":nest", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:sparse_ops", ], @@ -63,8 +64,8 @@ py_strict_test( "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", @@ -80,8 +81,8 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:type_spec", "//tensorflow/python/framework:type_spec_registry", "//tensorflow/python/ops:resource_variable_ops", @@ -91,6 +92,7 @@ py_strict_library( "//tensorflow/python/types:internal", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", + "//tensorflow/python/util:nest_util", "//tensorflow/python/util:tf_export", "@wrapt", ], @@ -112,8 +114,8 @@ py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:tensor_array_ops", "//tensorflow/python/ops:variables", diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py index edc19d3496c8b7..cdd0f01a938cd8 100644 --- a/tensorflow/python/data/util/nest_test.py +++ b/tensorflow/python/data/util/nest_test.py @@ -15,13 +15,16 @@ """Tests for utilities working with arbitrarily nested structures.""" import collections -import numpy as np +import dataclasses + from absl.testing import parameterized +import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.util import nest from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -29,6 +32,22 @@ from tensorflow.python.platform import test +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: ops.Tensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + def __tf_unflatten__(self, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + class NestTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) @@ -66,6 +85,89 @@ def testFlattenAndPack(self): with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) + @combinations.generate(test_base.default_test_combinations()) + def testDataclassIsNested(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + self.assertTrue(nest.is_nested(mt)) + + @combinations.generate(test_base.default_test_combinations()) + def testFlattenDataclass(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + leaves = nest.flatten(mt) + self.assertLen(leaves, 1) + self.assertAllEqual(leaves[0], [1]) + + @combinations.generate(test_base.default_test_combinations()) + def testPackDataclass(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + leaves = nest.flatten(mt) + reconstructed_mt = nest.pack_sequence_as(mt, leaves) + self.assertIsInstance(reconstructed_mt, MaskedTensor) + self.assertEqual(reconstructed_mt.mask, mt.mask) + self.assertAllEqual(reconstructed_mt.value, mt.value) + + mt2 = MaskedTensor(mask=False, value=constant_op.constant([2])) + reconstructed_mt = nest.pack_sequence_as(mt2, leaves) + self.assertIsInstance(reconstructed_mt, MaskedTensor) + self.assertFalse(reconstructed_mt.mask) + self.assertAllEqual(reconstructed_mt.value, [1]) + + @combinations.generate(test_base.default_test_combinations()) + def testDataclassMapStructure(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + mt_doubled = nest.map_structure(lambda x: x * 2, mt) + self.assertIsInstance(mt_doubled, MaskedTensor) + self.assertEqual(mt_doubled.mask, True) + self.assertAllEqual(mt_doubled.value, [2]) + + @combinations.generate(test_base.default_test_combinations()) + def testDataclassAssertSameStructure(self): + mt1 = MaskedTensor(mask=True, value=constant_op.constant([1])) + mt2 = MaskedTensor(mask=False, value=constant_op.constant([2])) + nest.assert_same_structure(mt1, mt2) + + mt3 = (1, 2) + + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + TypeError, + "don't have the same nested structure", + ): + nest.assert_same_structure(mt1, mt3) + + class SubMaskedTensor(MaskedTensor): + pass + + mt_subclass = SubMaskedTensor(mask=True, value=constant_op.constant([1])) + nest.assert_same_structure(mt1, mt_subclass, check_types=False) + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + TypeError, + "don't have the same sequence type", + ): + nest.assert_same_structure(mt1, mt_subclass) + + @combinations.generate(test_base.default_test_combinations()) + def testDataclassAssertShallowStructure(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + structure1 = ("a", "b") + structure2 = (mt, "c") + nest.assert_shallow_structure(structure1, structure2) + + structure3 = (mt, "d", "e") + + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, + "don't have the same sequence length", + ): + nest.assert_shallow_structure(structure1, structure3) + + structure4 = {"a": mt, "b": "c"} + nest.assert_shallow_structure(structure1, structure4, check_types=False) + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + TypeError, + "don't have the same sequence type", + ): + nest.assert_shallow_structure(structure1, structure4) + @combinations.generate(test_base.default_test_combinations()) def testFlattenDictOrder(self): """`flatten` orders dicts by key, including OrderedDicts.""" diff --git a/tensorflow/python/data/util/sparse.py b/tensorflow/python/data/util/sparse.py index f56a905058a0ff..1f1fc794b1ebae 100644 --- a/tensorflow/python/data/util/sparse.py +++ b/tensorflow/python/data/util/sparse.py @@ -15,8 +15,8 @@ """Python dataset sparse tensor utility functions.""" from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import sparse_ops @@ -107,7 +107,7 @@ def get_classes(tensors): """ return nest.pack_sequence_as(tensors, [ sparse_tensor.SparseTensor - if isinstance(tensor, sparse_tensor.SparseTensor) else ops.Tensor + if isinstance(tensor, sparse_tensor.SparseTensor) else tensor_lib.Tensor for tensor in nest.flatten(tensors) ]) diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py index eb1c592b975b20..c82067650d221d 100644 --- a/tensorflow/python/data/util/sparse_test.py +++ b/tensorflow/python/data/util/sparse_test.py @@ -24,8 +24,8 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test @@ -39,11 +39,11 @@ def _test_any_sparse_combinations(): cases = [("TestCase_0", lambda: (), False), - ("TestCase_1", lambda: (ops.Tensor), False), - ("TestCase_2", lambda: (((ops.Tensor))), False), - ("TestCase_3", lambda: (ops.Tensor, ops.Tensor), False), + ("TestCase_1", lambda: (tensor.Tensor), False), + ("TestCase_2", lambda: (((tensor.Tensor))), False), + ("TestCase_3", lambda: (tensor.Tensor, tensor.Tensor), False), ("TestCase_4", lambda: - (ops.Tensor, sparse_tensor.SparseTensor), True), + (tensor.Tensor, sparse_tensor.SparseTensor), True), ("TestCase_5", lambda: (sparse_tensor.SparseTensor, sparse_tensor.SparseTensor), True), ("TestCase_6", lambda: (((sparse_tensor.SparseTensor))), True)] @@ -62,7 +62,8 @@ def _test_as_dense_shapes_combinations(): cases = [ ("TestCase_0", lambda: (), lambda: (), lambda: ()), - ("TestCase_1", lambda: tensor_shape.TensorShape([]), lambda: ops.Tensor, + ("TestCase_1", lambda: tensor_shape.TensorShape([]), + lambda: tensor.Tensor, lambda: tensor_shape.TensorShape([])), ( "TestCase_2", @@ -71,7 +72,7 @@ def _test_as_dense_shapes_combinations(): lambda: tensor_shape.unknown_shape() # pylint: disable=unnecessary-lambda ), ("TestCase_3", lambda: (tensor_shape.TensorShape([])), lambda: - (ops.Tensor), lambda: (tensor_shape.TensorShape([]))), + (tensor.Tensor), lambda: (tensor_shape.TensorShape([]))), ( "TestCase_4", lambda: (tensor_shape.TensorShape([])), @@ -79,9 +80,9 @@ def _test_as_dense_shapes_combinations(): lambda: (tensor_shape.unknown_shape()) # pylint: disable=unnecessary-lambda ), ("TestCase_5", lambda: (tensor_shape.TensorShape([]), ()), lambda: - (ops.Tensor, ()), lambda: (tensor_shape.TensorShape([]), ())), + (tensor.Tensor, ()), lambda: (tensor_shape.TensorShape([]), ())), ("TestCase_6", lambda: ((), tensor_shape.TensorShape([])), lambda: - ((), ops.Tensor), lambda: ((), tensor_shape.TensorShape([]))), + ((), tensor.Tensor), lambda: ((), tensor_shape.TensorShape([]))), ("TestCase_7", lambda: (tensor_shape.TensorShape([]), ()), lambda: (sparse_tensor.SparseTensor, ()), lambda: (tensor_shape.unknown_shape(), ())), @@ -90,14 +91,14 @@ def _test_as_dense_shapes_combinations(): (), tensor_shape.unknown_shape())), ("TestCase_9", lambda: (tensor_shape.TensorShape([]), (), tensor_shape.TensorShape([])), lambda: - (ops.Tensor, (), ops.Tensor), lambda: + (tensor.Tensor, (), tensor.Tensor), lambda: (tensor_shape.TensorShape([]), (), tensor_shape.TensorShape([]))), ("TestCase_10", lambda: (tensor_shape.TensorShape([]), (), tensor_shape.TensorShape([])), lambda: (sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor), lambda: (tensor_shape.unknown_shape(), (), tensor_shape.unknown_shape())), ("TestCase_11", lambda: ((), tensor_shape.TensorShape([]), ()), lambda: - ((), ops.Tensor, ()), lambda: ((), tensor_shape.TensorShape([]), ())), + ((), tensor.Tensor, ()), lambda: ((), tensor_shape.TensorShape([]), ())), ("TestCase_12", lambda: ((), tensor_shape.TensorShape([]), ()), lambda: ((), sparse_tensor.SparseTensor, ()), lambda: ((), tensor_shape.unknown_shape(), ())) @@ -118,29 +119,30 @@ def reduce_fn(x, y): def _test_as_dense_types_combinations(): cases = [ ("TestCase_0", lambda: (), lambda: (), lambda: ()), - ("TestCase_1", lambda: dtypes.int32, lambda: ops.Tensor, + ("TestCase_1", lambda: dtypes.int32, lambda: tensor.Tensor, lambda: dtypes.int32), ("TestCase_2", lambda: dtypes.int32, lambda: sparse_tensor.SparseTensor, lambda: dtypes.variant), - ("TestCase_3", lambda: (dtypes.int32), lambda: (ops.Tensor), lambda: + ("TestCase_3", lambda: (dtypes.int32), lambda: (tensor.Tensor), lambda: (dtypes.int32)), ("TestCase_4", lambda: (dtypes.int32), lambda: (sparse_tensor.SparseTensor), lambda: (dtypes.variant)), ("TestCase_5", lambda: (dtypes.int32, ()), lambda: - (ops.Tensor, ()), lambda: (dtypes.int32, ())), + (tensor.Tensor, ()), lambda: (dtypes.int32, ())), ("TestCase_6", lambda: ((), dtypes.int32), lambda: - ((), ops.Tensor), lambda: ((), dtypes.int32)), + ((), tensor.Tensor), lambda: ((), dtypes.int32)), ("TestCase_7", lambda: (dtypes.int32, ()), lambda: (sparse_tensor.SparseTensor, ()), lambda: (dtypes.variant, ())), ("TestCase_8", lambda: ((), dtypes.int32), lambda: ((), sparse_tensor.SparseTensor), lambda: ((), dtypes.variant)), ("TestCase_9", lambda: (dtypes.int32, (), dtypes.int32), lambda: - (ops.Tensor, (), ops.Tensor), lambda: (dtypes.int32, (), dtypes.int32)), + (tensor.Tensor, (), tensor.Tensor), + lambda: (dtypes.int32, (), dtypes.int32)), ("TestCase_10", lambda: (dtypes.int32, (), dtypes.int32), lambda: (sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor), lambda: (dtypes.variant, (), dtypes.variant)), ("TestCase_11", lambda: ((), dtypes.int32, ()), lambda: - ((), ops.Tensor, ()), lambda: ((), dtypes.int32, ())), + ((), tensor.Tensor, ()), lambda: ((), dtypes.int32, ())), ("TestCase_12", lambda: ((), dtypes.int32, ()), lambda: ((), sparse_tensor.SparseTensor, ()), lambda: ((), dtypes.variant, ())), ] @@ -163,11 +165,12 @@ def _test_get_classes_combinations(): ("TestCase_1", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=[1], dense_shape=[1]), lambda: sparse_tensor.SparseTensor), - ("TestCase_2", lambda: constant_op.constant([1]), lambda: ops.Tensor), + ("TestCase_2", lambda: constant_op.constant([1]), lambda: tensor.Tensor), ("TestCase_3", lambda: (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])), lambda: (sparse_tensor.SparseTensor)), - ("TestCase_4", lambda: (constant_op.constant([1])), lambda: (ops.Tensor)), + ("TestCase_4", lambda: (constant_op.constant([1])), + lambda: (tensor.Tensor)), ("TestCase_5", lambda: (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), ()), lambda: (sparse_tensor.SparseTensor, ())), @@ -176,19 +179,19 @@ def _test_get_classes_combinations(): sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])), lambda: ((), sparse_tensor.SparseTensor)), ("TestCase_7", lambda: (constant_op.constant([1]), ()), lambda: - (ops.Tensor, ())), + (tensor.Tensor, ())), ("TestCase_8", lambda: ((), constant_op.constant([1])), lambda: - ((), ops.Tensor)), + ((), tensor.Tensor)), ("TestCase_9", lambda: (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), (), constant_op.constant([1])), lambda: (sparse_tensor.SparseTensor, - (), ops.Tensor)), + (), tensor.Tensor)), ("TestCase_10", lambda: ((), sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), ()), lambda: ((), sparse_tensor.SparseTensor, ())), ("TestCase_11", lambda: ((), constant_op.constant([1]), ()), lambda: - ((), ops.Tensor, ())), + ((), tensor.Tensor, ())), ] def reduce_fn(x, y): diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 14e9cb0e4ff9a4..43dfb7456c05e7 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -23,8 +23,8 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry from tensorflow.python.ops import resource_variable_ops @@ -34,6 +34,7 @@ from tensorflow.python.types import internal from tensorflow.python.util import deprecation from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.nest_util import CustomNestProtocol from tensorflow.python.util.tf_export import tf_export @@ -41,7 +42,7 @@ @tf_export(v1=["data.experimental.TensorStructure"]) @deprecation.deprecated(None, "Use `tf.TensorSpec` instead.") def _TensorStructure(dtype, shape): - return tensor_spec.TensorSpec(shape, dtype) + return tensor_lib.TensorSpec(shape, dtype) @tf_export(v1=["data.experimental.SparseTensorStructure"]) @@ -171,8 +172,8 @@ def convert_legacy_structure(output_types, output_shapes, output_classes): flat_ret.append(flat_class) elif issubclass(flat_class, sparse_tensor.SparseTensor): flat_ret.append(sparse_tensor.SparseTensorSpec(flat_shape, flat_type)) - elif issubclass(flat_class, ops.Tensor): - flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type)) + elif issubclass(flat_class, tensor_lib.Tensor): + flat_ret.append(tensor_lib.TensorSpec(flat_shape, flat_type)) elif issubclass(flat_class, tensor_array_ops.TensorArray): # We sneaked the dynamic_size and infer_shape into the legacy shape. flat_ret.append( @@ -493,6 +494,12 @@ def type_spec_from_value(element, use_fallback=True): type_spec_from_value(getattr(element, a.name)) for a in attrs ]) + if isinstance(element, CustomNestProtocol): + # pylint: disable=protected-access + metadata, children = element.__tf_flatten__() + return element.__tf_unflatten__(metadata, type_spec_from_value(children)) + # pylint: enable=protected-access + if use_fallback: # As a fallback try converting the element to a tensor. try: diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index b2d33c8247ef48..43fb07ca22fa36 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -15,11 +15,12 @@ """Tests for utilities working with arbitrarily nested structures.""" import collections +import dataclasses import functools +from absl.testing import parameterized import numpy as np import wrapt -from absl.testing import parameterized from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops @@ -30,8 +31,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables @@ -50,7 +51,7 @@ def _test_flat_structure_combinations(): cases = [ ("Tensor", lambda: constant_op.constant(37.0), - lambda: tensor_spec.TensorSpec, lambda: [dtypes.float32], lambda: [[]]), + lambda: tensor.TensorSpec, lambda: [dtypes.float32], lambda: [[]]), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), lambda: tensor_array_ops.TensorArraySpec, lambda: [dtypes.variant], @@ -336,8 +337,8 @@ def reduce_fn(x, y): def _test_convert_legacy_structure_combinations(): cases = [ - (dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor, - tensor_spec.TensorSpec([], dtypes.float32)), + (dtypes.float32, tensor_shape.TensorShape([]), tensor.Tensor, + tensor.TensorSpec([], dtypes.float32)), (dtypes.int32, tensor_shape.TensorShape([2, 2]), sparse_tensor.SparseTensor, sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)), @@ -369,13 +370,13 @@ def _test_convert_legacy_structure_combinations(): "a": tensor_shape.TensorShape([]), "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([])) }, { - "a": ops.Tensor, - "b": (sparse_tensor.SparseTensor, ops.Tensor) + "a": tensor.Tensor, + "b": (sparse_tensor.SparseTensor, tensor.Tensor) }, { "a": - tensor_spec.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.string)) + tensor.TensorSpec([], dtypes.string)) }) ] @@ -392,10 +393,10 @@ def reduce_fn(x, y): def _test_batch_combinations(): cases = [ - (tensor_spec.TensorSpec([], dtypes.float32), 32, - tensor_spec.TensorSpec([32], dtypes.float32)), - (tensor_spec.TensorSpec([], dtypes.float32), None, - tensor_spec.TensorSpec([None], dtypes.float32)), + (tensor.TensorSpec([], dtypes.float32), 32, + tensor.TensorSpec([32], dtypes.float32)), + (tensor.TensorSpec([], dtypes.float32), None, + tensor.TensorSpec([None], dtypes.float32)), (sparse_tensor.SparseTensorSpec([None], dtypes.float32), 32, sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)), (sparse_tensor.SparseTensorSpec([4], dtypes.float32), None, @@ -406,14 +407,14 @@ def _test_batch_combinations(): ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)), ({ "a": - tensor_spec.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.string)) + tensor.TensorSpec([], dtypes.string)) }, 128, { "a": - tensor_spec.TensorSpec([128], dtypes.float32), + tensor.TensorSpec([128], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), - tensor_spec.TensorSpec([128], dtypes.string)) + tensor.TensorSpec([128], dtypes.string)) }), ] @@ -429,10 +430,10 @@ def reduce_fn(x, y): def _test_unbatch_combinations(): cases = [ - (tensor_spec.TensorSpec([32], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.float32)), - (tensor_spec.TensorSpec([None], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.float32)), + (tensor.TensorSpec([32], dtypes.float32), + tensor.TensorSpec([], dtypes.float32)), + (tensor.TensorSpec([None], dtypes.float32), + tensor.TensorSpec([], dtypes.float32)), (sparse_tensor.SparseTensorSpec([32, None], dtypes.float32), sparse_tensor.SparseTensorSpec([None], dtypes.float32)), (sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32), @@ -443,14 +444,14 @@ def _test_unbatch_combinations(): ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)), ({ "a": - tensor_spec.TensorSpec([128], dtypes.float32), + tensor.TensorSpec([128], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), - tensor_spec.TensorSpec([None], dtypes.string)) + tensor.TensorSpec([None], dtypes.string)) }, { "a": - tensor_spec.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.string)) + tensor.TensorSpec([], dtypes.string)) }), ] @@ -493,6 +494,22 @@ def reduce_fn(x, y): return functools.reduce(reduce_fn, cases, []) +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: ops.Tensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + def __tf_unflatten__(self, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + # TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure. class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): @@ -838,9 +855,9 @@ def testConvertLegacyStructureFail(self): @combinations.generate(test_base.default_test_combinations()) def testNestedNestedStructure(self): - s = (tensor_spec.TensorSpec([], dtypes.int64), - (tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.string))) + s = (tensor.TensorSpec([], dtypes.int64), + (tensor.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.string))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) @@ -917,7 +934,7 @@ def testToBatchedTensorList(self, value_fn, element_0_fn): def testDatasetSpecConstructor(self): rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32) st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32) - t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string) + t_spec = tensor.TensorSpec([10, 8], dtypes.string) element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec} ds_struct = dataset_ops.DatasetSpec(element_spec, [5]) self.assertEqual(ds_struct._element_spec, element_spec) @@ -929,7 +946,7 @@ def testCustomMapping(self): elem = CustomMap(foo=constant_op.constant(37.)) spec = structure.type_spec_from_value(elem) self.assertIsInstance(spec, CustomMap) - self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32)) + self.assertEqual(spec["foo"], tensor.TensorSpec([], dtypes.float32)) @combinations.generate(test_base.default_test_combinations()) def testObjectProxy(self): @@ -955,6 +972,23 @@ def testTypeSpecNotCompatible(self): self.assertEqual(test_obj, test_obj.most_specific_compatible_shape(test_obj)) + @combinations.generate(test_base.default_test_combinations()) + def testDataclasses(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + + mt_type_spec = structure.type_spec_from_value(mt) + self.assertEqual(mt_type_spec.mask, mt.mask) + self.assertEqual( + mt_type_spec.value, structure.type_spec_from_value(mt.value) + ) + + mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) + mt3 = MaskedTensor(mask=False, value=constant_op.constant([1])) + mt2_type_spec = structure.type_spec_from_value(mt2) + mt3_type_spec = structure.type_spec_from_value(mt3) + self.assertEqual(mt_type_spec, mt2_type_spec) + self.assertNotEqual(mt_type_spec, mt3_type_spec) + class CustomMap(collections_abc.Mapping): """Custom, immutable map.""" diff --git a/tensorflow/python/debug/cli/BUILD b/tensorflow/python/debug/cli/BUILD index 98f8badc98f18a..f3b8a53513de72 100644 --- a/tensorflow/python/debug/cli/BUILD +++ b/tensorflow/python/debug/cli/BUILD @@ -78,7 +78,8 @@ py_strict_library( ":debugger_cli_common", ":tensor_format", "//tensorflow/python/debug/lib:common", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:gfile", "//third_party/py/numpy", diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index 3c5e21a2ff0bac..69466d446c1469 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -22,6 +22,7 @@ from tensorflow.python.debug.cli import tensor_format from tensorflow.python.debug.lib import common from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -402,7 +403,9 @@ def get_run_short_description(run_call_count, description = "run #%d: " % run_call_count - if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): + if isinstance( + fetches, (tensor_lib.Tensor, ops.Operation, variables.Variable) + ): description += "1 fetch (%s); " % common.get_graph_element_name(fetches) else: # Could be (nested) list, tuple, dict or namedtuple. diff --git a/tensorflow/python/debug/lib/BUILD b/tensorflow/python/debug/lib/BUILD index c6cdfb38552ccd..37c99b30dd2056 100644 --- a/tensorflow/python/debug/lib/BUILD +++ b/tensorflow/python/debug/lib/BUILD @@ -149,6 +149,7 @@ py_strict_library( ":debug_data", ":debug_graphs", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops_gen", "//tensorflow/python/ops:variables", ], @@ -418,6 +419,7 @@ cuda_py_strict_test( "//tensorflow/core:protos_all_py", "//tensorflow/python/client:session", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/lib/io:lib", "//tensorflow/python/ops:gradients_impl", diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py index 8d202c9e1a0e9b..529264b2c5525f 100644 --- a/tensorflow/python/debug/lib/debug_gradients.py +++ b/tensorflow/python/debug/lib/debug_gradients.py @@ -20,6 +20,7 @@ from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_graphs from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import variables @@ -332,7 +333,7 @@ def gradient_tensors(self): return self._gradient_tensors def _get_tensor_name(self, tensor): - if isinstance(tensor, (ops.Tensor, variables.Variable)): + if isinstance(tensor, (tensor_lib.Tensor, variables.Variable)): return tensor.name elif isinstance(tensor, str): return tensor diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py index f84be38e0451e8..a3321be710dd8e 100644 --- a/tensorflow/python/debug/lib/debug_gradients_test.py +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -23,6 +23,7 @@ from tensorflow.python.debug.lib import debug_gradients from tensorflow.python.debug.lib import debug_utils from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gradients_impl @@ -68,17 +69,17 @@ def testIdentifyGradientGivesCorrectTensorObjectWithoutContextManager(self): # Fetch the gradient tensor with the x-tensor object. w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) # Fetch the gradient tensor with the x-tensor's name. w_grad = grad_debugger.gradient_tensor(self.w.name) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) # Fetch the gradient tensor with the x-tensor name. w_grad = grad_debugger.gradient_tensor(self.w.name) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testIdentifyGradientGivesCorrectTensorObjectWithTfGradients(self): @@ -99,17 +100,17 @@ def testIdentifyGradientGivesCorrectTensorObjectWithTfGradients(self): # Fetch the gradient tensor with the x-tensor object. w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) # Fetch the gradient tensor with the x-tensor's name. w_grad = grad_debugger.gradient_tensor(self.w.name) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) # Fetch the gradient tensor with the x-tensor name. w_grad = grad_debugger.gradient_tensor(self.w.name) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self): @@ -137,8 +138,8 @@ def testIdentifyGradientWorksOnMultipleLosses(self): dz1_dy = grad_debugger_1.gradient_tensor(y) dz2_dy = grad_debugger_2.gradient_tensor(y) - self.assertIsInstance(dz1_dy, ops.Tensor) - self.assertIsInstance(dz2_dy, ops.Tensor) + self.assertIsInstance(dz1_dy, tensor.Tensor) + self.assertIsInstance(dz2_dy, tensor.Tensor) self.assertIsNot(dz1_dy, dz2_dy) self.sess.run(variables.global_variables_initializer()) @@ -187,7 +188,7 @@ def testIdentifyGradientTensorWorksWithGradientDescentOptimizer(self): # Fetch the gradient tensor with the x-tensor object. w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testWatchGradientsByXTensorNamesWorks(self): @@ -209,11 +210,11 @@ def testWatchGradientsByXTensorNamesWorks(self): self.assertAllClose(2.0, self.sess.run(v_grad)) w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) w_grad = grad_debugger.gradient_tensor("w:0") - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testWatchGradientsByXTensorNamesWorksWithoutContextManager(self): @@ -235,11 +236,11 @@ def testWatchGradientsByXTensorNamesWorksWithoutContextManager(self): self.assertAllClose(2.0, self.sess.run(v_grad)) w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) w_grad = grad_debugger.gradient_tensor("w:0") - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testWatchGradientsWorksOnRefTensor(self): @@ -272,7 +273,7 @@ def testWatchGradientsWorksOnMultipleTensors(self): self.assertEqual(2, len(grad_debugger.gradient_tensors())) self.assertIs(u_grad, grad_debugger.gradient_tensor("u:0")) - self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), ops.Tensor) + self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), tensor.Tensor) self.sess.run(variables.global_variables_initializer()) self.assertAllClose(1.0, self.sess.run( @@ -317,8 +318,8 @@ def testWatchGradientsByTensorCanWorkOnMultipleLosses(self): dz1_dy = grad_debugger_1.gradient_tensor(y) dz2_dy = grad_debugger_2.gradient_tensor(y) - self.assertIsInstance(dz1_dy, ops.Tensor) - self.assertIsInstance(dz2_dy, ops.Tensor) + self.assertIsInstance(dz1_dy, tensor.Tensor) + self.assertIsInstance(dz2_dy, tensor.Tensor) self.assertIsNot(dz1_dy, dz2_dy) self.sess.run(variables.global_variables_initializer()) diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py index a1492f3e5e1fd7..e14363c623a6b6 100644 --- a/tensorflow/python/debug/wrappers/framework_test.py +++ b/tensorflow/python/debug/wrappers/framework_test.py @@ -320,7 +320,7 @@ def testUsingWrappedSessionShouldWorkAsContextManager(self): with wrapper as sess: self.assertAllClose([[3.0], [4.0]], self._s) self.assertEqual(1, self._observer["on_run_start_count"]) - self.assertEqual(self._s, self._observer["run_fetches"]) + self.assertEqual([self._s], self._observer["run_fetches"]) self.assertEqual(1, self._observer["on_run_end_count"]) self.assertAllClose( @@ -337,7 +337,7 @@ def testUsingWrappedSessionShouldSupportEvalWithAsDefault(self): with wrapper.as_default(): foo = constant_op.constant(42, name="foo") self.assertEqual(42, self.evaluate(foo)) - self.assertEqual(foo, self._observer["run_fetches"]) + self.assertEqual([foo], self._observer["run_fetches"]) def testWrapperShouldSupportSessionClose(self): wrapper = TestDebugWrapperSession(self._sess, self._dump_root, diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 9459c5c8416b65..1886daa0638b90 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -48,6 +48,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:kernels", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -158,6 +159,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", @@ -172,7 +174,6 @@ py_strict_library( "//tensorflow/python/trackable:base", "//tensorflow/python/types:distribute", "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:lazy_loader", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_decorator", "//tensorflow/python/util:tf_export", @@ -894,7 +895,7 @@ py_strict_library( srcs_version = "PY3", deps = [ ":distribute_lib", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", ], ) @@ -929,6 +930,7 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", @@ -979,6 +981,7 @@ distribute_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/ops:variables", @@ -999,8 +1002,8 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:handle_data_util", "//tensorflow/python/ops:lookup_ops", @@ -1325,6 +1328,7 @@ distribute_py_strict_test( "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", @@ -1423,6 +1427,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:collective_ops", "//tensorflow/python/ops:cond", @@ -1462,6 +1467,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:type_spec", @@ -1600,6 +1606,7 @@ distribute_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -2003,6 +2010,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:test_lib", @@ -2046,6 +2054,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:custom_gradient", "//tensorflow/python/ops:math_ops", @@ -2159,6 +2168,11 @@ cuda_py_strict_test( tpu_py_strict_test( name = "collective_all_reduce_strategy_test_tpu", srcs = ["collective_all_reduce_strategy_test.py"], + # copybara:uncomment_begin + # args = [ + # "--tpu_use_tfrt=false", #TODO(b/227404010): Remove once the bug is fixed. + # ], + # copybara:uncomment_end # FIXME(b/227404010): On TFRT TPU, eager CollectiveReduceV2 is broken. disable_tfrt = True, main = "collective_all_reduce_strategy_test.py", @@ -2240,6 +2254,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", @@ -2454,6 +2469,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:config", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/util:nest", @@ -2575,6 +2591,7 @@ distribute_py_strict_test( "multi_and_single_gpu", "no_oss", # TODO(b/249822228) "noasan", # TODO(b/237407459) + "nomsan", # TODO(b/290745680) "notpu", "notsan", # Tsan failure doesn't seem to be caused by TF. ], diff --git a/tensorflow/python/distribute/coordinator/BUILD b/tensorflow/python/distribute/coordinator/BUILD index aa9f71c740449f..a2dc307a546684 100644 --- a/tensorflow/python/distribute/coordinator/BUILD +++ b/tensorflow/python/distribute/coordinator/BUILD @@ -96,6 +96,7 @@ distribute_py_strict_test( "notpu", "notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards ], + xla_enable_strict_auto_jit = False, # TODO(b/291174864) xla_tags = [ "no_cuda_asan", # Race condition on async test ], @@ -152,6 +153,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py index 879c9e24921e4c..ca4bbbc2d8c2c1 100644 --- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py @@ -766,11 +766,16 @@ def _log_ps_failure_and_raise(self, e, ps_index): raise PSUnavailableError(e) def _get_task_states(self): + """Get task states and reset to None if coordination service is down.""" try: self._task_states = context.context().get_task_states( [("worker", self._num_workers), ("ps", self._num_ps)] ) - except errors.UnavailableError: + except (errors.UnavailableError, errors.InternalError) as e: + if isinstance( + e, errors.InternalError + ) and "coordination service is not enabled" not in str(e).lower(): + raise # Coordination service is down self._task_states = None with self._next_task_state_cond: diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test_base.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test_base.py index 510f04f29bbb31..ca68282ca0ba3f 100644 --- a/tensorflow/python/distribute/coordinator/fault_tolerance_test_base.py +++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test_base.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops @@ -540,14 +541,14 @@ def worker_fn(): # Attempt to fetch before killing worker task should succeed. fetched = remote_value.get()[0] - self.assertIsInstance(fetched, ops.Tensor) + self.assertIsInstance(fetched, tensor.Tensor) self.assertEqual(fetched.device, "/job:chief/replica:0/task:0/device:CPU:0") self.assertEqual((1, -1), remote_value.get()) remote_value.get()[0].numpy() # As well as the remote tensors that point to worker0 or worker1. values = remote_value._values[0] - self.assertIsInstance(values, ops.Tensor) + self.assertIsInstance(values, tensor.Tensor) self.assertRegex(values.device, "/job:worker/replica:0/task:[0-1]/device:CPU:0") self.assertEqual((1, -1), remote_value._values) diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index 0e60c337a8dbfd..002cb1d41e8070 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import kernels from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -59,7 +60,8 @@ def check_destinations(destinations): """ # Calling bool() on a ResourceVariable is not allowed. if isinstance(destinations, - (resource_variable_ops.BaseResourceVariable, ops.Tensor)): + (resource_variable_ops.BaseResourceVariable, + tensor_lib.Tensor)): return bool(destinations.device) return bool(destinations) @@ -68,9 +70,9 @@ def validate_destinations(destinations): """Validates the `destination` is one of expected types.""" if not isinstance( destinations, - (value_lib.DistributedValues, ops.Tensor, indexed_slices.IndexedSlices, - ps_values.AggregatingVariable, six.string_types, - tpu_values.TPUMirroredVariable + (value_lib.DistributedValues, tensor_lib.Tensor, + indexed_slices.IndexedSlices, ps_values.AggregatingVariable, + six.string_types, tpu_values.TPUMirroredVariable )) and not resource_variable_ops.is_resource_variable(destinations): raise ValueError("destinations must be one of a `DistributedValues` object," " a tf.Variable object, or a device string.") diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index dc0c5aad701ba9..dca6886ba25619 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.ops import cond @@ -204,7 +205,7 @@ def as_list(self, value): Returns: A list of `Tensor` or `IndexedSlices`. """ - if isinstance(value, ops.Tensor): + if isinstance(value, tensor_lib.Tensor): return [value] elif isinstance(value, IndexedSlices): return [value] diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 66d1bb2ac4b1ff..4bd0cb86e345ed 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -214,6 +214,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -3926,7 +3927,7 @@ class ReplicaContextV1(ReplicaContextBase): def _batch_reduce_destination(x): """Returns the destinations for batch all-reduce.""" - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): # If this is a one device strategy. return x.device else: diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 38b8c95529af1a..9b999e882184ee 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -47,6 +47,7 @@ from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util as framework_test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1982,7 +1983,7 @@ def testDistributeDatasetFromFunctionNested(self, distribution): num_replicas_in_sync=num_workers)) class InnerType(extension_type.ExtensionType): - tensor: ops.Tensor + tensor: tensor.Tensor class OuterType(extension_type.ExtensionType): inner: InnerType diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 39912b62cb4b55..584565e38b7b5d 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -46,6 +46,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util as util @@ -1548,14 +1549,14 @@ def f(): def _replica_id(): replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group - if not isinstance(replica_id, ops.Tensor): + if not isinstance(replica_id, tensor_lib.Tensor): replica_id = constant_op.constant(replica_id) return array_ops.identity(replica_id) def _replica_id_as_int(): replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group - if isinstance(replica_id, ops.Tensor): + if isinstance(replica_id, tensor_lib.Tensor): replica_id = tensor_util.constant_value(replica_id) return replica_id diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py index 169ba5211397c0..89bae84219418c 100644 --- a/tensorflow/python/distribute/mirrored_variable_test.py +++ b/tensorflow/python/distribute/mirrored_variable_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import math_ops @@ -46,7 +47,7 @@ def _replica_id(): replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group - if not isinstance(replica_id, ops.Tensor): + if not isinstance(replica_id, tensor.Tensor): replica_id = constant_op.constant(replica_id) return replica_id diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD index 3341e9557669f8..de6f9a3220a08d 100644 --- a/tensorflow/python/distribute/parallel_device/BUILD +++ b/tensorflow/python/distribute/parallel_device/BUILD @@ -27,6 +27,7 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/tpu/ops", diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py index c3255c57aba9b2..771925efa2372c 100644 --- a/tensorflow/python/distribute/parallel_device/parallel_device.py +++ b/tensorflow/python/distribute/parallel_device/parallel_device.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.tpu.ops import tpu_ops @@ -83,8 +84,14 @@ def __init__(self, components): def _pack_tensor(self, *tensors): """Helper to pack plain-old-tensors, not structures or composites.""" for tensor in tensors: - if not isinstance(tensor, (ops.Tensor, composite_tensor.CompositeTensor, - variables.Variable)): + if not isinstance( + tensor, + ( + tensor_lib.Tensor, + composite_tensor.CompositeTensor, + variables.Variable, + ), + ): raise ValueError( ("Every component to pack onto the ParallelDevice must already be " "a tensor, got {}. Consider running `tf.constant` or " @@ -129,10 +136,15 @@ def pack(self, tensors): def _unpack_tensor(self, parallel_tensor): """Helper to unpack a single tensor.""" - if not isinstance(parallel_tensor, ( - ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)): - raise ValueError( - "Expected a tensor, got {}.".format(parallel_tensor)) + if not isinstance( + parallel_tensor, + ( + tensor_lib.Tensor, + composite_tensor.CompositeTensor, + variables.Variable, + ), + ): + raise ValueError("Expected a tensor, got {}.".format(parallel_tensor)) with ops.device(self._name): return tpu_ops.tpu_replicated_output( parallel_tensor, num_replicas=len(self.components)) diff --git a/tensorflow/python/distribute/parameter_server_strategy_test.py b/tensorflow/python/distribute/parameter_server_strategy_test.py index d1818933264f52..3ce0aaa3e2d279 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_test.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -60,7 +61,7 @@ def _get_replica_id_integer(): replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group - if isinstance(replica_id, ops.Tensor): + if isinstance(replica_id, tensor.Tensor): replica_id = tensor_util.constant_value(replica_id) return replica_id diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index 2c39613505a27c..8bf08b2dd8efdc 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -824,7 +824,11 @@ def _create_variable_round_robin(self, next_creator, **kwargs): with ops.device("/job:ps/task:%d/device:CPU:0" % (self._variable_count % self._num_ps)): var = next_creator(**kwargs) - logging.debug( + log_method = ( + logging.info if os.getenv("TF_PSS_VERBOSE_VARIABLE_PLACEMENT") + else logging.debug + ) + log_method( "Creating variable (name:%s, shape:%r) on " "/job:ps/task:%d/device:CPU:0", var.name, var.shape, (self._variable_count % self._num_ps)) diff --git a/tensorflow/python/distribute/ps_values.py b/tensorflow/python/distribute/ps_values.py index 3a866516b2faa1..73b49c8937fc3e 100644 --- a/tensorflow/python/distribute/ps_values.py +++ b/tensorflow/python/distribute/ps_values.py @@ -30,8 +30,8 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion_registry -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import handle_data_util from tensorflow.python.ops import lookup_ops @@ -494,7 +494,7 @@ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): @classmethod def _overload_overloadable_operators(cls): """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor.Tensor.OVERLOADABLE_OPERATORS: # Overloading __eq__ or __ne__ does not work as expected. if operator == "__eq__" or operator == "__ne__": continue @@ -502,8 +502,8 @@ def _overload_overloadable_operators(cls): @classmethod def _tensor_overload_operator(cls, operator): - """Delegate an operator overload to `ops.Tensor`.""" - tensor_operator = getattr(ops.Tensor, operator) + """Delegate an operator overload to `tensor.Tensor`.""" + tensor_operator = getattr(tensor.Tensor, operator) def _operator(v, *args, **kwargs): return tensor_operator(v.value(), *args, **kwargs) # pylint: disable=protected-access @@ -655,7 +655,7 @@ def read_all(self): return [wv.get() for wv in self._per_worker_vars._values] # pylint: disable=protected-access -class PerWorkerVariableSpec(tensor_spec.TensorSpec): +class PerWorkerVariableSpec(tensor.TensorSpec): def __init__(self, value=None, name=None): super().__init__(value.shape, value.dtype, name=name) self._value = value @@ -745,7 +745,7 @@ def closure(): else: return self._coordinator_instance.resource_handle - return closure, tensor_spec.TensorSpec([], dtype=dtypes.resource) + return closure, tensor.TensorSpec([], dtype=dtypes.resource) def _maybe_build_distributed_table(self): """Create table objects and resources on each worker if hasn't been created.""" @@ -871,7 +871,7 @@ def closure(): return self._coordinator_instance.resource_handle - return closure, tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource) + return closure, tensor.TensorSpec(shape=(), dtype=dtypes.resource) def __setattr__(self, name, value): if name in TRACKABLE_RESOURCE_METHODS: diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py index f65677920f4579..68c05c4d867fb3 100644 --- a/tensorflow/python/distribute/sharded_variable.py +++ b/tensorflow/python/distribute/sharded_variable.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices as indexed_slices_lib from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import type_spec @@ -93,7 +94,6 @@ class FixedShardsPartitioner(Partitioner): >>> # use in ParameterServerStrategy >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) - """ def __init__(self, num_shards): @@ -134,10 +134,9 @@ class MinSizePartitioner(Partitioner): >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) """ - def __init__(self, - min_shard_bytes=256 << 10, - max_shards=1, - bytes_per_string=16): + def __init__( + self, min_shard_bytes=256 << 10, max_shards=1, bytes_per_string=16 + ): """Creates a new `MinSizePartitioner`. Args: @@ -147,14 +146,19 @@ def __init__(self, an estimate of how large each string is. """ if min_shard_bytes < 1: - raise ValueError('Argument `min_shard_bytes` must be positive. ' - f'Received: {min_shard_bytes}') + raise ValueError( + 'Argument `min_shard_bytes` must be positive. ' + f'Received: {min_shard_bytes}' + ) if max_shards < 1: - raise ValueError('Argument `max_shards` must be positive. ' - f'Received: {max_shards}') + raise ValueError( + f'Argument `max_shards` must be positive. Received: {max_shards}' + ) if bytes_per_string < 1: - raise ValueError('Argument `bytes_per_string` must be positive. ' - f'Received: {bytes_per_string}') + raise ValueError( + 'Argument `bytes_per_string` must be positive. ' + f'Received: {bytes_per_string}' + ) self._min_shard_bytes = min_shard_bytes self._max_shards = max_shards self._bytes_per_string = bytes_per_string @@ -164,7 +168,8 @@ def __call__(self, shape, dtype, axis=0): max_partitions=self._max_shards, axis=axis, min_slice_size=self._min_shard_bytes, - bytes_per_string_element=self._bytes_per_string)(shape, dtype) + bytes_per_string_element=self._bytes_per_string, + )(shape, dtype) @tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[]) @@ -207,14 +212,19 @@ def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16): an estimate of how large each string is. """ if max_shard_bytes < 1: - raise ValueError('Argument `max_shard_bytes` must be positive. ' - f'Received {max_shard_bytes}') + raise ValueError( + 'Argument `max_shard_bytes` must be positive. ' + f'Received {max_shard_bytes}' + ) if max_shards and max_shards < 1: - raise ValueError('Argument `max_shards` must be positive. ' - f'Received {max_shards}') + raise ValueError( + f'Argument `max_shards` must be positive. Received {max_shards}' + ) if bytes_per_string < 1: - raise ValueError('Argument `bytes_per_string` must be positive. ' - f'Received: {bytes_per_string}') + raise ValueError( + 'Argument `bytes_per_string` must be positive. ' + f'Received: {bytes_per_string}' + ) self._max_shard_bytes = max_shard_bytes self._max_shards = max_shards @@ -225,7 +235,8 @@ def __call__(self, shape, dtype, axis=0): max_shard_bytes=self._max_shard_bytes, max_shards=self._max_shards, bytes_per_string_element=self._bytes_per_string, - axis=axis)(shape, dtype) + axis=axis, + )(shape, dtype) class ShardedVariableSpec(type_spec.TypeSpec): @@ -264,7 +275,6 @@ class ShardedVariableMixin(trackable.Trackable): def __init__(self, variables, name='ShardedVariable'): """Treats `variables` as shards of a larger Variable. - Example: ``` @@ -287,16 +297,22 @@ def __init__(self, variables, name='ShardedVariable'): self._variables = variables self._name = name - if not isinstance(variables, Sequence) or not variables or any( - not isinstance(v, variables_lib.Variable) for v in variables): - raise TypeError('Argument `variables` should be a non-empty list of ' - f'`variables.Variable`s. Received {variables}') + if ( + not isinstance(variables, Sequence) + or not variables + or any(not isinstance(v, variables_lib.Variable) for v in variables) + ): + raise TypeError( + 'Argument `variables` should be a non-empty list of ' + f'`variables.Variable`s. Received {variables}' + ) var_dtypes = {v.dtype for v in variables} if len(var_dtypes) > 1: raise ValueError( 'All elements in argument `variables` must have the same dtype. ' - f'Received dtypes: {[v.dtype for v in variables]}') + f'Received dtypes: {[v.dtype for v in variables]}' + ) first_var = variables[0] self._dtype = first_var.dtype @@ -307,10 +323,12 @@ def __init__(self, variables, name='ShardedVariable'): raise ValueError( 'All elements in argument `variables` must have the same shapes ' 'except for the first axis. ' - f'Received shapes: {[v.shape for v in variables]}') + f'Received shapes: {[v.shape for v in variables]}' + ) first_dim = sum(int(v.shape.as_list()[0]) for v in variables) - self._shape = tensor_shape.TensorShape([first_dim] + - first_var.shape.as_list()[1:]) + self._shape = tensor_shape.TensorShape( + [first_dim] + first_var.shape.as_list()[1:] + ) for v in variables: v._sharded_container = weakref.ref(self) @@ -321,7 +339,8 @@ def __init__(self, variables, name='ShardedVariable'): for i in range(1, len(variables)): # Always partition on the first axis. Offsets on other axes are 0. self._var_offsets[i][0] += ( - self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0]) + self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0] + ) save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access if any(slice_info is not None for slice_info in save_slice_info): @@ -329,16 +348,20 @@ def __init__(self, variables, name='ShardedVariable'): '`SaveSliceInfo` should not be set for all elements in argument ' '`variables`. `ShardedVariable` will infer `SaveSliceInfo` according ' 'to the order of the elements `variables`. ' - f'Received save slice info {save_slice_info}') + f'Received save slice info {save_slice_info}' + ) # We create an uninitialized saving_variable with the full shape, which can # be later captured in signatures so that the signatures can treat this # ShardedVariable as one single variable. self._saving_variable = resource_variable_ops.UninitializedVariable( - shape=self._shape, dtype=self._dtype, name=self._name, + shape=self._shape, + dtype=self._dtype, + name=self._name, trainable=self._variables[0].trainable, synchronization=variables_lib.VariableSynchronization.NONE, - aggregation=variables_lib.VariableAggregation.NONE) + aggregation=variables_lib.VariableAggregation.NONE, + ) def __iter__(self): """Return an iterable for accessing the underlying sharded variables.""" @@ -365,9 +388,14 @@ def __getitem__(self, slice_spec): # TODO(b/177482728): Support tensor input. # TODO(b/177482728): Support slice assign, similar to variable slice assign. - if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and - slice_spec.dtype == dtypes.bool) or - (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)): + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool + ) + or (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool) + ): tensor = _var_to_tensor(self) return array_ops.boolean_mask(tensor=tensor, mask=slice_spec) @@ -385,30 +413,36 @@ def __getitem__(self, slice_spec): if s.step is not None and s.step < 0: values.reverse() if not values: - return constant_op.constant([], - dtype=self._dtype, - shape=((0,) + self._shape[1:])) + return constant_op.constant( + [], dtype=self._dtype, shape=((0,) + self._shape[1:]) + ) return array_ops.concat(values, axis=0) elif s is Ellipsis: - return array_ops.concat([var[slice_spec] for var in self._variables], - axis=0) + return array_ops.concat( + [var[slice_spec] for var in self._variables], axis=0 + ) elif s is array_ops.newaxis: - return array_ops.concat([var[slice_spec[1:]] for var in self._variables], - axis=0)[array_ops.newaxis] + return array_ops.concat( + [var[slice_spec[1:]] for var in self._variables], axis=0 + )[array_ops.newaxis] else: - if isinstance(s, ops.Tensor): + if isinstance(s, tensor_lib.Tensor): raise TypeError( - 'ShardedVariable: using Tensor for indexing is not allowed.') + 'ShardedVariable: using Tensor for indexing is not allowed.' + ) if s < 0: s += self._shape[0] if s < 0 or s >= self._shape[0]: raise IndexError( - f'ShardedVariable: slice index {s} of dimension 0 out of bounds.') + f'ShardedVariable: slice index {s} of dimension 0 out of bounds.' + ) for i in range(len(self._variables)): - if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and - s < self._var_offsets[i + 1][0]): - return self._variables[i][(s - self._var_offsets[i][0],) + - slice_spec[1:]] + if i == len(self._variables) - 1 or ( + s > self._var_offsets[i][0] and s < self._var_offsets[i + 1][0] + ): + return self._variables[i][ + (s - self._var_offsets[i][0],) + slice_spec[1:] + ] def _decompose_slice_spec(self, slice_spec): """Decompose a global slice_spec into a list of per-variable slice_spec. @@ -441,11 +475,15 @@ def _decompose_slice_spec(self, slice_spec): v1[returned[1]] = [5] v2[returned[2]] = [9, 7] """ - if isinstance(slice_spec.start, ops.Tensor) or isinstance( - slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor): + if ( + isinstance(slice_spec.start, tensor_lib.Tensor) + or isinstance(slice_spec.stop, tensor_lib.Tensor) + or isinstance(slice_spec.step, tensor_lib.Tensor) + ): raise TypeError( 'ShardedVariable: using Tensor in slice_spec is not allowed. Please ' - 'file a feature request with the TensorFlow team.') + 'file a feature request with the TensorFlow team.' + ) result = [] # Normalize start, end and stop. @@ -479,7 +517,9 @@ def _decompose_slice_spec(self, slice_spec): var_start = self._var_offsets[i][0] var_end = ( self._var_offsets[i + 1][0] - if i < len(self._var_offsets) - 1 else self._shape[0]) + if i < len(self._var_offsets) - 1 + else self._shape[0] + ) if cur < var_start: cur += slice_step * int(math.ceil((var_start - cur) / slice_step)) if cur >= var_end or cur >= slice_end: @@ -493,7 +533,9 @@ def _decompose_slice_spec(self, slice_spec): var_start = self._var_offsets[i][0] var_end = ( self._var_offsets[i + 1][0] - if i < len(self._var_offsets) - 1 else self._shape[0]) + if i < len(self._var_offsets) - 1 + else self._shape[0] + ) if cur >= var_end: cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step)) if cur < var_start or cur <= slice_end: @@ -513,8 +555,11 @@ def _decompose_slice_spec(self, slice_spec): @property def _type_spec(self): return ShardedVariableSpec( - *(resource_variable_ops.VariableSpec(v.shape, v.dtype) - for v in self._variables)) + *( + resource_variable_ops.VariableSpec(v.shape, v.dtype) + for v in self._variables + ) + ) @property def variables(self): @@ -546,13 +591,15 @@ def assign(self, value, use_locking=None, name=None, read_value=True): def assign_add(self, delta, use_locking=False, name=None, read_value=True): for i, v in enumerate(self._variables): v.assign_add( - array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) + array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()) + ) return self def assign_sub(self, delta, use_locking=False, name=None, read_value=True): for i, v in enumerate(self._variables): v.assign_sub( - array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) + array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()) + ) return self def _decompose_indices(self, indices): @@ -560,7 +607,8 @@ def _decompose_indices(self, indices): if indices.shape.rank != 1: raise ValueError( 'ShardedVariable: indices must be 1D Tensor for sparse operations. ' - f'Received shape: {indices.shape}') + f'Received shape: {indices.shape}' + ) base = self._shape[0] // len(self._variables) extra = self._shape[0] % len(self._variables) @@ -573,7 +621,8 @@ def _decompose_indices(self, indices): if expect_first_dim != actual_first_dim: raise NotImplementedError( 'scater_xxx ops are not supported in ShardedVariale that does not ' - 'conform to "div" sharding') + 'conform to "div" sharding' + ) # For index that falls into the partition that has extra 1, assignment is # `index // (base + 1)` (no less than `(indices - extra) // base`) @@ -585,30 +634,35 @@ def _decompose_indices(self, indices): # base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32) # index = 10 -> partition_assigment = 0 # index = 22 -> partition_assiment = 2 - partition_assignments = math_ops.maximum(indices // (base + 1), - (indices - extra) // base) - local_indices = array_ops.where(partition_assignments < extra, - indices % (base + 1), - (indices - extra) % base) + partition_assignments = math_ops.maximum( + indices // (base + 1), (indices - extra) // base + ) + local_indices = array_ops.where( + partition_assignments < extra, + indices % (base + 1), + (indices - extra) % base, + ) # For whatever reason `dynamic_partition` only supports int32 partition_assignments = math_ops.cast(partition_assignments, dtypes.int32) - per_var_indices = data_flow_ops.dynamic_partition(local_indices, - partition_assignments, - len(self._variables)) + per_var_indices = data_flow_ops.dynamic_partition( + local_indices, partition_assignments, len(self._variables) + ) return per_var_indices, partition_assignments def _decompose_indexed_slices(self, indexed_slices): """Decompose a global `IndexedSlices` into a list of per-variable ones.""" per_var_indices, partition_assignments = self._decompose_indices( - indexed_slices.indices) - per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values, - partition_assignments, - len(self._variables)) + indexed_slices.indices + ) + per_var_values = data_flow_ops.dynamic_partition( + indexed_slices.values, partition_assignments, len(self._variables) + ) return [ indexed_slices_lib.IndexedSlices( - values=per_var_values[i], indices=per_var_indices[i]) + values=per_var_values[i], indices=per_var_indices[i] + ) for i in range(len(self._variables)) ] @@ -720,24 +774,32 @@ def _saveable_factory(name=self.name): full_name=self.name, full_shape=self.shape.as_list(), var_offset=copy.copy(var_offset), - var_shape=v.shape.as_list()) + var_shape=v.shape.as_list(), + ) saveables.append( saveable_object_util.ResourceVariableSaveable( - v, save_slice_info.spec, name)) + v, save_slice_info.spec, name + ) + ) var_offset[0] += int(v.shape[0]) return saveables return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} - def _export_to_saved_model_graph(self, object_map, tensor_map, - options, **kwargs): + def _export_to_saved_model_graph( + self, object_map, tensor_map, options, **kwargs + ): """For implementing `Trackable`.""" resource_list = [] for v in self._variables + [self._saving_variable]: - resource_list.extend(v._export_to_saved_model_graph( # pylint:disable=protected-access - object_map, tensor_map, options, **kwargs)) - object_map[self] = ShardedVariable([object_map[self._saving_variable]], - name=self.name) + resource_list.extend( + v._export_to_saved_model_graph( # pylint:disable=protected-access + object_map, tensor_map, options, **kwargs + ) + ) + object_map[self] = ShardedVariable( + [object_map[self._saving_variable]], name=self.name + ) return resource_list @property @@ -828,7 +890,7 @@ def _type_spec(self): @classmethod def _overload_all_operators(cls): """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: if operator == '__getitem__': continue @@ -836,16 +898,17 @@ def _overload_all_operators(cls): @classmethod def _overload_operator(cls, operator): - """Delegate an operator overload to `ops.Tensor`.""" - tensor_operator = getattr(ops.Tensor, operator) + """Delegate an operator overload to `tensor_lib.Tensor`.""" + tensor_operator = getattr(tensor_lib.Tensor, operator) def _operator(v, *args, **kwargs): return tensor_operator(_var_to_tensor(v), *args, **kwargs) setattr(cls, operator, _operator) - def __tf_experimental_restore_capture__(self, concrete_function, - internal_capture): + def __tf_experimental_restore_capture__( + self, concrete_function, internal_capture + ): # Avoid restoring captures for functions that use ShardedVariable - the # layer will be recreated during Keras model loading # TODO(jmullenbach): support loading models with ShardedVariables using @@ -858,7 +921,8 @@ def _should_act_as_resource_variable(self): def _write_object_proto(self, proto, options): resource_variable_ops.write_object_proto_for_resource_variable( - self._saving_variable, proto, options, enforce_naming=False) + self._saving_variable, proto, options, enforce_naming=False + ) def _var_to_tensor(var, dtype=None, name=None, as_ref=False): @@ -867,10 +931,12 @@ def _var_to_tensor(var, dtype=None, name=None, as_ref=False): if dtype is not None and not dtype.is_compatible_with(var.dtype): raise ValueError( 'Incompatible type conversion requested to type {!r} for variable ' - 'of type {!r}'.format(dtype.name, var.dtype.name)) + 'of type {!r}'.format(dtype.name, var.dtype.name) + ) if as_ref: raise NotImplementedError( - "ShardedVariable doesn't support being used as a reference.") + "ShardedVariable doesn't support being used as a reference." + ) # We use op dispatch mechanism to override embedding_lookup ops when called # with ShardedVariable. This requires embedding_lookup ops to raise TypeError # when called with ShardedVariable. However since ShardedVariable can be @@ -885,32 +951,42 @@ def _var_to_tensor(var, dtype=None, name=None, as_ref=False): # TODO(chenkai): Find a more robust way to do this, which should not rely # on namescope. if 'embedding_lookup' in ops.get_name_scope(): - raise TypeError('Converting ShardedVariable to tensor in embedding lookup' - ' ops is disallowed.') + raise TypeError( + 'Converting ShardedVariable to tensor in embedding lookup' + ' ops is disallowed.' + ) return array_ops.concat(var.variables, axis=0) # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. tensor_conversion_registry.register_tensor_conversion_function( - ShardedVariable, _var_to_tensor) + ShardedVariable, _var_to_tensor +) ShardedVariable._overload_all_operators() # pylint: disable=protected-access # Override the behavior of embedding_lookup(sharded_variable, ...) @dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable) -def embedding_lookup(params, - ids, - partition_strategy='mod', - name=None, - validate_indices=True, - max_norm=None): +def embedding_lookup( + params, + ids, + partition_strategy='mod', + name=None, + validate_indices=True, + max_norm=None, +): if isinstance(params, list): params = params[0] - return embedding_ops.embedding_lookup(params.variables, ids, - partition_strategy, name, - validate_indices, max_norm) + return embedding_ops.embedding_lookup( + params.variables, + ids, + partition_strategy, + name, + validate_indices, + max_norm, + ) # Separately override safe_embedding_lookup_sparse, to avoid conversion of @@ -937,4 +1013,5 @@ def safe_embedding_lookup_sparse( name=name, partition_strategy=partition_strategy, max_norm=max_norm, - allow_fast_lookup=allow_fast_lookup) + allow_fast_lookup=allow_fast_lookup, + ) diff --git a/tensorflow/python/distribute/summary_op_util.py b/tensorflow/python/distribute/summary_op_util.py index 59e619a871e388..7ccb6a181bd206 100644 --- a/tensorflow/python/distribute/summary_op_util.py +++ b/tensorflow/python/distribute/summary_op_util.py @@ -16,7 +16,7 @@ from tensorflow.python.distribute import distribute_lib -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util @@ -39,6 +39,6 @@ def skip_summary(): # TODO(b/118385803): when replica_id of _TPUReplicaContext is properly # initialized, remember to change here as well. replica_id = replica_context.replica_id_in_sync_group - if isinstance(replica_id, ops.Tensor): + if isinstance(replica_id, tensor.Tensor): replica_id = tensor_util.constant_value(replica_id) return replica_id and replica_id > 0 diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py index 866d86a4dd3f59..8dbf6c020c1242 100644 --- a/tensorflow/python/distribute/test_util.py +++ b/tensorflow/python/distribute/test_util.py @@ -32,6 +32,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import config from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.util import nest @@ -142,7 +143,7 @@ def _op_dependencies(op): """Returns the data and control dependencies of a tf.Operation combined.""" deps = [] for node in itertools.chain(op.inputs, op.control_inputs): - if isinstance(node, ops.Tensor): + if isinstance(node, tensor.Tensor): node = node.op assert isinstance(node, ops.Operation) deps.append(node) diff --git a/tensorflow/python/distribute/v1/BUILD b/tensorflow/python/distribute/v1/BUILD index e88086280e7677..59f19db8b4e11c 100644 --- a/tensorflow/python/distribute/v1/BUILD +++ b/tensorflow/python/distribute/v1/BUILD @@ -37,6 +37,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:kernels", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:collective_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/distribute/v1/cross_device_ops_test.py b/tensorflow/python/distribute/v1/cross_device_ops_test.py index fa59aba7f52c79..360cec0bd5ce71 100644 --- a/tensorflow/python/distribute/v1/cross_device_ops_test.py +++ b/tensorflow/python/distribute/v1/cross_device_ops_test.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import indexed_slices as indexed_slices_lib from tensorflow.python.framework import kernels from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.ops import math_ops @@ -52,7 +53,7 @@ def _get_devices(devices): return tuple(device_util.resolve(d) for d in devices) elif isinstance(devices, value_lib.DistributedValues): return devices._devices - elif isinstance(devices, ops.Tensor): + elif isinstance(devices, tensor_lib.Tensor): return (device_util.resolve(devices.device),) return (device_util.resolve(devices),) @@ -422,7 +423,7 @@ def testReduceDistributedVariable(self, distribution, else: result = cross_device_ops_instance.reduce(reduce_util.ReduceOp.MEAN, v, v) for v in result.values: - self.assertIsInstance(v, ops.Tensor) + self.assertIsInstance(v, tensor_lib.Tensor) self.evaluate(variables.global_variables_initializer()) self.assertAllEqual(self.evaluate(result.values), [1.0, 1.0]) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 5bf5e4ec52aca2..fa7456b19d4ea5 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec @@ -1040,7 +1041,7 @@ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): def __tf_tensor__(self, dtype: Optional[dtypes.DType] = None, - name: Optional[str] = None) -> ops.Tensor: + name: Optional[str] = None) -> tensor_lib.Tensor: return self._dense_var_to_tensor(dtype, name) def _export_to_saved_model_graph(self, diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index dd0add62d863a8..70cd0fb6a608a5 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -498,12 +499,12 @@ def testTensorConversion(self, distribution): _, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM, distribution) converted = ops.convert_to_tensor(replica_local, as_ref=False) - self.assertIsInstance(converted, ops.Tensor) + self.assertIsInstance(converted, tensor.Tensor) self.assertEqual(converted.dtype, replica_local.dtype) converted = ops.convert_to_tensor(replica_local, as_ref=True) # Resources variable are converted to tensors as well when as_ref is True. - self.assertIsInstance(converted, ops.Tensor) + self.assertIsInstance(converted, tensor.Tensor) self.assertEqual(converted.dtype, replica_local.dtype) @combinations.generate(combinations.combine( @@ -517,7 +518,7 @@ def testValueInCrossReplicaContext(self, distribution): value_list, replica_local = _make_replica_local( variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution) - self.assertIsInstance(replica_local.value(), ops.Tensor) + self.assertIsInstance(replica_local.value(), tensor.Tensor) self.assertEqual(self.evaluate(replica_local.value()), self.evaluate(value_list[0].value())) diff --git a/tensorflow/python/distribute/values_v2_test.py b/tensorflow/python/distribute/values_v2_test.py index e7dcd958fe3240..0e4c3298654b5e 100644 --- a/tensorflow/python/distribute/values_v2_test.py +++ b/tensorflow/python/distribute/values_v2_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables as variables_lib @@ -329,7 +330,7 @@ def testSlice(self): # ==== Begin ResourceVariable interface === def testHandle(self): v = self.create_variable() - self.assertIsInstance(v.handle, ops.Tensor) + self.assertIsInstance(v.handle, tensor.Tensor) self.assertEqual(v.handle.dtype, dtypes.resource) def testInGraphMode(self): diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 1b1f562cce21df..4bfe08f5fdb5ec 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -573,6 +573,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:nn_ops", @@ -686,6 +687,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", @@ -718,6 +720,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:handle_data_util", @@ -915,6 +918,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/layers", @@ -995,6 +999,7 @@ py_strict_library( deps = [ "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:op_selector", "//tensorflow/python/ops:resource_variable_ops", @@ -1039,8 +1044,8 @@ py_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/ops:variable_scope", diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 9c488d1b133526..16fc829ee6ca20 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec @@ -405,7 +406,7 @@ def decorated(*args, **kwds): def _ensure_unique_tensor_objects(parameter_positions, args): - """Make each of the parameter_positions in args a unique ops.Tensor object. + """Make each of the parameter_positions in args a unique tensor_lib.Tensor object. Ensure that each parameter is treated independently. For example: @@ -594,18 +595,18 @@ def _aggregate_grads(gradients): if len(gradients) == 1: return gradients[0] - if all(isinstance(g, ops.Tensor) for g in gradients): + if all(isinstance(g, tensor_lib.Tensor) for g in gradients): return gen_math_ops.add_n(gradients) else: assert all( - isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)) + isinstance(g, (tensor_lib.Tensor, indexed_slices.IndexedSlices)) for g in gradients) return backprop_util.AggregateIndexedSlicesGradients(gradients) def _num_elements(grad): """The number of elements in the `grad` tensor.""" - if isinstance(grad, ops.Tensor): + if isinstance(grad, tensor_lib.Tensor): shape_tuple = grad._shape_tuple() # pylint: disable=protected-access elif isinstance(grad, indexed_slices.IndexedSlices): shape_tuple = grad.values._shape_tuple() # pylint: disable=protected-access diff --git a/tensorflow/python/eager/backprop_util.py b/tensorflow/python/eager/backprop_util.py index b6509a48307a94..c4fe1158dc3c8e 100644 --- a/tensorflow/python/eager/backprop_util.py +++ b/tensorflow/python/eager/backprop_util.py @@ -19,6 +19,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import handle_data_util @@ -67,7 +68,7 @@ def IsTrainable(tensor_or_dtype): def FlattenNestedIndexedSlices(grad): assert isinstance(grad, indexed_slices.IndexedSlices) - if isinstance(grad.values, ops.Tensor): + if isinstance(grad.values, tensor_lib.Tensor): return grad else: assert isinstance(grad.values, indexed_slices.IndexedSlices) @@ -85,7 +86,7 @@ def AggregateIndexedSlicesGradients(grads): grads = [g for g in grads if g is not None] # If any gradient is a `Tensor`, sum them up and return a dense tensor # object. - if any(isinstance(g, ops.Tensor) for g in grads): + if any(isinstance(g, tensor_lib.Tensor) for g in grads): return math_ops.add_n(grads) # The following `_as_indexed_slices_list` casts ids of IndexedSlices into diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 64a2a6ad06d6ae..6e919d6deab965 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -91,7 +92,7 @@ def _test_hashable(self, a, b, hashable): set([a, b]) def testEquality(self): - default = ops.Tensor._USE_EQUALITY + default = tensor_lib.Tensor._USE_EQUALITY try: def _v1_check(a, b): @@ -113,20 +114,20 @@ def _v2_check(a, b): constant_a = constant_op.constant(1.0) constant_b = constant_op.constant(1.0) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() self._test_hashable(constant_a, constant_b, False) _v1_check(constant_a, constant_b) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() _v2_check(constant_a, constant_b) self._test_hashable(constant_a, constant_b, False) variable_a = variables.Variable(1.0) variable_b = variables.Variable(1.0) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() _v1_check(variable_a, variable_b) self._test_hashable(variable_a, variable_b, True) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() _v2_check(variable_a, variable_b) self._test_hashable(variable_a, variable_b, False) @@ -137,12 +138,12 @@ def _v2_check(a, b): self._test_hashable(numpy_a, numpy_b, False) finally: if default: - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() else: - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() def testEqualityNan(self): - default = ops.Tensor._USE_EQUALITY + default = tensor_lib.Tensor._USE_EQUALITY try: def _v1_check(a, b): @@ -164,20 +165,20 @@ def _v2_check(a, b): constant_a = constant_op.constant(float('nan')) constant_b = constant_op.constant(float('nan')) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() self._test_hashable(constant_a, constant_b, False) _v1_check(constant_a, constant_b) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() _v2_check(constant_a, constant_b) self._test_hashable(constant_a, constant_b, False) variable_a = variables.Variable(float('nan')) variable_b = variables.Variable(float('nan')) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() _v1_check(variable_a, variable_b) self._test_hashable(variable_a, variable_b, True) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() _v2_check(variable_a, variable_b) self._test_hashable(variable_a, variable_b, False) @@ -187,12 +188,12 @@ def _v2_check(a, b): self._test_hashable(numpy_a, numpy_b, False) finally: if default: - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() else: - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() def testEqualityCompare(self): - default = ops.Tensor._USE_EQUALITY + default = tensor_lib.Tensor._USE_EQUALITY try: tf_a = constant_op.constant([1, 2]) @@ -202,7 +203,7 @@ def testEqualityCompare(self): np_b = np.array([1, 2]) np_c = np.array([1, 1]) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() # We don't do element-wise comparison self.assertNotEqual(tf_a, tf_b) self.assertNotEqual(tf_a, tf_c) @@ -216,7 +217,7 @@ def testEqualityCompare(self): self.assertIn(tf_a, [tf_b, tf_a]) self.assertNotIn(tf_a, [tf_b, tf_c]) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() # We do element-wise comparison but can't convert results array to bool with self.assertRaises(ValueError): bool(tf_a == tf_b) @@ -266,12 +267,12 @@ def testEqualityCompare(self): self.assertAllEqual(np.array(1) == np.array(2), False) finally: if default: - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() else: - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() def testEqualityBroadcast(self): - default = ops.Tensor._USE_EQUALITY + default = tensor_lib.Tensor._USE_EQUALITY try: tf_a = constant_op.constant([1, 1]) @@ -285,13 +286,13 @@ def testEqualityBroadcast(self): np_d = np.array([[1, 2], [1, 2]]) np_e = np.array([1, 1, 1]) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() # We don't do element-wise comparison self.assertNotEqual(tf_a, tf_b) self.assertNotEqual(tf_a, tf_c) self.assertNotEqual(tf_a, tf_d) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() # We do element-wise comparison but can't convert results array to bool with self.assertRaises(ValueError): bool(tf_a == tf_b) @@ -322,9 +323,9 @@ def testEqualityBroadcast(self): self.assertNotAllEqual(np_a, np_e) finally: if default: - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() else: - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() @test_util.disable_tfrt('Get execution mode not supported in TFRT.') def testContext(self): diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py index 5e8be440cb4af6..7d7ac1b8e0dff2 100644 --- a/tensorflow/python/eager/lift_to_graph.py +++ b/tensorflow/python/eager/lift_to_graph.py @@ -19,6 +19,7 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import op_selector from tensorflow.python.ops import resource_variable_ops @@ -31,7 +32,7 @@ def _as_operation(op_or_tensor): - if isinstance(op_or_tensor, ops.Tensor): + if isinstance(op_or_tensor, tensor_lib.Tensor): return op_or_tensor.op return op_or_tensor diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index 1792b9fb659b8b..d1006b4ece3ef1 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.layers import core @@ -191,8 +192,8 @@ def ops_test(v1, v2): self.assertAllEqual((a >= b), np.greater_equal(v1, v2)) # TODO(b/120678848): Remove the else branch once we enable - # ops.Tensor._USE_EQUALITY by default. - if ops.Tensor._USE_EQUALITY: + # tensor.Tensor._USE_EQUALITY by default. + if tensor.Tensor._USE_EQUALITY: self.assertAllEqual((a == b), np.equal(v1, v2)) self.assertAllEqual((a != b), np.not_equal(v1, v2)) else: diff --git a/tensorflow/python/eager/polymorphic_function/BUILD b/tensorflow/python/eager/polymorphic_function/BUILD index e020919e5f9d8f..6a093f105f51ca 100644 --- a/tensorflow/python/eager/polymorphic_function/BUILD +++ b/tensorflow/python/eager/polymorphic_function/BUILD @@ -105,8 +105,8 @@ py_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:type_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:default_gradient", @@ -266,8 +266,8 @@ cuda_py_strict_test( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/framework:type_spec", @@ -331,6 +331,10 @@ tf_py_strict_test( tf_xla_py_strict_test( name = "polymorphic_function_xla_jit_test", srcs = ["polymorphic_function_xla_jit_test.py"], + # copybara:uncomment_begin + # #TODO(b/185944215) # Remove once the bug is fixed. + # disable_tpu_tfrt = True, + # copybara:uncomment_end disabled_backends = [ "cpu_ondemand", ], @@ -352,7 +356,7 @@ tf_xla_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:collective_ops", @@ -443,7 +447,6 @@ tf_py_strict_test( python_version = "PY3", deps = [ ":function_type_utils", - ":polymorphic_function", ":tracing_compilation", "//tensorflow/core:protos_all_py", "//tensorflow/core/function/capture:capture_container", @@ -455,7 +458,7 @@ tf_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/layers", @@ -478,7 +481,6 @@ tf_py_strict_test( "//tensorflow/python/saved_model:save", "//tensorflow/python/util:compat", "//tensorflow/python/util:nest", - "//tensorflow/python/util:tf_decorator", "@absl_py//absl/testing:parameterized", ], ) @@ -585,10 +587,8 @@ py_strict_library( ":composite_tensor_utils", "//tensorflow/core/function/polymorphism:function_type", "//tensorflow/core/function/trace_type", - "//tensorflow/python/framework:composite_tensor", - "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:type_spec", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/util:nest", @@ -651,7 +651,7 @@ tf_xla_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_gen", diff --git a/tensorflow/python/eager/polymorphic_function/atomic_function.py b/tensorflow/python/eager/polymorphic_function/atomic_function.py index 3c4425f215012f..ae9ec964659a9f 100644 --- a/tensorflow/python/eager/polymorphic_function/atomic_function.py +++ b/tensorflow/python/eager/polymorphic_function/atomic_function.py @@ -234,6 +234,9 @@ def __call__(self, *args: core.Tensor) -> Sequence[core.Tensor]: if len(args) != expected_len: raise ValueError( f"Signature specifies {expected_len} arguments, got: {len(args)}." + f" Expected inputs: {self.cached_definition.signature.input_arg}." + f" Received inputs: {args}." + f" Function Type: {self.function_type!r}" ) with InterpolateRuntimeError(self): diff --git a/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py b/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py index e3b301307838b5..edc9ad50cb5e66 100644 --- a/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py +++ b/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py @@ -19,7 +19,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -33,11 +33,11 @@ class CompilerIrTest(xla_test.XLATestCase): def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs): flat_args = list(args) + list(kwargs.values()) - if not all([isinstance(x, ops.Tensor) for x in flat_args]): + if not all([isinstance(x, tensor.Tensor) for x in flat_args]): self.skipTest('It only support args and kwargs are all tf.Tensor types.') - args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args) - kwargs_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, kwargs) + args_spec = nest.map_structure(tensor.TensorSpec.from_tensor, args) + kwargs_spec = nest.map_structure(tensor.TensorSpec.from_tensor, kwargs) hlo_1 = f.experimental_get_compiler_ir(*args, **kwargs)() hlo_2 = f.experimental_get_compiler_ir(*args_spec, **kwargs_spec)() @@ -105,7 +105,7 @@ def f(x): with self.assertRaisesRegex( ValueError, 'Only support static input shape but got' ): - args_spec = [tensor_spec.TensorSpec((None), dtype=dtypes.float32)] + args_spec = [tensor.TensorSpec((None), dtype=dtypes.float32)] concrete_fn = f.get_concrete_function(*args_spec) _ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo') @@ -117,7 +117,7 @@ def f2(x): return x[x[0] : 0] args = [ops.convert_to_tensor([1, 2, 3, 4])] - args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args) + args_spec = nest.map_structure(tensor.TensorSpec.from_tensor, args) concrete_fn = f2.get_concrete_function(*args_spec) if test_util.is_mlir_bridge_enabled(): with self.assertRaisesRegex( @@ -142,17 +142,17 @@ def f4(a, b): kwargs = {'b': a, 'a': b} kwargs_spec = nest.map_structure( - tensor_spec.TensorSpec.from_tensor, kwargs + tensor.TensorSpec.from_tensor, kwargs ) concrete_fn = f4.get_concrete_function(**kwargs_spec) captured_inputs = concrete_fn.captured_inputs captured_spec = compiler_ir.make_handledata_tensor_specs(captured_inputs) self.assertEqual(len(captured_spec), 2) self.assertEqual( - captured_spec[0], tensor_spec.TensorSpec((2), dtype=dtypes.float32) + captured_spec[0], tensor.TensorSpec((2), dtype=dtypes.float32) ) self.assertEqual( - captured_spec[1], tensor_spec.TensorSpec((1), dtype=dtypes.int32) + captured_spec[1], tensor.TensorSpec((1), dtype=dtypes.int32) ) def test_capture_variable_1(self): @@ -224,13 +224,13 @@ def fun_tf(x): return (x * v3 + t4 + v2) * v3 + t5 concrete_fn = fun_tf.get_concrete_function( - tensor_spec.TensorSpec((None,), dtype=dtypes.float32) + tensor.TensorSpec((None,), dtype=dtypes.float32) ) - x = tensor_spec.TensorSpec((10,), dtype=dtypes.float32) + x = tensor.TensorSpec((10,), dtype=dtypes.float32) hlo_1 = compiler_ir.from_concrete_function(concrete_fn, [x])(stage='hlo') self.assertIn('f32[10]', hlo_1) - x = tensor_spec.TensorSpec((20,), dtype=dtypes.float32) + x = tensor.TensorSpec((20,), dtype=dtypes.float32) hlo_2 = compiler_ir.from_concrete_function(concrete_fn, [x])(stage='hlo') self.assertIn('f32[20]', hlo_2) diff --git a/tensorflow/python/eager/polymorphic_function/concrete_function.py b/tensorflow/python/eager/polymorphic_function/concrete_function.py index 5461a54759b1f4..3f3cce06fb92a4 100644 --- a/tensorflow/python/eager/polymorphic_function/concrete_function.py +++ b/tensorflow/python/eager/polymorphic_function/concrete_function.py @@ -36,8 +36,8 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import default_gradient @@ -168,7 +168,7 @@ def _construct_forward_backward(self, num_doutputs): signature = [] for t in trainable_outputs: signature.append( - tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t))) + tensor_lib.TensorSpec(*default_gradient.shape_and_dtype(t))) def _backprop_function(*grad_ys): with ops.device(None): @@ -1177,7 +1177,7 @@ def _call_with_flat_signature(self, args, kwargs): for i, arg in enumerate(args): if not isinstance( - arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)): + arg, (tensor_lib.Tensor, resource_variable_ops.BaseResourceVariable)): raise TypeError(f"{self._flat_signature_summary()}: expected argument " f"#{i}(zero-based) to be a Tensor; " f"got {type(arg).__name__} ({arg}).") @@ -1391,7 +1391,7 @@ def bool_closure(): concrete_fn.replace_capture_with_deferred_capture( bool_captured_tensor, bool_closure, - spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)) + spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.bool)) print(concrete_fn()) # tf.Tensor([5.], shape=(1,), dtype=float32) ``` @@ -1651,7 +1651,7 @@ def pretty_printed_signature(self, verbose=True): def pretty_print_spec(spec): """Returns a string describing the spec for a single argument.""" - if isinstance(spec, tensor_spec.TensorSpec): + if isinstance(spec, tensor_lib.TensorSpec): return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape) elif nest.is_nested(spec): pieces = nest.flatten(spec, expand_composites=False) @@ -1762,7 +1762,7 @@ def _export_to_saved_model_graph(self, object_map, tensor_map, return [] -_pywrap_utils.RegisterType("Tensor", ops.Tensor) +_pywrap_utils.RegisterType("Tensor", tensor_lib.Tensor) _pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) _pywrap_utils.RegisterType("IndexedSlices", indexed_slices.IndexedSlices) diff --git a/tensorflow/python/eager/polymorphic_function/function_type_utils.py b/tensorflow/python/eager/polymorphic_function/function_type_utils.py index 612caa4fb0c5ff..5da72faa326f52 100644 --- a/tensorflow/python/eager/polymorphic_function/function_type_utils.py +++ b/tensorflow/python/eager/polymorphic_function/function_type_utils.py @@ -23,7 +23,7 @@ from tensorflow.core.function import trace_type from tensorflow.core.function.polymorphism import function_type as function_type_lib from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import type_spec from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import nest @@ -165,7 +165,7 @@ def to_input_signature(function_type): trace_type.InternalPlaceholderContext(unnest_only=True) ) if any( - not isinstance(arg, tensor_spec.TensorSpec) + not isinstance(arg, tensor.TensorSpec) for arg in nest.flatten([constraint], expand_composites=True) ): # input_signature only supports contiguous TensorSpec composites @@ -465,13 +465,13 @@ def _validate_signature(signature): ) if any( - not isinstance(arg, tensor_spec.TensorSpec) + not isinstance(arg, tensor.TensorSpec) for arg in nest.flatten(signature, expand_composites=True) ): bad_args = [ arg for arg in nest.flatten(signature, expand_composites=True) - if not isinstance(arg, tensor_spec.TensorSpec) + if not isinstance(arg, tensor.TensorSpec) ] raise TypeError( "input_signature must be a possibly nested sequence of " @@ -483,7 +483,7 @@ def _validate_signature(signature): def _to_tensor_or_tensor_spec(x): return ( x - if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) + if isinstance(x, (tensor.Tensor, tensor.TensorSpec)) else ops.convert_to_tensor(x) ) @@ -502,7 +502,7 @@ def _get_variable_specs(args): continue if isinstance(arg, resource_variable_ops.VariableSpec): variable_specs.append(arg) - elif not isinstance(arg, tensor_spec.TensorSpec): + elif not isinstance(arg, tensor.TensorSpec): # arg is a CompositeTensor spec. variable_specs.extend(_get_variable_specs(arg._component_specs)) # pylint: disable=protected-access return variable_specs diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py index 12414894f7a004..e06f87a04fe6ab 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py @@ -52,8 +52,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec @@ -110,7 +110,7 @@ def _spec_for_value(value): """Returns the (nested) TypeSpec for a value.""" if nest.is_nested(value): return nest.map_structure(_spec_for_value, value) - elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): + elif isinstance(value, (tensor_lib.Tensor, composite_tensor.CompositeTensor)): return type_spec.type_spec_from_value(value) else: return value @@ -408,8 +408,8 @@ def testImplementsWorksWithTensorSpec(self): v = polymorphic_function.function( experimental_implements='func')(lambda x, y: x + y) v = v.get_concrete_function( - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32)) x = v(1., 2.) self.assertEqual(x.numpy(), 3.) @@ -546,21 +546,21 @@ def check_trace(x, expected_trace): check_trace( structured_tensor.StructuredTensor.from_pyval({'a': [1]}), structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={'a': tensor_spec.TensorSpec((1,), dtypes.int32)}, rank=0)) + fields={'a': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0)) check_trace( structured_tensor.StructuredTensor.from_pyval({'b': [1]}), structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={'b': tensor_spec.TensorSpec((1,), dtypes.int32)}, rank=0)) + fields={'b': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0)) check_trace( structured_tensor.StructuredTensor.from_pyval({'c': [1]}), structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={'c': tensor_spec.TensorSpec((1,), dtypes.int32)}, rank=0)) + fields={'c': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0)) # But if we call again with only shape different, then do relax: check_trace( # relax & retrace structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}), structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={'a': tensor_spec.TensorSpec((None,), dtypes.int32)}, + fields={'a': tensor_lib.TensorSpec((None,), dtypes.int32)}, rank=0)) check_trace( # use relaxed graph structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}), None) @@ -593,13 +593,13 @@ def check_trace(x, expected_trace): check_trace( # shape=[1, 2]: retrace dataset_ops.make_one_shot_iterator(ds_1_2), iterator_ops.IteratorSpec( - tensor_spec.TensorSpec([1, 2], dtypes.float32))) + tensor_lib.TensorSpec([1, 2], dtypes.float32))) check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph) dataset_ops.make_one_shot_iterator(ds_1_2), None) check_trace( # shape=[2, 2]: relax to [None, 2] and retrace dataset_ops.make_one_shot_iterator(ds_2_2), iterator_ops.IteratorSpec( - tensor_spec.TensorSpec([None, 2], dtypes.float32))) + tensor_lib.TensorSpec([None, 2], dtypes.float32))) check_trace( # shape=[3, 2]: no retrace (use the [None, 2] graph) dataset_ops.make_one_shot_iterator(ds_3_2), None) check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph) @@ -607,7 +607,7 @@ def check_trace(x, expected_trace): check_trace( # shape=[2, 1]: relax to [None, None] and retrace dataset_ops.make_one_shot_iterator(ds_2_1), iterator_ops.IteratorSpec( - tensor_spec.TensorSpec([None, None], dtypes.float32))) + tensor_lib.TensorSpec([None, None], dtypes.float32))) def testCapturesVariables(self): a = variables.Variable(1.0, trainable=False) @@ -787,7 +787,7 @@ def sq(a): return matmul(a, a) sq_op = sq.get_concrete_function( - tensor_spec.TensorSpec((None, None), dtypes.float32)) + tensor_lib.TensorSpec((None, None), dtypes.float32)) self.assertEqual([None, None], sq_op.output_shapes.as_list()) t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) @@ -806,18 +806,16 @@ def sq(mats): ((a, b),) = mats return matmul(a, b) - sq_op_autonamed = sq.get_concrete_function([(tensor_spec.TensorSpec( - (None, None), - dtypes.float32), tensor_spec.TensorSpec((None, None), dtypes.float32))]) + sq_op_autonamed = sq.get_concrete_function([( + tensor_lib.TensorSpec((None, None), dtypes.float32), + tensor_lib.TensorSpec((None, None), dtypes.float32), + )]) self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list()) - sq_op = sq.get_concrete_function([(tensor_spec.TensorSpec((None, None), - dtypes.float32, - name='first_mat'), - tensor_spec.TensorSpec( - (None, None), - dtypes.float32, - name='second_mat'))]) + sq_op = sq.get_concrete_function([( + tensor_lib.TensorSpec((None, None), dtypes.float32, name='first_mat'), + tensor_lib.TensorSpec((None, None), dtypes.float32, name='second_mat'), + )]) self.assertEqual([None, None], sq_op.output_shapes.as_list()) t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) @@ -892,7 +890,7 @@ def testShareRendezvous(self): cpu = '/device:CPU:0' - signature = [tensor_spec.TensorSpec([], dtypes.int32)] + signature = [tensor_lib.TensorSpec([], dtypes.int32)] @polymorphic_function.function def send(): @@ -960,8 +958,8 @@ def a_times_b(inputs): t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq_op = a_times_b.get_concrete_function( pair( - dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')), - dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b')))) + dict(a=tensor_lib.TensorSpec([2, 2], dtypes.float32, 'a')), + dict(b=tensor_lib.TensorSpec([2, 2], dtypes.float32, 'b')))) self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) out = sq_op(a=t, b=t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) @@ -1116,7 +1114,7 @@ def testShapeInferenceForMoreSpecificInput(self): def f(a): return array_ops.reshape(a, [-1, 3]) - signature = [tensor_spec.TensorSpec(None, dtypes.float32)] + signature = [tensor_lib.TensorSpec(None, dtypes.float32)] compiled = polymorphic_function.function(f, input_signature=signature) @polymorphic_function.function @@ -1522,7 +1520,7 @@ def testConcreteFunctionType(self): def foo(x): return {'input': x, 'capture': y} - cf = foo.get_concrete_function(tensor_spec.TensorSpec([], dtypes.int32)) + cf = foo.get_concrete_function(tensor_lib.TensorSpec([], dtypes.int32)) x = constant_op.constant(2) output = cf(x) self.assertEqual(set(output.keys()), {'input', 'capture'}) @@ -1534,12 +1532,12 @@ def foo(x): self.assertEqual(parameters[0].name, 'x') self.assertEqual( parameters[0].type_constraint, - tensor_spec.TensorSpec([], dtypes.int32), + tensor_lib.TensorSpec([], dtypes.int32), ) captures = cf.function_type.captures self.assertLen(captures, 1) - self.assertEqual(captures[id(y)], tensor_spec.TensorSpec([], dtypes.int32)) + self.assertEqual(captures[id(y)], tensor_lib.TensorSpec([], dtypes.int32)) output = cf.function_type.output self.assertEqual(output, trace_type.from_value({'input': x, 'capture': y})) @@ -1551,8 +1549,8 @@ def testSequenceInputs(self): clipped_list, global_norm = clip_by_global_norm(t_list, constant_op.constant(.2)) for t in clipped_list: - self.assertIsInstance(t, ops.Tensor) - self.assertIsInstance(global_norm, ops.Tensor) + self.assertIsInstance(t, tensor_lib.Tensor) + self.assertIsInstance(global_norm, tensor_lib.Tensor) def testNestedSequenceInputs(self): @@ -1690,7 +1688,7 @@ def foo(a, b): del b # Signatures must consist exclusively of `TensorSpec` objects. - signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] + signature = [(2, 3), tensor_lib.TensorSpec([2, 3], dtypes.float32)] with self.assertRaisesRegex(TypeError, 'input_signature.*nested sequence'): polymorphic_function.function(foo, input_signature=signature) @@ -1700,7 +1698,7 @@ def testInputsIncompatibleWithSignatureRaisesError(self): def foo(a): return a - signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] + signature = [tensor_lib.TensorSpec(shape=(2,), dtype=dtypes.float32)] defined = polymorphic_function.function(foo, input_signature=signature) # Valid call @@ -1729,7 +1727,7 @@ def foo(a): TypeError, r'Can not cast .*shape=\(3,\).* to .*shape=\(2,\).*' ): defined.get_concrete_function( - tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=(3,), dtype=dtypes.float32)) def testMismatchedConcreteSignatureRaisesError(self): @@ -1761,8 +1759,8 @@ def foo(a, training=True): return -1.0 * a signature = [ - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.bool), + tensor_lib.TensorSpec([], dtypes.float32), + tensor_lib.TensorSpec([], dtypes.bool), ] defined = polymorphic_function.function(foo, input_signature=signature) a = constant_op.constant(1.0) @@ -1860,8 +1858,8 @@ def py_add(x, y): py_add(array_ops.ones([]), array_ops.ones([])) add = py_add.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)) + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32)) @polymorphic_function.function def py_composite(x, y): @@ -1869,8 +1867,8 @@ def py_composite(x, y): py_composite(array_ops.ones([]), array_ops.ones([])) composite = py_composite.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)) + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32)) with context.graph_mode(), self.cached_session(): with ops.get_default_graph().as_default(): @@ -2287,9 +2285,9 @@ def _uses_symbolic_shapes(w, x, y): return array_ops.reshape(y, [n, x_batch, -1]) conc = _uses_symbolic_shapes.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)) + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32)) @polymorphic_function.function def _call_concrete(): @@ -2482,7 +2480,7 @@ def f(x, y): @test_util.run_in_graph_and_eager_modes def testConcreteFunctionMethodWithVarargs(self): - float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) + float32_scalar = tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32) class MyModel(module.Module): @@ -2808,8 +2806,8 @@ def f(x, y): return x * 10 + y conc = f.get_concrete_function( - x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'), - y=tensor_spec.TensorSpec(None, dtypes.int32, name='x')) + x=tensor_lib.TensorSpec(None, dtypes.int32, name='y'), + y=tensor_lib.TensorSpec(None, dtypes.int32, name='x')) result = conc(x=constant_op.constant(5), y=constant_op.constant(6)) self.assertAllEqual(result, 56) @@ -2886,7 +2884,7 @@ def func2(x, y=3, *args, **kwargs): def testPrettyPrintedExplicitSignatureWithKeywordArg(self): @polymorphic_function.function( - input_signature=[tensor_spec.TensorSpec(None)]) + input_signature=[tensor_lib.TensorSpec(None)]) def fn(a, b=1): return a + b @@ -3101,8 +3099,8 @@ def func_pos_3args(position_arg1, position_arg2, position_arg3): def testShapeInferencePropagateConstNestedStack(self): @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((None, None), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), ]) def f(x, s): old_shape = array_ops.shape(x) @@ -3111,7 +3109,7 @@ def f(x, s): return y @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=(3, 6), dtype=dtypes.int32) ]) def g(x): y = f(x, s=5) @@ -3124,8 +3122,8 @@ def g(x): def testShapeInferencePropagateConstNestedUnstackStack(self): @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((None, None), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), ]) def f(x, s): s0, _ = array_ops_stack.unstack(array_ops.shape(x), axis=0) @@ -3134,7 +3132,7 @@ def f(x, s): return y @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=(3, 6), dtype=dtypes.int32) ]) def g(x): y = f(x, s=5) @@ -3147,9 +3145,9 @@ def g(x): def testShapeInferencePropagateConstNestedConcat(self): @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec((), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), ]) def f(d1, d2, d3): new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) @@ -3167,9 +3165,9 @@ def g(): def testShapeInferencePropagateConstDoubleNested(self): @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec((), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), ]) def f(d1, d2, d3): new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) @@ -3417,8 +3415,8 @@ def apply(self, x): def testMethodExtensionType(self): class MaskedTensor(extension_type.ExtensionType): - values: ops.Tensor - mask: ops.Tensor + values: tensor_lib.Tensor + mask: tensor_lib.Tensor @polymorphic_function.function def with_default(self, default_value): @@ -3495,24 +3493,24 @@ def dynamic_unroll(core_fn, def test_unspecified_default_argument(self): wrapped = polymorphic_function.function( lambda x, y=2: x + y, - input_signature=[tensor_spec.TensorSpec((), dtypes.int32)]) + input_signature=[tensor_lib.TensorSpec((), dtypes.int32)]) self.assertEqual(3, wrapped(constant_op.constant(1)).numpy()) def test_concrete_function_from_signature(self): @polymorphic_function.function( - input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) + input_signature=[tensor_lib.TensorSpec(None, dtypes.float32)]) def compute(x): return 2. * x concrete = compute.get_concrete_function() self.assertAllClose(1., concrete(constant_op.constant(0.5))) concrete = compute.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32)) + tensor_lib.TensorSpec(None, dtypes.float32)) self.assertAllClose(4., concrete(constant_op.constant(2.))) signature_args, _ = concrete.structured_input_signature self.assertEqual(signature_args, - (tensor_spec.TensorSpec( + (tensor_lib.TensorSpec( None, dtypes.float32, name='x'),)) def testInputSignatureMissingTensorSpecsMethod(self): @@ -3539,7 +3537,7 @@ def f6(self, arg1, arg4=4, **kwargs): m = MyModule() tf_func_dec = polymorphic_function.function( - input_signature=(tensor_spec.TensorSpec([], dtypes.int32),)) + input_signature=(tensor_lib.TensorSpec([], dtypes.int32),)) error_message = 'input_signature missing type constraint' with self.assertRaisesRegex(TypeError, error_message): tf_func_dec(m.f1)(1, 2, 3) @@ -3560,7 +3558,7 @@ def f6(self, arg1, arg4=4, **kwargs): def testInputSignatureMissingTensorSpecsFunction(self): tf_func_dec = polymorphic_function.function( - input_signature=(tensor_spec.TensorSpec([], dtypes.int32),)) + input_signature=(tensor_lib.TensorSpec([], dtypes.int32),)) error_message = 'input_signature missing type constraint' # pylint: disable=unused-argument def f1(arg1, arg2, arg3): @@ -3600,7 +3598,7 @@ def f6(arg1, arg4=4, **kwargs): def testInputSignatureMissingTensorSpecsLambdaFunction(self): tf_func_dec = polymorphic_function.function( - input_signature=(tensor_spec.TensorSpec([], dtypes.int32),)) + input_signature=(tensor_lib.TensorSpec([], dtypes.int32),)) error_message = 'input_signature missing type constraint' with self.assertRaisesRegex(TypeError, error_message): tf_func_dec(lambda ar1, arg2, arg3: None)(1, 2, 3) @@ -3638,7 +3636,7 @@ def f(arg1, arg2, arg3, arg4=4): error_message = 'input_signature missing type constraint' tf_func_dec = polymorphic_function.function( - input_signature=(tensor_spec.TensorSpec([], dtypes.int32),) + input_signature=(tensor_lib.TensorSpec([], dtypes.int32),) ) with self.assertRaisesRegex(TypeError, error_message): tf_func_dec(functools.partial(f, 1))(2, 3) @@ -3690,20 +3688,20 @@ def f(x): return x conc = f.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32, 'y')) + tensor_lib.TensorSpec(None, dtypes.float32, 'y')) conc(y=constant_op.constant(3.0)) signature_args, _ = conc.structured_input_signature self.assertEqual('y', signature_args[0].name) # If name is not specified, the previously named one will be returned. - conc = f.get_concrete_function(tensor_spec.TensorSpec(None, dtypes.float32)) + conc = f.get_concrete_function(tensor_lib.TensorSpec(None, dtypes.float32)) conc(x=constant_op.constant(3.0)) signature_args, _ = conc.structured_input_signature self.assertEqual('y', signature_args[0].name) # New name will return updated signature. conc = f.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32, 'z') + tensor_lib.TensorSpec(None, dtypes.float32, 'z') ) conc(x=constant_op.constant(3.0)) signature_args, _ = conc.structured_input_signature @@ -3714,7 +3712,7 @@ def g(x): return x[0] conc = g.get_concrete_function( - [tensor_spec.TensorSpec(None, dtypes.float32, 'z'), 2]) + [tensor_lib.TensorSpec(None, dtypes.float32, 'z'), 2]) conc(z=constant_op.constant(3.0)) signature_args, _ = conc.structured_input_signature self.assertEqual('z', signature_args[0][0].name) @@ -3756,10 +3754,10 @@ def f(x, y): self.assertEqual( signatures_args, - set(((tensor_spec.TensorSpec([1, 2], dtypes.float32, name='x'), - tensor_spec.TensorSpec([1], dtypes.float32, name='y')), - (tensor_spec.TensorSpec([1, 3], dtypes.int32, name='x'), - tensor_spec.TensorSpec([1], dtypes.int32, name='y'))))) + set(((tensor_lib.TensorSpec([1, 2], dtypes.float32, name='x'), + tensor_lib.TensorSpec([1], dtypes.float32, name='y')), + (tensor_lib.TensorSpec([1, 3], dtypes.int32, name='x'), + tensor_lib.TensorSpec([1], dtypes.int32, name='y'))))) @test_util.assert_no_garbage_created def testFunctionReferenceCycles(self): @@ -3838,10 +3836,10 @@ def non_unique_arg_names(x, **kwargs): return a + b + c + d concrete = non_unique_arg_names.get_concrete_function( - (tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)), - d=tensor_spec.TensorSpec(None, dtypes.float32)) + (tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32)), + d=tensor_lib.TensorSpec(None, dtypes.float32)) self.assertAllClose( 10., concrete(x=constant_op.constant(1.), @@ -3949,9 +3947,9 @@ def func(x): return 2 * x func_a = func.get_concrete_function( - tensor_spec.TensorSpec([None], dtypes.int32)) + tensor_lib.TensorSpec([None], dtypes.int32)) func_b = func.get_concrete_function( - tensor_spec.TensorSpec([None], dtypes.int32)) + tensor_lib.TensorSpec([None], dtypes.int32)) self.assertIs(func_a, func_b) @@ -4049,7 +4047,7 @@ def decorator(f): self.assertEqual(func().numpy(), 2) @parameterized.parameters(*itertools.product( - (None, (tensor_spec.TensorSpec([]),)), # input_signature + (None, (tensor_lib.TensorSpec([]),)), # input_signature (True, False), # autograph (None, converter.Feature.ALL), # autograph_options (None, 'foo.bar'), # implements @@ -4133,7 +4131,7 @@ def func(): self.assertEmpty(graph.captures) @parameterized.parameters(*itertools.product( - (None, (tensor_spec.TensorSpec([]),)), # input_signature + (None, (tensor_lib.TensorSpec([]),)), # input_signature (True, False), # autograph (None, converter.Feature.ALL), # autograph_options (None, 'foo.bar'), # implements @@ -4317,7 +4315,7 @@ def __call__(self, x): f_flexible = Foo() _ = f_flexible.__call__.get_concrete_function( - tensor_spec.TensorSpec(shape=[None], dtype=dtypes.int32)) + tensor_lib.TensorSpec(shape=[None], dtype=dtypes.int32)) tmp_dir = self.create_tempdir() save(f_flexible, tmp_dir.full_path) restored_f_flexible = load(tmp_dir.full_path) @@ -4386,14 +4384,14 @@ def testDouble(self, a): def test_tensor_shape_casted_to_specific(self): @polymorphic_function.function( - input_signature=[tensor_spec.TensorSpec([1])] + input_signature=[tensor_lib.TensorSpec([1])] ) def specific(x): self.assertEqual(x.shape, [1]) return x @polymorphic_function.function( - input_signature=[tensor_spec.TensorSpec(None)] + input_signature=[tensor_lib.TensorSpec(None)] ) def general(x): return specific(x) @@ -4572,7 +4570,7 @@ def closure(): concrete_fn.replace_capture_with_deferred_capture( concrete_fn.captured_inputs[1], closure, - spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), placeholder=concrete_fn.inputs[1]) self.assertAllEqual(concrete_fn(), 8.0) @@ -4589,7 +4587,7 @@ def testRaiseReplaceCaptureWithDeferredTypeSpecMismatch(self): def fn(): deferred_tensor = ops.get_default_graph().capture_call_time_value( lambda: value, - tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32)) if bool_captured_tensor: return deferred_tensor else: @@ -4615,13 +4613,13 @@ def float_closure(): concrete_fn.replace_capture_with_deferred_capture( bool_captured_tensor, float_closure, - spec=tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) + spec=tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32)) # Test replace without a placeholder concrete_fn.replace_capture_with_deferred_capture( bool_captured_tensor, bool_closure, - spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)) + spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.bool)) self.assertAllEqual(concrete_fn(), [5.]) @@ -4633,7 +4631,7 @@ def testConcreteFunctionSetExternalCapture(self): def fn(): deferred_tensor = ops.get_default_graph().capture_call_time_value( lambda: value, - tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32)) return deferred_tensor + captured_tensor cf = fn.get_concrete_function() @@ -4656,7 +4654,7 @@ def testGraphReplaceCaptureAndSetExternalCapture(self): def fn(): deferred_tensor = ops.get_default_graph().capture_call_time_value( lambda: value, - tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32)) if bool_captured_tensor: return deferred_tensor else: @@ -4673,7 +4671,7 @@ def closure(): concrete_fn.graph.replace_capture_with_deferred_capture( concrete_fn.captured_inputs[0], closure, - spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool), + spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.bool), placeholder=concrete_fn.inputs[1]) concrete_fn.set_external_captures([ @@ -4688,7 +4686,7 @@ def testDeferredCapture(self): @polymorphic_function.function def lazy_capture(x): y = ops.get_default_graph().capture_call_time_value( - lambda: value, tensor_spec.TensorSpec(None)) + lambda: value, tensor_lib.TensorSpec(None)) return x + y self.assertAllEqual(lazy_capture(2.0), 3.0) @@ -4703,7 +4701,7 @@ def testNestedDeferredCapture(self): @polymorphic_function.function def inner(x): y = ops.get_default_graph().capture_call_time_value( - lambda: value, tensor_spec.TensorSpec(None)) + lambda: value, tensor_lib.TensorSpec(None)) return x + y @polymorphic_function.function @@ -4723,7 +4721,7 @@ def testNestedDeferredCaptureInTFWhileLoop(self): @polymorphic_function.function def inner(x): y = ops.get_default_graph().capture_call_time_value( - lambda: value, tensor_spec.TensorSpec(None)) + lambda: value, tensor_lib.TensorSpec(None)) return x + y @polymorphic_function.function @@ -4752,15 +4750,15 @@ def testDeferredCaptureWithKey(self): @polymorphic_function.function def lazy_capture(x): w = ops.get_default_graph().capture_call_time_value( - lambda: value0, tensor_spec.TensorSpec(None), key=0) + lambda: value0, tensor_lib.TensorSpec(None), key=0) y = ops.get_default_graph().capture_call_time_value( - lambda: value1, tensor_spec.TensorSpec(None), key=1) + lambda: value1, tensor_lib.TensorSpec(None), key=1) def bad_closure(): raise ValueError('Should not run') z = ops.get_default_graph().capture_call_time_value( - bad_closure, tensor_spec.TensorSpec(None), key=1) + bad_closure, tensor_lib.TensorSpec(None), key=1) return x + y + w + z self.assertAllEqual(lazy_capture(2.0), 7.0) @@ -4774,7 +4772,7 @@ def testDeferredCaptureTypeError(self): @polymorphic_function.function def lazy_capture(x): y = ops.get_default_graph().capture_call_time_value( - lambda: value, tensor_spec.TensorSpec(())) + lambda: value, tensor_lib.TensorSpec(())) return x + y self.assertAllEqual(lazy_capture(2.0), 3.0) diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py index 43f67f81870ce2..5598226ba6d966 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py @@ -24,7 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops @@ -49,11 +49,11 @@ class FunctionTest(xla_test.XLATestCase): def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs): """Assert the two differnet methods (tensor_spec inputs or tensor inputs) experimental_get_compiler give same HLO text.""" flat_args = list(args) + list(kwargs.values()) - if not all([isinstance(x, ops.Tensor) for x in flat_args]): + if not all([isinstance(x, tensor.Tensor) for x in flat_args]): self.skipTest('It only support args and kwargs are all tf.Tensor types.') - args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args) - kwargs_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, kwargs) + args_spec = nest.map_structure(tensor.TensorSpec.from_tensor, args) + kwargs_spec = nest.map_structure(tensor.TensorSpec.from_tensor, kwargs) hlo_1 = f.experimental_get_compiler_ir(*args, **kwargs)() hlo_2 = f.experimental_get_compiler_ir(*args_spec, **kwargs_spec)() @@ -389,7 +389,7 @@ def g(x): def testWhileLoopWithUnmodifiedCarriedShape(self): with ops.device('device:{}:0'.format(self.device)): - signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)] + signature = [tensor.TensorSpec(shape=[None], dtype=dtypes.float32)] # We define a signature that specifies unknown vector shape, then test # that tf.shape constness gets properly propagated into the while_loop @@ -407,7 +407,7 @@ def g(x): def testNestedWhileLoopWithUnmodifiedCarriedShape(self): with ops.device('device:{}:0'.format(self.device)): - signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)] + signature = [tensor.TensorSpec(shape=[None], dtype=dtypes.float32)] @polymorphic_function.function( input_signature=signature, jit_compile=True) @@ -432,7 +432,7 @@ def outer(y, shp): def testNestedWhileLoopWithUnmodifiedCarriedShapeSlice(self): with ops.device('device:{}:0'.format(self.device)): signature = [ - tensor_spec.TensorSpec(shape=[None, None], dtype=dtypes.float32) + tensor.TensorSpec(shape=[None, None], dtype=dtypes.float32) ] @polymorphic_function.function( diff --git a/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py b/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py index d8dbd707d6b478..ec7c0da10db7e0 100644 --- a/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py +++ b/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py @@ -34,7 +34,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.layers import convolutional @@ -221,7 +221,7 @@ def f_py(): @test_util.run_v2_only def testCompilationNumpyArraysConvertedToTensors(self): def f(x): - self.assertIsInstance(x, ops.Tensor) + self.assertIsInstance(x, tensor_lib.Tensor) return x x = random_ops.random_uniform([2, 2]).numpy() @@ -280,7 +280,7 @@ def f(x, dtype): def testCompilationNumpyArraysConvertedToTensorsInKwargs(self): def f(**kwargs): x = kwargs.pop('x') - self.assertIsInstance(x, ops.Tensor) + self.assertIsInstance(x, tensor_lib.Tensor) return x x = random_ops.random_uniform([2, 2]).numpy() @@ -580,7 +580,7 @@ def foo(a): return a function_cache = function_cache_lib.FunctionCache() - signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] + signature = [tensor_lib.TensorSpec(shape=(2,), dtype=dtypes.float32)] defined = compiled_fn( foo, input_signature=signature, function_cache=function_cache ) @@ -592,7 +592,7 @@ def foo(a): self.assertAllEqual( a, defined.get_concrete_function( - tensor_spec.TensorSpec((2,), dtype=dtypes.float32) + tensor_lib.TensorSpec((2,), dtype=dtypes.float32) )(a), ) self.assertLen(function_cache, 1) @@ -601,7 +601,7 @@ def bar(a): self.assertEqual(a._shape_tuple(), (2, None)) return a - signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] + signature = [tensor_lib.TensorSpec((2, None), dtypes.float32)] defined = compiled_fn(bar, input_signature=signature) a = array_ops.ones([2, 1]) out = defined(a) @@ -629,7 +629,7 @@ def f(*_args, **_kwargs): self.assertLen(function_cache, 2) def testInputSignatureWithCompatibleInputs(self): - rank2_spec = tensor_spec.TensorSpec( + rank2_spec = tensor_lib.TensorSpec( shape=(None, None), dtype=dtypes.float32 ) @@ -656,8 +656,8 @@ def expected_foo(a, b): @compiled_fn( input_signature=[ - [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, - tensor_spec.TensorSpec((1,), dtypes.float32), + [tensor_lib.TensorSpec((2, None), dtypes.float32)] * 2, + tensor_lib.TensorSpec((1,), dtypes.float32), ], function_cache=function_cache, ) @@ -707,9 +707,9 @@ def expected_bar(a): @compiled_fn( input_signature=[{ - 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), - 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), - 'c': tensor_spec.TensorSpec((1,), dtypes.float32), + 'a': tensor_lib.TensorSpec((2, None), dtypes.float32), + 'b': tensor_lib.TensorSpec((2, None), dtypes.float32), + 'c': tensor_lib.TensorSpec((1,), dtypes.float32), }] ) def bar(a): @@ -744,7 +744,7 @@ def foo(a, b): del b # Signatures must be either lists or tuples on their outermost levels. - signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} + signature = {'t1': tensor_lib.TensorSpec([], dtypes.float32)} with self.assertRaisesRegex( TypeError, 'input_signature must be either a tuple or a list.*' ): @@ -755,8 +755,8 @@ def foo(a, b): return [a, b] signature = [ - [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, - [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, + [tensor_lib.TensorSpec((1,), dtypes.float32)] * 2, + [tensor_lib.TensorSpec((1,), dtypes.float32)] * 2, ] defined = compiled_fn(foo, input_signature=signature) a = array_ops.ones([1]) @@ -772,7 +772,7 @@ def foo(a, b): def testUnderspecifiedInputSignature(self): @compiled_fn( input_signature=[ - tensor_spec.TensorSpec([], dtypes.float32), + tensor_lib.TensorSpec([], dtypes.float32), ] ) def foo(a, training=True): @@ -794,7 +794,7 @@ def full_function(a, b, c=3.0): partial = functools.partial(full_function, 1, c=4) a, b, c = partial(2.0) - signature = [tensor_spec.TensorSpec([], dtypes.float32)] + signature = [tensor_lib.TensorSpec([], dtypes.float32)] defined = compiled_fn(partial, input_signature=signature) x = constant_op.constant(2.0) func_a, func_b, func_c = defined(x) @@ -808,8 +808,8 @@ def testInputSignatureWithKeywordPositionalArgs(self): @compiled_fn( input_signature=[ - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.int64), + tensor_lib.TensorSpec([], dtypes.float32), + tensor_lib.TensorSpec([], dtypes.int64), ], function_cache=function_cache, ) @@ -848,8 +848,8 @@ def foo(a, b, **kwargs): x = compiled_fn( foo, input_signature=[ - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.int32), + tensor_lib.TensorSpec([], dtypes.float32), + tensor_lib.TensorSpec([], dtypes.int32), ], ).get_concrete_function() result = x(constant_op.constant(5.0), constant_op.constant(5)) @@ -899,15 +899,15 @@ def f(rt): @test_util.run_v2_only def testInputSignatureWithKeywordOnlyArgs(self): def f(a, b, c=3, *, d=4): - self.assertIsInstance(a, ops.Tensor) - self.assertIsInstance(b, ops.Tensor) + self.assertIsInstance(a, tensor_lib.Tensor) + self.assertIsInstance(b, tensor_lib.Tensor) self.assertIsInstance(c, int) - self.assertIsInstance(d, (int, ops.Tensor)) + self.assertIsInstance(d, (int, tensor_lib.Tensor)) return a + b + c + d signature = [ - tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), - tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32), ] defined = compiled_fn(f, input_signature=signature) self.assertEqual(defined(1, 2).numpy(), 10) @@ -935,8 +935,8 @@ def f(a, b, c=3, *, d=4): def testInputSignatureWithKeywordOnlyArgsNoDefaults(self): signature = [ - tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), - tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32), ] def test_func(a, *, b): @@ -1104,8 +1104,8 @@ def py_add(x, y): py_add(array_ops.ones([]), array_ops.ones([])) add = py_add.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), ) @compiled_fn( @@ -1116,8 +1116,8 @@ def py_composite(x, y): py_composite(array_ops.ones([]), array_ops.ones([])) composite = py_composite.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), ) with context.graph_mode(), self.cached_session(): @@ -1188,8 +1188,8 @@ def matmul(x, y): defun_matmul = compiled_fn( matmul, input_signature=[ - tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(2, 2), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(2, 2), dtype=dtypes.float32), ], function_cache=function_cache_lib.FunctionCache(), ) @@ -1450,7 +1450,7 @@ def defined(t): return t z = array_ops.zeros([2, 2]) - z_spec = tensor_spec.TensorSpec.from_tensor(z) + z_spec = tensor_lib.TensorSpec.from_tensor(z) self.assertIs( defined.get_concrete_function(z_spec), defined.get_concrete_function(z) ) @@ -1616,7 +1616,7 @@ def func(x): return array_ops.shape(x) @compiled_fn( - input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)] + input_signature=[tensor_lib.TensorSpec([None, None], dtypes.float32)] ) def calls_func(x): return func(x) @@ -2014,8 +2014,8 @@ def fn(a, b): fn(array_ops.ones([]), array_ops.ones([])) fn_op = fn.get_concrete_function( - tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), ) self.assertEqual(['a', 'b'], [inp.op.name for inp in fn_op.inputs]) self.assertEqual( @@ -2040,7 +2040,7 @@ def fn(a, b): fn(array_ops.ones([]), array_ops.ones([])) fn_op = fn.get_concrete_function( - tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32), variables.Variable(1.0), ) self.assertEqual(['a', 'b'], [inp.op.name for inp in fn_op.inputs]) @@ -2060,8 +2060,8 @@ def fn(x, z=(1.0, 2.0), y=3.0): fn(array_ops.ones([])) fn_op = fn.get_concrete_function( - x=tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32), - y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + x=tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32), + y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), ) self.assertEqual(['x', 'y'], [inp.op.name for inp in fn_op.inputs]) self.assertEqual( @@ -2074,14 +2074,14 @@ def fn(x, z=(1.0, 2.0), y=3.0): fn_op2 = fn.get_concrete_function( z=( - tensor_spec.TensorSpec( + tensor_lib.TensorSpec( shape=(None,), dtype=dtypes.float32, name='z_first' ), - tensor_spec.TensorSpec( + tensor_lib.TensorSpec( shape=(), dtype=dtypes.float32, name='z_second' ), ), - y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'), + y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'), x=4.0, ) self.assertEqual( @@ -2094,14 +2094,14 @@ def fn(x, z=(1.0, 2.0), y=3.0): ) fn_op3 = fn.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'), z=( - tensor_spec.TensorSpec( + tensor_lib.TensorSpec( shape=(None,), dtype=dtypes.float32, name='z1' ), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='z2'), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='z2'), ), - y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), ) self.assertEqual( ['custom', 'z1', 'z2', 'y'], [inp.op.name for inp in fn_op3.inputs] @@ -2120,7 +2120,7 @@ def method(self, x): has_method = HasMethod() compiled_method = compiled_fn(has_method.method) class_op = compiled_method.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32) ) self.assertEqual(['x'], [inp.op.name for inp in class_op.inputs]) self.assertEqual( @@ -2129,7 +2129,7 @@ def method(self, x): ) method_op = compiled_method.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32) ) self.assertEqual(['x'], [inp.op.name for inp in method_op.inputs]) self.assertEqual( @@ -2141,7 +2141,7 @@ def method(self, x): # should always retrace? self.skipTest('Not working') method_op = has_method.method.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='y') + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='y') ) self.assertEqual(['y'], [inp.op.name for inp in method_op.inputs]) self.assertEqual( @@ -2160,7 +2160,7 @@ def method(self, x): compiled_method = compiled_fn( has_method.method, input_signature=( - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float64, name='y'), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float64, name='y'), ), ) @@ -2185,14 +2185,14 @@ def variadic_fn(x, *args, **kwargs): # Call the function to make def_function happy variadic_fn(array_ops.ones([]), array_ops.ones([])) variadic_op = variadic_fn.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32, name='y'), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - tensor_spec.TensorSpec( + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32, name='y'), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_lib.TensorSpec( shape=(), dtype=dtypes.float32, name='second_variadic' ), - z=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - zz=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='cust'), + z=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), + zz=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='cust'), ) self.assertEqual( ['x', 'y', 'args_1', 'second_variadic', 'z', 'cust'], @@ -2206,10 +2206,10 @@ def variadic_fn(x, *args, **kwargs): def testVariadicInputSignature(self): @compiled_fn( input_signature=( - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32, name='y'), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='z'), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32, name='y'), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='z'), ), name='variadic_fn', ) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 6ffe1caf5c15e1..23935e6a126646 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -29,8 +29,8 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -246,7 +246,7 @@ def _call_impl(self, args, kwargs): if self._signature is not None: args = list(args) for i, arg in enumerate(args): - if isinstance(self._signature[i], tensor_spec.DenseSpec): + if isinstance(self._signature[i], tensor_lib.DenseSpec): args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype) return self._call_flat(args, self.captured_inputs) else: @@ -281,7 +281,7 @@ def prune(self, feeds, fetches, name=None, input_signature=None): flat_feeds = nest.flatten(feeds, expand_composites=True) flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] for f in flat_feeds: - if not isinstance(f, ops.Tensor): + if not isinstance(f, tensor_lib.Tensor): raise ValueError("All memebers of argument `feeds` must be tensors. " f"Got {f} with type {type(f)}.") @@ -319,7 +319,8 @@ def _fetch_preprocessing_callback(fetch): else: operation_fetches.append(decoded) return decoded - elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): + elif isinstance( + fetch, (tensor_lib.Tensor, composite_tensor.CompositeTensor)): tensor_fetches.append(fetch) return fetch else: diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index b1b54e34e068b3..9bb28913c72198 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -76,6 +76,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 4613bf386228d9..c1e5e22867cf82 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -141,6 +141,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -1502,7 +1503,7 @@ def categorical_column_with_vocabulary_file_v2(key, 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file) # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. - if not isinstance(vocabulary_size, ops.Tensor) and vocabulary_size < 1: + if not isinstance(vocabulary_size, tensor_lib.Tensor) and vocabulary_size < 1: raise ValueError('Invalid vocabulary_size in {}.'.format(key)) if num_oov_buckets: if default_value is not None: diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 451ef1fd570dbe..d35df511994a84 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -275,6 +275,7 @@ py_strict_library( srcs_version = "PY3", deps = [ ":ops", + ":tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:tf_logging", @@ -400,6 +401,7 @@ py_strict_library( deps = [ ":dtypes", ":ops", + ":tensor", ":tensor_conversion_registry", ":tensor_shape", ":tensor_util", @@ -528,6 +530,7 @@ py_strict_library( ":dtypes", ":graph_to_function_def", ":ops", + ":tensor", "//tensorflow/core:protos_all_py", "//tensorflow/python/client:pywrap_tf_session", "//tensorflow/python/eager:context", @@ -665,6 +668,7 @@ py_strict_library( ":op_def_library_pybind", ":op_def_registry", ":ops", + ":tensor", ":tensor_shape", "//tensorflow/core:protos_all_py", "//tensorflow/core/config:flags_py", @@ -906,7 +910,7 @@ tf_py_strict_test( ":_pywrap_python_tensor_converter", ":constant_op", ":dtypes", - ":ops", + ":tensor", ":tensor_shape", ":test_lib", "//tensorflow/core:protos_all_py", @@ -1051,7 +1055,7 @@ tf_py_strict_test( deps = [ ":_pywrap_python_api_dispatcher", ":constant_op", - ":ops", + ":tensor", ":test_lib", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -1136,7 +1140,7 @@ tf_py_strict_test( ":constant_op", ":dtypes", ":indexed_slices", - ":ops", + ":tensor", ":test_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", @@ -1532,7 +1536,7 @@ py_strict_library( srcs_version = "PY3", visibility = visibility + ["//tensorflow_model_optimization:__subpackages__"], deps = [ - ":ops", + ":tensor", ":tensor_util", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:control_flow_case", @@ -1603,6 +1607,8 @@ py_strict_library( ":errors", ":extension_type", ":ops", + ":tensor", + ":tensor_conversion_registry", "//tensorflow/python/eager:context", "//third_party/py/numpy", ], @@ -1614,13 +1620,12 @@ tf_py_strict_test( main = "weak_tensor_test.py", python_version = "PY3", srcs_version = "PY3", - tags = ["no_pip"], # weak_tensor_test is not available in pip. deps = [ ":constant_op", ":dtypes", ":errors", ":ops", - ":tensor_spec", + ":tensor", ":test_lib", ":weak_tensor", "//tensorflow/python/eager:backprop", @@ -1687,9 +1692,8 @@ py_strict_library( ":dtypes", ":extension_type_field", ":immutable_dict", - ":ops", + ":tensor", ":tensor_shape", - ":tensor_spec", ":type_spec", ":type_spec_registry", "//tensorflow/core:protos_all_py", @@ -1715,6 +1719,7 @@ py_strict_library( ":dtypes", ":immutable_dict", ":ops", + ":tensor", ":tensor_shape", ":type_spec", "//tensorflow/python/util:type_annotations", @@ -1814,10 +1819,11 @@ pytype_strict_library( name = "flexible_dtypes", srcs = ["flexible_dtypes.py"], deps = [ - ":constant_op", ":dtypes", ":ops", + ":tensor_shape", ":weak_tensor", + "//tensorflow/python/types:core", "//tensorflow/python/util:nest", "//third_party/py/numpy", ], @@ -1836,6 +1842,12 @@ py_strict_library( py_strict_library( name = "tensor", srcs = ["tensor.py"], + visibility = visibility + [ + "//tensorflow:internal", + "//tensorflow_models:__subpackages__", + "//third_party/mlperf:__subpackages__", + "//third_party/py/tf_slim:__subpackages__", + ], deps = [ ":common_shapes", ":dtypes", @@ -1955,6 +1967,7 @@ py_strict_library( ":ops", ":random_seed", ":sparse_tensor", + ":tensor", ":tensor_shape", ":tensor_util", ":tfrt_utils", @@ -2466,9 +2479,9 @@ tf_py_strict_test( ":indexed_slices", ":ops", ":sparse_tensor", + ":tensor", ":tensor_conversion_registry", ":tensor_shape", - ":tensor_spec", ":tensor_util", ":test_lib", ":test_ops", @@ -2570,8 +2583,8 @@ tf_py_strict_test( ":extension_type_field", ":immutable_dict", ":ops", + ":tensor", ":tensor_shape", - ":tensor_spec", ":test_lib", ":type_spec", ":type_spec_registry", @@ -2612,9 +2625,8 @@ tf_py_strict_test( ":constant_op", ":dtypes", ":extension_type_field", - ":ops", + ":tensor", ":tensor_shape", - ":tensor_spec", ":test_lib", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -2717,8 +2729,8 @@ tf_py_strict_test( ":errors", ":ops", ":sparse_tensor", + ":tensor", ":tensor_shape", - ":tensor_spec", ":test_lib", ":type_utils", "//tensorflow/core:protos_all_py", @@ -2830,6 +2842,7 @@ cuda_py_strict_test( ":indexed_slices", ":ops", ":random_seed", + ":tensor", ":test_lib", ":test_ops", "//tensorflow/core:protos_all_py", @@ -2921,8 +2934,8 @@ tf_py_strict_test( ":op_def_library", ":op_def_library_pybind", ":ops", + ":tensor", ":tensor_shape", - ":tensor_spec", ":test_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:def_function", @@ -3102,14 +3115,15 @@ py_strict_test( name = "flexible_dtypes_test", srcs = ["flexible_dtypes_test.py"], tags = [ - "no_pip", "no_windows", # TODO(b/286939592): Enable this test on Windows. ], deps = [ ":constant_op", ":dtypes", + ":extension_type", ":flexible_dtypes", ":ops", + ":tensor", ":weak_tensor", "//tensorflow/python/ops:variables", "//tensorflow/python/ops:weak_tensor_test_util", diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 20f314afc7b191..4004485469c33d 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -26,6 +26,7 @@ from tensorflow.python.eager import execute from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -314,7 +315,7 @@ def _constant_eager_impl(ctx, value, dtype, shape, verify_shape): def is_constant(tensor_or_op): - if isinstance(tensor_or_op, ops.Tensor): + if isinstance(tensor_or_op, tensor_lib.Tensor): op = tensor_or_op.op else: op = tensor_or_op @@ -400,7 +401,7 @@ class _ConstantTensorCodec: """Codec for Tensor.""" def can_encode(self, pyobj): - return isinstance(pyobj, ops.Tensor) + return isinstance(pyobj, tensor_lib.Tensor) def do_encode(self, tensor_value, encode_fn): """Returns an encoded `TensorProto` for the given `tf.Tensor`.""" diff --git a/tensorflow/python/framework/cpp_shape_inference.proto b/tensorflow/python/framework/cpp_shape_inference.proto index d2fd1f29f23b87..4272ed3f4dfa2b 100644 --- a/tensorflow/python/framework/cpp_shape_inference.proto +++ b/tensorflow/python/framework/cpp_shape_inference.proto @@ -9,6 +9,7 @@ import "tensorflow/core/framework/types.proto"; option cc_enable_arenas = true; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto"; +// LINT.IfChange message CppShapeInferenceResult { message HandleShapeAndType { reserved 3; @@ -34,3 +35,4 @@ message CppShapeInferenceInputsNeeded { repeated int32 input_tensors_needed = 1; repeated int32 input_tensors_as_shapes_needed = 2; } +// LINT.ThenChange(//tensorflow/core/framework/cpp_shape_inference.proto) diff --git a/tensorflow/python/framework/extension_type.py b/tensorflow/python/framework/extension_type.py index 16ae82831121b4..d83e5b3b6d401e 100644 --- a/tensorflow/python/framework/extension_type.py +++ b/tensorflow/python/framework/extension_type.py @@ -25,9 +25,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type_field from tensorflow.python.framework import immutable_dict -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry from tensorflow.python.ops import array_ops @@ -146,7 +145,7 @@ class ExtensionType( >>> class Toy(ExtensionType): ... name: str - ... price: ops.Tensor + ... price: tensor.Tensor ... features: typing.Mapping[str, tf.Tensor] >>> class ToyStore(ExtensionType): @@ -307,7 +306,7 @@ def __eq__(self, other): def __ne__(self, other): eq = self.__eq__(other) - if isinstance(eq, ops.Tensor): + if isinstance(eq, tensor.Tensor): return math_ops.logical_not(eq) else: return not eq @@ -448,7 +447,7 @@ def _to_components(self, value): # TypeSpec API. if self._tf_extension_type_is_packed: return value._tf_extension_type_packed_variant # pylint: disable=protected-access - tensor_or_composite = (ops.Tensor, composite_tensor.CompositeTensor) + tensor_or_composite = (tensor.Tensor, composite_tensor.CompositeTensor) # Retireve fields by the order of spec dict to preserve field ordering. This # is needed as nest.flatten would sort dictionary entries by key. value_tuple = tuple(value.__dict__[key] for key in self.__dict__) @@ -490,7 +489,7 @@ def _from_components(self, components): # TypeSpec API. @property def _component_specs(self): # TypeSpec API. if self._tf_extension_type_is_packed: - return tensor_spec.TensorSpec((), dtypes.variant) + return tensor.TensorSpec((), dtypes.variant) components = [] @@ -864,9 +863,9 @@ def _deserialize_for_reduce(value_type, serialization): def _replace_tensor_with_spec(value): - if isinstance(value, ops.Tensor): + if isinstance(value, tensor.Tensor): # Note: we intentionally exclude `value.name` from the `TensorSpec`. - return tensor_spec.TensorSpec(value.shape, value.dtype) + return tensor.TensorSpec(value.shape, value.dtype) if hasattr(value, '_type_spec'): return value._type_spec # pylint: disable=protected-access return value @@ -1265,7 +1264,7 @@ def _convert_anonymous_fields(value, for_spec=False): ) if ( - isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)) + isinstance(value, (tensor.Tensor, composite_tensor.CompositeTensor)) and not for_spec ): return value diff --git a/tensorflow/python/framework/extension_type_field.py b/tensorflow/python/framework/extension_type_field.py index 80774535f39421..afd84fb7d9d1e7 100644 --- a/tensorflow/python/framework/extension_type_field.py +++ b/tensorflow/python/framework/extension_type_field.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import immutable_dict from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import type_spec from tensorflow.python.util import type_annotations @@ -154,7 +155,7 @@ def validate_field_value_type(value_type, if value_type in (int, float, str, bytes, bool, None, _NoneType, dtypes.DType): return - elif (value_type in (ops.Tensor, tensor_shape.TensorShape) or + elif (value_type in (tensor.Tensor, tensor_shape.TensorShape) or (isinstance(value_type, type) and _issubclass(value_type, composite_tensor.CompositeTensor))): if in_mapping_key: @@ -287,7 +288,7 @@ def _convert_value(value, expected_type, path, if expected_type is None: expected_type = _NoneType - if expected_type is ops.Tensor: + if expected_type is tensor.Tensor: return _convert_tensor(value, path, context) elif (isinstance(expected_type, type) and _issubclass(expected_type, composite_tensor.CompositeTensor)): @@ -324,13 +325,13 @@ def _convert_tensor(value, path, context): """Converts `value` to a `Tensor`.""" if context == _ConversionContext.SPEC: if not (isinstance(value, type_spec.TypeSpec) and - value.value_type is ops.Tensor): + value.value_type is tensor.Tensor): raise TypeError( f'{"".join(path)}: expected a TensorSpec, got ' f'{type(value).__name__!r}') return value - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor.Tensor): if context == _ConversionContext.DEFAULT: # TODO(edloper): Convert the value to a numpy array? (Note: we can't just # use `np.array(value)`, since the default dtypes for TF and numpy are diff --git a/tensorflow/python/framework/extension_type_field_test.py b/tensorflow/python/framework/extension_type_field_test.py index a892ce9097df9f..f352c899d0e042 100644 --- a/tensorflow/python/framework/extension_type_field_test.py +++ b/tensorflow/python/framework/extension_type_field_test.py @@ -22,9 +22,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type_field -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -46,12 +45,12 @@ class ExtensionTypeFieldTest(test_util.TensorFlowTestCase, # Without default values: ('x', int), ('f', float), - ('t', ops.Tensor), + ('t', tensor.Tensor), # With default values: ('x', int, 33), ('y', float, 33.8), - ('t', ops.Tensor, [[1, 2], [3, 4]]), - ('t', ops.Tensor, lambda: constant_op.constant([[1, 2], [3, 4]])), + ('t', tensor.Tensor, [[1, 2], [3, 4]]), + ('t', tensor.Tensor, lambda: constant_op.constant([[1, 2], [3, 4]])), ('r', ragged_tensor.RaggedTensor, lambda: ragged_factory_ops.constant([[1, 2], [3]])), ('seq', typing.Tuple[typing.Union[int, float], ...], (33, 12.8, 9, 0)), @@ -75,7 +74,7 @@ def testConstruction( default = converted_default self.assertEqual(field.name, name) self.assertEqual(field.value_type, value_type) - if isinstance(default, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(default, (tensor.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(field.default, default) else: self.assertEqual(field.default, default) @@ -91,13 +90,13 @@ def testConstruction( ('seq', _TUPLE[typing.Union[int, float], ...], [33, 12.8, 'zero'], (r'default value for seq\[2\]: expected ' r"typing.Union\[int, float\], got 'str'")), - ('t', tensor_spec.TensorSpec(None, dtypes.int32), + ('t', tensor.TensorSpec(None, dtypes.int32), lambda: constant_op.constant(0.0), 'Unsupported type annotation TensorSpec.*'), ('x', dict, {}, "In field 'x': Unsupported type annotation 'dict'"), ('y', typing.Union[int, list], 3, "In field 'y': Unsupported type annotation 'list'"), - ('z', typing.Mapping[ops.Tensor, int], {}, + ('z', typing.Mapping[tensor.Tensor, int], {}, "In field 'z': Mapping had a key 'Tensor' with type 'type'"), ]) def testConstructionError(self, name, value_type, default, error): @@ -150,7 +149,7 @@ class ValidateFieldPyTypeTest(test_util.TensorFlowTestCase, dict(tp=type(None)), dict(tp=dtypes.DType), dict(tp=tensor_shape.TensorShape), - dict(tp=ops.Tensor), + dict(tp=tensor.Tensor), dict(tp='A', allow_forward_references=True), # Generic types dict(tp=typing.Union[int, float]), @@ -185,7 +184,7 @@ def testValidPytype(self, tp, allow_forward_references=False): error="Unsupported type annotation 'dict'"), dict(tp='A', error='Unresolved forward reference .*'), dict(tp=typing.Union[int, 'A'], error='Unresolved forward reference .*'), - dict(tp=typing.Mapping[ops.Tensor, int], + dict(tp=typing.Mapping[tensor.Tensor, int], error="Mapping had a key 'Tensor' with type 'type'"), dict( tp=typing.Mapping[tensor_shape.TensorShape, int], @@ -223,8 +222,8 @@ def testConvertFieldsMismatch(self, field_values, error): ('foo', str), (None, None), (True, bool), - ([1, 2, 3], ops.Tensor), - (lambda: constant_op.constant([1, 2, 3]), ops.Tensor), + ([1, 2, 3], tensor.Tensor), + (lambda: constant_op.constant([1, 2, 3]), tensor.Tensor), (lambda: ragged_factory_ops.constant([[1, 2], [3]]), ragged_tensor.RaggedTensor), ([1, 2, 3], typing.Tuple[int, ...], (1, 2, 3)), @@ -252,7 +251,7 @@ def testConvertValue(self, value, value_type, expected=None): if expected is None: expected = value converted = extension_type_field._convert_value(value, value_type, ('x',)) - if isinstance(converted, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(converted, (tensor.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(converted, expected) else: self.assertEqual(converted, expected) @@ -263,7 +262,7 @@ def testConvertValue(self, value, value_type, expected=None): ('foo', str), (None, None), (True, bool), - (tensor_spec.TensorSpec([5]), ops.Tensor), + (tensor.TensorSpec([5]), tensor.Tensor), (ragged_tensor.RaggedTensorSpec([5, None]), ragged_tensor.RaggedTensor), ([1, 2, 3], typing.Tuple[int, ...], (1, 2, 3)), ((1, 2, 3), typing.Tuple[int, int, int], (1, 2, 3)), @@ -292,7 +291,7 @@ def testConvertValueForSpec(self, value, value_type, expected=None): converted = extension_type_field._convert_value( value, value_type, ('x',), extension_type_field._ConversionContext.SPEC) - if isinstance(converted, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(converted, (tensor.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(converted, expected) else: self.assertEqual(converted, expected) @@ -321,14 +320,14 @@ def testConvertFields(self): 'y', typing.Tuple[typing.Union[int, bool], ...]), extension_type_field.ExtensionTypeField( 'y', _TUPLE[typing.Union[int, bool], ...]), - extension_type_field.ExtensionTypeField('z', ops.Tensor) + extension_type_field.ExtensionTypeField('z', tensor.Tensor) ] field_values = {'x': 1, 'y': [1, True, 3], 'z': [[1, 2], [3, 4], [5, 6]]} extension_type_field.convert_fields(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) - self.assertIsInstance(field_values['z'], ops.Tensor) + self.assertIsInstance(field_values['z'], tensor.Tensor) self.assertAllEqual(field_values['z'], [[1, 2], [3, 4], [5, 6]]) def testConvertFieldsForSpec(self): @@ -338,18 +337,18 @@ def testConvertFieldsForSpec(self): 'y', typing.Tuple[typing.Union[int, bool], ...]), extension_type_field.ExtensionTypeField( 'y', _TUPLE[typing.Union[int, bool], ...]), - extension_type_field.ExtensionTypeField('z', ops.Tensor) + extension_type_field.ExtensionTypeField('z', tensor.Tensor) ] field_values = { 'x': 1, 'y': [1, True, 3], - 'z': tensor_spec.TensorSpec([5, 3]) + 'z': tensor.TensorSpec([5, 3]) } extension_type_field.convert_fields_for_spec(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) - self.assertEqual(field_values['z'], tensor_spec.TensorSpec([5, 3])) + self.assertEqual(field_values['z'], tensor.TensorSpec([5, 3])) if __name__ == '__main__': diff --git a/tensorflow/python/framework/extension_type_test.py b/tensorflow/python/framework/extension_type_test.py index 0e6bccb396c7af..0169690eaf3c33 100644 --- a/tensorflow/python/framework/extension_type_test.py +++ b/tensorflow/python/framework/extension_type_test.py @@ -34,8 +34,8 @@ from tensorflow.python.framework import extension_type_field from tensorflow.python.framework import immutable_dict from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry @@ -64,8 +64,8 @@ class MaskedTensorV1(extension_type.ExtensionType): """Example subclass of ExtensionType, used for testing.""" - values: ops.Tensor - mask: ops.Tensor + values: tensor.Tensor + mask: tensor.Tensor class MaskedTensorV2(extension_type.ExtensionType): @@ -78,8 +78,8 @@ class MaskedTensorV2(extension_type.ExtensionType): __name__ = 'tf.test.MaskedTensorV2' - values: ops.Tensor - mask: ops.Tensor + values: tensor.Tensor + mask: tensor.Tensor def __repr__(self): if hasattr(self.values, 'numpy') and hasattr(self.mask, 'numpy'): @@ -117,7 +117,7 @@ def with_default(self, default): class SimpleExtensionType(extension_type.ExtensionType): - x: ops.Tensor + x: tensor.Tensor class Spec: @@ -145,8 +145,8 @@ class MaskedTensorV3(extension_type.BatchableExtensionType): __name__ = 'tf.test.MaskedTensorV3.Spec' - values: typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] - mask: typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] + values: typing.Union[tensor.Tensor, ragged_tensor.RaggedTensor] + mask: typing.Union[tensor.Tensor, ragged_tensor.RaggedTensor] def __init__(self, values, mask): if isinstance(values, ragged_tensor.RaggedTensor): @@ -182,12 +182,12 @@ class ForwardRefA(extension_type.ExtensionType): class ForwardRefB(extension_type.ExtensionType): z: 'ForwardRefB' - n: ops.Tensor + n: tensor.Tensor class ExtensionTypeWithTensorDefault(extension_type.ExtensionType): - x: ops.Tensor = 5 - y: ops.Tensor = ['a', 'b', 'c'] + x: tensor.Tensor = 5 + y: tensor.Tensor = ['a', 'b', 'c'] @test_util.run_all_in_graph_and_eager_modes @@ -198,9 +198,9 @@ def testAttributeAccessors(self): mt2 = extension_type.pack(mt1) for mt in [mt1, mt2]: - self.assertIsInstance(mt.values, ops.Tensor) + self.assertIsInstance(mt.values, tensor.Tensor) self.assertAllEqual(mt.values, [1, 2, 3, 4]) - self.assertIsInstance(mt.mask, ops.Tensor) + self.assertIsInstance(mt.mask, tensor.Tensor) self.assertAllEqual(mt.mask, [True, True, False, True]) def testAttributesAreImmutable(self): @@ -260,14 +260,16 @@ def testAsDict(self): def testConstructorSignature(self): class MyType(extension_type.ExtensionType): - x: ops.Tensor - y: ops.Tensor + x: tensor.Tensor + y: tensor.Tensor z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] expected_parameters = [ tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), - tf_inspect.Parameter('x', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), - tf_inspect.Parameter('y', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), + tf_inspect.Parameter( + 'x', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor), + tf_inspect.Parameter( + 'y', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor), tf_inspect.Parameter( 'z', POSITIONAL_OR_KEYWORD, @@ -284,7 +286,7 @@ def testConstructorSignatureWithKeywordOnlyArgs(self): class MyType(extension_type.ExtensionType): a: int b: str = 'Hello world' - c: ops.Tensor + c: tensor.Tensor expected_parameters = [ tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), @@ -292,7 +294,7 @@ class MyType(extension_type.ExtensionType): tf_inspect.Parameter( 'b', POSITIONAL_OR_KEYWORD, annotation=str, default='Hello world' ), - tf_inspect.Parameter('c', KEYWORD_ONLY, annotation=ops.Tensor), + tf_inspect.Parameter('c', KEYWORD_ONLY, annotation=tensor.Tensor), ] expected_sig = tf_inspect.Signature( expected_parameters, return_annotation=MyType @@ -314,13 +316,14 @@ def testConstructorSignatureWithDefaultForTensorField(self): def testConstructorSignatureWithAnnotatedTensorField(self): class MyType(extension_type.ExtensionType): - a: typing_extensions.Annotated[ops.Tensor, 'metadata'] + a: typing_extensions.Annotated[tensor.Tensor, 'metadata'] b: typing_extensions.Annotated[str, 'metadata'] = 'Hello world' c: typing.Optional[typing_extensions.Annotated[int, 'metadata']] = None expected_parameters = [ tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), - tf_inspect.Parameter('a', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), + tf_inspect.Parameter( + 'a', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor), tf_inspect.Parameter( 'b', POSITIONAL_OR_KEYWORD, annotation=str, default='Hello world' ), @@ -348,9 +351,9 @@ class EmptyType(extension_type.ExtensionType): def testCustomConstrutor(self): class SummarizedTensor(extension_type.ExtensionType): - values: ops.Tensor - mean: ops.Tensor - max: ops.Tensor + values: tensor.Tensor + mean: tensor.Tensor + max: tensor.Tensor def __init__(self, values): self.values = ops.convert_to_tensor(values) @@ -363,7 +366,7 @@ def __init__(self, values): self.assertAllEqual(x.max, 6) class Node(extension_type.ExtensionType): - x: ops.Tensor + x: tensor.Tensor y: typing.Optional[str] = None children: typing.Tuple['ExtensionTypeTest.Node', ...] = () @@ -402,8 +405,8 @@ def __init__(self, foo): def testCustomValidate(self): class AlignedTensors(extension_type.ExtensionType): - x: ops.Tensor - y: ops.Tensor + x: tensor.Tensor + y: tensor.Tensor def __validate__(self): self.x.shape.assert_is_compatible_with(self.y.shape) @@ -417,8 +420,8 @@ def __validate__(self): def testEquals(self): class MyType(extension_type.ExtensionType): - values: ops.Tensor - score: ops.Tensor + values: tensor.Tensor + score: tensor.Tensor flavor: str x1 = MyType([1, 2], 8, 'blue') @@ -509,8 +512,8 @@ def fn_with_side_effect(mts): def testNestPackUnpack(self): class CandyStore(extension_type.ExtensionType): - name: ops.Tensor - prices: typing.Mapping[str, ops.Tensor] + name: tensor.Tensor + prices: typing.Mapping[str, tensor.Tensor] store = CandyStore('Yum', {'gum': [0.42, 0.48], 'chocolate': [0.83, 1.02]}) components = nest.flatten(store, expand_composites=True) @@ -702,13 +705,14 @@ def body(i, x): self.assertAllEqual(y.mask, [True, False, True, False]) def testNestedFields(self): - PossiblyRaggedTensor = typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] + PossiblyRaggedTensor = typing.Union[ + tensor.Tensor, ragged_tensor.RaggedTensor] ToyFeatures = typing.Mapping[str, PossiblyRaggedTensor] class ToyInfo(extension_type.ExtensionType): version: str - toys: typing.Tuple[typing.Tuple[str, ops.Tensor, ToyFeatures], ...] - boxes: typing.Mapping[str, ops.Tensor] + toys: typing.Tuple[typing.Tuple[str, tensor.Tensor, ToyFeatures], ...] + boxes: typing.Mapping[str, tensor.Tensor] authors = [[b'A', b'Aardvark'], [b'Z', b'Zhook']] toys = [ @@ -720,10 +724,10 @@ class ToyInfo(extension_type.ExtensionType): self.assertEqual(toy_info.version, '1.0 alpha') self.assertEqual(toy_info.toys[0][0], 'car') - self.assertIsInstance(toy_info.toys[0][1], ops.Tensor) + self.assertIsInstance(toy_info.toys[0][1], tensor.Tensor) self.assertAllEqual(toy_info.toys[0][1], 1.0) self.assertEqual(set(toy_info.toys[0][2].keys()), {'size', 'color'}) - self.assertIsInstance(toy_info.toys[0][2]['size'], ops.Tensor) + self.assertIsInstance(toy_info.toys[0][2]['size'], tensor.Tensor) self.assertAllEqual(toy_info.toys[0][2]['size'], [8, 3, 2]) self.assertIsInstance( toy_info.toys[1][2]['authors'], ragged_tensor.RaggedTensor @@ -745,15 +749,15 @@ class ToyInfo(extension_type.ExtensionType): self.assertRegex(repr(toy_info), expected_repr) def testNestedExtensionTypes(self): - PossiblyMaskedTensor = typing.Union[ops.Tensor, MaskedTensorV1] + PossiblyMaskedTensor = typing.Union[tensor.Tensor, MaskedTensorV1] class Toy(extension_type.ExtensionType): name: str - price: ops.Tensor + price: tensor.Tensor features: typing.Mapping[str, PossiblyMaskedTensor] class Box(extension_type.ExtensionType): - contents: ops.Tensor + contents: tensor.Tensor class ToyInfo(extension_type.ExtensionType): version: str @@ -784,7 +788,7 @@ def fn(info): def testNestedCustomConstructor(self): class Toy(extension_type.ExtensionType): name: str - price: ops.Tensor + price: tensor.Tensor def __init__(self, name, price, discount=0): if discount: @@ -834,10 +838,10 @@ def testGetExtensionTypeFields(self): for fields in [fields_1, fields_2]: self.assertLen(fields, 2) self.assertEqual(fields[0].name, 'values') - self.assertEqual(fields[0].value_type, ops.Tensor) + self.assertEqual(fields[0].value_type, tensor.Tensor) self.assertEqual(fields[0].default, fields[0].NO_DEFAULT) self.assertEqual(fields[1].name, 'mask') - self.assertEqual(fields[1].value_type, ops.Tensor) + self.assertEqual(fields[1].value_type, tensor.Tensor) self.assertEqual(fields[1].default, fields[0].NO_DEFAULT) def testHasExtensionTypeField(self): @@ -866,7 +870,7 @@ def testForwardReferences(self): B._tf_extension_type_fields(), ( extension_type_field.ExtensionTypeField('z', B), - extension_type_field.ExtensionTypeField('n', ops.Tensor), + extension_type_field.ExtensionTypeField('n', tensor.Tensor), ), ) @@ -905,7 +909,7 @@ def testUnsupportedAnnotations(self): ): class MyType1(extension_type.ExtensionType): # pylint: disable=unused-variable - values: typing.List[ops.Tensor] + values: typing.List[tensor.Tensor] with self.assertRaisesRegex( TypeError, "In field 'xyz': Unsupported type annotation" @@ -955,8 +959,8 @@ def testExtensionTypeBaseConstructorRaisesException(self): class ExtensionTypeWithName(extension_type.ExtensionType): __name__ = 'tf.__test__.ExtensionTypeWithName' # For SavedModel - x: typing.Tuple[ops.Tensor, int] - y: ops.Tensor + x: typing.Tuple[tensor.Tensor, int] + y: tensor.Tensor def testSavedModelSupport(self): class TestModule(module.Module): @@ -985,16 +989,16 @@ def testPackedEncoding(self): mt2 = extension_type.pack(mt1) self.assertLen(nest.flatten(mt2, expand_composites=True), 1) - self.assertIsInstance(mt2.values, ops.Tensor) + self.assertIsInstance(mt2.values, tensor.Tensor) self.assertAllEqual(mt2.values, [1, 2, 3, 4]) - self.assertIsInstance(mt2.mask, ops.Tensor) + self.assertIsInstance(mt2.mask, tensor.Tensor) self.assertAllEqual(mt2.mask, [True, True, False, True]) mt3 = extension_type.unpack(mt2) self.assertLen(nest.flatten(mt3, expand_composites=True), 2) - self.assertIsInstance(mt3.values, ops.Tensor) + self.assertIsInstance(mt3.values, tensor.Tensor) self.assertAllEqual(mt3.values, [1, 2, 3, 4]) - self.assertIsInstance(mt3.mask, ops.Tensor) + self.assertIsInstance(mt3.mask, tensor.Tensor) self.assertAllEqual(mt3.mask, [True, True, False, True]) nest.assert_same_structure(mt1, mt3, expand_composites=True) @@ -1010,8 +1014,8 @@ def testPackedEncoding(self): def testSubclassing(self): class Instrument(extension_type.ExtensionType): - name: ops.Tensor - weight: ops.Tensor + name: tensor.Tensor + weight: tensor.Tensor needs_case: bool class StringInstrument(Instrument): @@ -1019,7 +1023,7 @@ class StringInstrument(Instrument): needs_case: bool = True # Override default value. class Violin(StringInstrument): - maker: ops.Tensor + maker: tensor.Tensor num_strings: int = 4 # Override default value. name: str = 'violin' # Override field type and default value. @@ -1030,10 +1034,10 @@ class Violin(StringInstrument): [ tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), tf_inspect.Parameter( - 'name', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor + 'name', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor ), tf_inspect.Parameter( - 'weight', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor + 'weight', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor ), tf_inspect.Parameter( 'needs_case', @@ -1051,14 +1055,16 @@ class Violin(StringInstrument): tf_inspect.Parameter( 'name', POSITIONAL_OR_KEYWORD, annotation=str, default='violin' ), - tf_inspect.Parameter('weight', KEYWORD_ONLY, annotation=ops.Tensor), + tf_inspect.Parameter( + 'weight', KEYWORD_ONLY, annotation=tensor.Tensor), tf_inspect.Parameter( 'needs_case', KEYWORD_ONLY, annotation=bool, default=True ), tf_inspect.Parameter( 'num_strings', KEYWORD_ONLY, annotation=int, default=4 ), - tf_inspect.Parameter('maker', KEYWORD_ONLY, annotation=ops.Tensor), + tf_inspect.Parameter( + 'maker', KEYWORD_ONLY, annotation=tensor.Tensor), ], ) @@ -1131,8 +1137,8 @@ class ExtensionTypeSpecTest( ): def testSpecConstructor(self): - values_spec = tensor_spec.TensorSpec([4], dtypes.float32) - mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) + values_spec = tensor.TensorSpec([4], dtypes.float32) + mask_spec = tensor.TensorSpec([4], dtypes.bool) mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec) self.assertEqual(mt_spec.values, values_spec) self.assertEqual(mt_spec.mask, mask_spec) @@ -1142,8 +1148,8 @@ def testSpecConstructor(self): def testSpecConstructorSignature(self): class MyType(extension_type.ExtensionType): - x: ops.Tensor - y: ops.Tensor + x: tensor.Tensor + y: tensor.Tensor z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] expected_parameters = [ @@ -1183,19 +1189,19 @@ def testSpecFromValue(self): mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True]) mt_spec = MaskedTensorV1.Spec.from_value(mt) - expected_values_spec = tensor_spec.TensorSpec([4], dtypes.float32) - expected_mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) + expected_values_spec = tensor.TensorSpec([4], dtypes.float32) + expected_mask_spec = tensor.TensorSpec([4], dtypes.bool) self.assertEqual(mt_spec.values, expected_values_spec) self.assertEqual(mt_spec.mask, expected_mask_spec) def testSpecSerialize(self): class Zoo(extension_type.ExtensionType): zookeepers: typing.Tuple[str, ...] - animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]] + animals: typing.Mapping[str, typing.Mapping[str, tensor.Tensor]] featurespec = { - 'size': tensor_spec.TensorSpec([3]), - 'weight': tensor_spec.TensorSpec([]), + 'size': tensor.TensorSpec([3]), + 'weight': tensor.TensorSpec([]), } zoo_spec = Zoo.Spec( zookeepers=['Zoey', 'Zack'], @@ -1222,7 +1228,7 @@ class Zoo(extension_type.ExtensionType): def testSpecComponents(self): class Zoo(extension_type.ExtensionType): zookeepers: typing.Tuple[str, ...] - animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]] + animals: typing.Mapping[str, typing.Mapping[str, tensor.Tensor]] zoo = Zoo( ['Zoey', 'Zack'], @@ -1247,17 +1253,17 @@ class Zoo(extension_type.ExtensionType): self.assertEqual( zoo_spec._component_specs, ( - tensor_spec.TensorSpec([3], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([3], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.float32), + tensor.TensorSpec([3], dtypes.int32), + tensor.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.float32), + tensor.TensorSpec([3], dtypes.int32), + tensor.TensorSpec([], dtypes.float32), ), ) def testCopyAndPickle(self): - values_spec = tensor_spec.TensorSpec([4], dtypes.float32) - mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) + values_spec = tensor.TensorSpec([4], dtypes.float32) + mask_spec = tensor.TensorSpec([4], dtypes.bool) mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec) self.assertEqual(copy.copy(mt_spec), mt_spec) self.assertEqual(copy.deepcopy(mt_spec), mt_spec) @@ -1273,8 +1279,8 @@ class WeightedTensor(extension_type.ExtensionType): * Add method (with_shape). """ - values: ops.Tensor - weight: ops.Tensor # scalar + values: tensor.Tensor + weight: tensor.Tensor # scalar shape = property(lambda self: self.shape) dtype = property(lambda self: self.dtype) @@ -1286,8 +1292,8 @@ def __validate__(self): class Spec: def __init__(self, shape, dtype, weight_dtype=dtypes.float32): - self.values = tensor_spec.TensorSpec(shape, dtype) - self.weight = tensor_spec.TensorSpec([], weight_dtype) + self.values = tensor.TensorSpec(shape, dtype) + self.weight = tensor.TensorSpec([], weight_dtype) def __validate__(self): self.weight.shape.assert_has_rank(0) @@ -1376,7 +1382,7 @@ def testAttributeAccessors(self, fields): s = extension_type.AnonymousExtensionType(**fields) for name, value in fields.items(): actual = getattr(s, name) - if isinstance(actual, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(actual, (tensor.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(actual, value) else: self.assertEqual(actual, value) @@ -1434,7 +1440,7 @@ def testReinterpret(self): lambda: extension_type.AnonymousExtensionType( values=constant_op.constant([1, 2, 3]) ), - ops.Tensor, + tensor.Tensor, ( 'reinterpret expects `new_type` to be a subclass of ' 'tf.ExtensionType; ' @@ -1464,8 +1470,8 @@ def f(x, y): y_mask = y.mask if isinstance(y, MaskedTensorV1) else True return MaskedTensorV1(x_values + y_values, x_mask & y_mask) - t_spec = tensor_spec.TensorSpec(None, dtypes.int32) - b_spec = tensor_spec.TensorSpec(None, dtypes.bool) + t_spec = tensor.TensorSpec(None, dtypes.int32) + b_spec = tensor.TensorSpec(None, dtypes.bool) mt_spec = MaskedTensorV1.Spec(values=t_spec, mask=b_spec) model = module.Module() model.f = def_function.function(f) @@ -1515,8 +1521,8 @@ def testFlatTensorSpecs(self): self.assertEqual( flat_specs, [ - tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32, name=None), - tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.bool, name=None), + tensor.TensorSpec(shape=(2,), dtype=dtypes.int32, name=None), + tensor.TensorSpec(shape=(2,), dtype=dtypes.bool, name=None), ], ) @@ -1546,7 +1552,7 @@ def testToLegacyOutputShapeMissing(self): def replace_tensors_with_placeholders(value): def repl(x): - if isinstance(x, ops.Tensor): + if isinstance(x, tensor.Tensor): return array_ops.placeholder_with_default(x, shape=None) else: return x diff --git a/tensorflow/python/framework/flexible_dtypes.py b/tensorflow/python/framework/flexible_dtypes.py index 5d5abc970ebbe1..909f731faa96e9 100644 --- a/tensorflow/python/framework/flexible_dtypes.py +++ b/tensorflow/python/framework/flexible_dtypes.py @@ -19,6 +19,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import weak_tensor +from tensorflow.python.framework.tensor_shape import TensorShape +from tensorflow.python.types import core as core_types from tensorflow.python.util import nest # PromoMode Enum that denotes safe and all mode. @@ -372,6 +374,32 @@ def _initialize(): ) +def _is_acceptable_input_type(x): + """Determines if x is an acceptable input type for auto dtype conversion semantics.""" + acceptable_types = [ + core_types.Tensor, + core_types.TensorProtocol, + int, + float, + bool, + str, + bytes, + complex, + tuple, + list, + np.ndarray, + np.generic, + dtypes.DType, + np.dtype, + TensorShape, + weak_tensor.WeakTensor, + ] + for t in acceptable_types: + if isinstance(x, t): + return True + return False + + def _get_dtype_and_weakness(x): """Returns a TF type and weak type information from x. @@ -438,12 +466,22 @@ def _result_type_impl(*arrays_and_dtypes): TypeError: when the promotion between the input dtypes is disabled in the current mode + + NotImplementedError: when arrays_and_dtypes contains an unsupported input + type (e.g. CompositeTensor). """ promo_safety_mode = ops.get_dtype_conversion_mode() - # Drop None inputs. - valid_arrays_and_dtypes = [ - inp for inp in arrays_and_dtypes if inp is not None - ] + # Drop None inputs and check if input type is supported. + valid_arrays_and_dtypes = [] + for inp in arrays_and_dtypes: + if inp is not None: + if _is_acceptable_input_type(inp): + valid_arrays_and_dtypes.append(inp) + else: + raise NotImplementedError( + 'Auto dtype conversion semantics does not support' + f' {type(inp)} type.' + ) dtypes_and_is_weak = [ _get_dtype_and_weakness(x) for x in nest.flatten(valid_arrays_and_dtypes) diff --git a/tensorflow/python/framework/flexible_dtypes_test.py b/tensorflow/python/framework/flexible_dtypes_test.py index dfd5ec3022621b..5205e9e9dd16d0 100644 --- a/tensorflow/python/framework/flexible_dtypes_test.py +++ b/tensorflow/python/framework/flexible_dtypes_test.py @@ -19,8 +19,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import extension_type from tensorflow.python.framework import flexible_dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import weak_tensor from tensorflow.python.ops import variables from tensorflow.python.ops import weak_tensor_test_util @@ -862,6 +864,18 @@ def testResultTypeEmptyInput(self): self.assertEqual(dtype, dtypes.float32) self.assertTrue(is_weak) + def testResultTypeUnsupportedInputType(self): + class MyTensor(extension_type.ExtensionType): + value: tensor.Tensor + + with DtypeConversionTestEnv('all'): + a = MyTensor(constant_op.constant(1)) + with self.assertRaisesRegex( + NotImplementedError, + f'Auto dtype conversion semantics does not support {type(a)} type.', + ): + _ = flexible_dtypes.result_type(a) + # Test v1 + v2 = v2 + v1. def testCommunicativity(self): with DtypeConversionTestEnv('all'): diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index b09c176e80eb78..632a0022ea2a12 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs @@ -706,7 +707,7 @@ def __call__(self, *args, **kwargs): args = list(args) for (i, x) in enumerate(args): x = ops.convert_to_tensor(x) - if not isinstance(x, ops.Tensor): + if not isinstance(x, tensor_lib.Tensor): raise ValueError(f"Expected a Tensor but got {x} with type {type(x)}.") input_types.append(x.dtype) args[i] = x diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index acef1db0607759..359e9f4f99af08 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import control_flow_util from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated_args @@ -194,7 +195,7 @@ def _ConvertInputMapValues(name, input_map): Raises: ValueError: if input map values cannot be converted due to empty name scope. """ - if not all(isinstance(v, ops.Tensor) for v in input_map.values()): + if not all(isinstance(v, tensor.Tensor) for v in input_map.values()): if name == '': # pylint: disable=g-explicit-bool-comparison raise ValueError( 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index b2ad0cf659dedb..1ac4f9b73460da 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import op_def_library_pybind from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import _pywrap_utils @@ -368,7 +369,7 @@ def _CanExtractAttrsFastPath(op_def, keywords): # Check if all inputs are already tf.Tensor for input_arg in op_def.input_arg: value = keywords.get(input_arg.name, None) - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor.Tensor): return False # Check that attrs are not `func` or `list(func)` type. @@ -452,7 +453,7 @@ def _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, dtype = attrs[input_arg.type_attr] else: for t in values: - if isinstance(t, ops.Tensor): + if isinstance(t, tensor.Tensor): dtype = t.dtype break diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py index 2f835ae47f9cf9..f77c85643524c7 100644 --- a/tensorflow/python/framework/op_def_library_test.py +++ b/tensorflow/python/framework/op_def_library_test.py @@ -25,8 +25,8 @@ from tensorflow.python.framework import op_def_library from tensorflow.python.framework import op_def_library_pybind from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest from tensorflow.python.util import compat @@ -417,7 +417,7 @@ def fn(x): def testAttrFuncWithFuncWithAttrs(self): with ops.Graph().as_default(): @def_function.function( - input_signature=(tensor_spec.TensorSpec(None, dtypes.float32),), + input_signature=(tensor.TensorSpec(None, dtypes.float32),), autograph=False, experimental_attributes={"_implements": 15}) def fn(x): @@ -1334,7 +1334,7 @@ def testStructuredOutputListAndSingle(self): self.assertIsInstance(a, list) self.assertEqual(n_a, len(a)) self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) - self.assertIsInstance(b, ops.Tensor) + self.assertIsInstance(b, tensor.Tensor) self.assertEqual(dtypes.float32, b.dtype) def testStructuredOutputMultipleLists(self): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index de09610cce1ded..e8f1545c0546bd 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -234,8 +234,6 @@ def value_text(tensor, is_repr=False): return text -enable_tensor_equality = tensor_lib.enable_tensor_equality -disable_tensor_equality = tensor_lib.disable_tensor_equality Tensor = tensor_lib.Tensor @@ -1600,8 +1598,8 @@ def experimental_set_type(self, type_proto): raise ValueError("error setting the type of ", self.name, ": expected TFT_UNSET or TFT_PRODUCT, got ", type_proto.type_id) - pywrap_tf_session.SetFullType(c_graph, self._c_op, - type_proto.SerializeToString()) # pylint:disable=protected-access + with c_api_util.tf_buffer(type_proto.SerializeToString()) as serialized: + pywrap_tf_session.SetFullType(c_graph, self._c_op, serialized) # pylint:disable=protected-access def run(self, feed_dict=None, session=None): """Runs this operation in a `Session`. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index aec2433043ddf6..70b61f699d768e 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -44,9 +44,9 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util @@ -589,27 +589,27 @@ def testSerialize(self, spec, expected): @parameterized.parameters([ (indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), ( - tensor_spec.TensorSpec(None, dtypes.string), - tensor_spec.TensorSpec([None], dtypes.int64), + tensor_lib.TensorSpec(None, dtypes.string), + tensor_lib.TensorSpec([None], dtypes.int64), )), (indexed_slices.IndexedSlicesSpec( dtype=dtypes.string, dense_shape_dtype=dtypes.int32), ( - tensor_spec.TensorSpec(None, dtypes.string), - tensor_spec.TensorSpec([None], dtypes.int64), - tensor_spec.TensorSpec([None], dtypes.int32), + tensor_lib.TensorSpec(None, dtypes.string), + tensor_lib.TensorSpec([None], dtypes.int64), + tensor_lib.TensorSpec([None], dtypes.int32), )), (indexed_slices.IndexedSlicesSpec( shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), ( - tensor_spec.TensorSpec([None, 10, 15], dtypes.float32), - tensor_spec.TensorSpec([None], dtypes.int64), - tensor_spec.TensorSpec([3], dtypes.int32), + tensor_lib.TensorSpec([None, 10, 15], dtypes.float32), + tensor_lib.TensorSpec([None], dtypes.int64), + tensor_lib.TensorSpec([3], dtypes.int32), )), (indexed_slices.IndexedSlicesSpec( shape=[5, 10, 15], dense_shape_dtype=dtypes.int32, indices_shape=[20]), ( - tensor_spec.TensorSpec([20, 10, 15], dtypes.float32), - tensor_spec.TensorSpec([20], dtypes.int64), - tensor_spec.TensorSpec([3], dtypes.int32), + tensor_lib.TensorSpec([20, 10, 15], dtypes.float32), + tensor_lib.TensorSpec([20], dtypes.int64), + tensor_lib.TensorSpec([3], dtypes.int32), )), ]) def testComponentSpecs(self, spec, expected): @@ -1447,10 +1447,10 @@ def testNodeDefArgs(self): g, "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], name="myop3") - self.assertTrue(isinstance(t1, ops.Tensor)) + self.assertTrue(isinstance(t1, tensor_lib.Tensor)) self.assertTrue(isinstance(t2, list)) self.assertTrue(isinstance(t3, list)) - self.assertTrue(isinstance(t3[0], ops.Tensor)) + self.assertTrue(isinstance(t3[0], tensor_lib.Tensor)) self.assertEqual("myop1", t1._as_node_def_input()) self.assertEqual("myop2", t2[0]._as_node_def_input()) self.assertEqual("myop2:1", t2[1]._as_node_def_input()) @@ -2333,8 +2333,8 @@ def testMembershipAllowed(self): g = ops.Graph() t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") - self.assertTrue(isinstance(t1, ops.Tensor)) - self.assertTrue(isinstance(t2, ops.Tensor)) + self.assertTrue(isinstance(t1, tensor_lib.Tensor)) + self.assertTrue(isinstance(t2, tensor_lib.Tensor)) self.assertTrue(t1 in [t1]) self.assertTrue(t1 not in [t2]) @@ -3623,7 +3623,7 @@ def testCompositeTensorConversion(self): self.assertIsInstance(y, _TupleTensor) self.assertLen(y, len(x)) for x_, y_ in zip(x, y): - self.assertIsInstance(y_, ops.Tensor) + self.assertIsInstance(y_, tensor_lib.Tensor) self.assertTrue(tensor_util.is_tf_type(y_)) self.assertAllEqual(x_, tensor_util.constant_value(y_)) @@ -3681,7 +3681,7 @@ def setUpInputShapes(self, pre_add_input_shapes): test_tensor_shape = [None, 1, 1, 1] @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=test_tensor_shape, dtype=dtypes.float32) + tensor_lib.TensorSpec(shape=test_tensor_shape, dtype=dtypes.float32) ]) def f(x): return array_ops.identity(x, name="output") diff --git a/tensorflow/python/framework/python_api_dispatcher_test.py b/tensorflow/python/framework/python_api_dispatcher_test.py index f179cbc9eceb31..a4ddb620f09aeb 100644 --- a/tensorflow/python/framework/python_api_dispatcher_test.py +++ b/tensorflow/python/framework/python_api_dispatcher_test.py @@ -19,7 +19,7 @@ from tensorflow.python.framework import _pywrap_python_api_dispatcher as dispatch from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -47,7 +47,7 @@ def testInstanceChecker(self): self.assertEqual(repr(int_checker), '') with self.subTest('tensor checker'): - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) self.assertEqual(tensor_checker.Check(t), MATCH) self.assertEqual(tensor_checker.Check(3), NO_MATCH) self.assertEqual(tensor_checker.Check(3.0), NO_MATCH) @@ -119,7 +119,7 @@ def testUnionChecker(self): float_checker = dispatch.MakeInstanceChecker(float) str_checker = dispatch.MakeInstanceChecker(str) none_checker = dispatch.MakeInstanceChecker(type(None)) - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) ragged_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) t = constant_op.constant([1, 2, 3]) @@ -159,7 +159,7 @@ def testUnionChecker(self): def testListChecker(self): int_checker = dispatch.MakeInstanceChecker(int) - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) ragged_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) np_int_checker = dispatch.MakeInstanceChecker(np.integer) @@ -269,7 +269,7 @@ def testSimpleSignature(self): def testUnion(self): rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) rt_or_tensor = dispatch.MakeUnionChecker([rt_checker, tensor_checker]) checker = dispatch.PySignatureChecker([(0, rt_or_tensor), (1, rt_or_tensor)]) @@ -383,7 +383,7 @@ def testListAndUnionDispatch(self): (None,)) rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) rt_or_t = dispatch.MakeUnionChecker([rt_checker, tensor_checker]) list_of_rt_or_t = dispatch.MakeListChecker(rt_or_t) f1 = lambda x, ys, name=None: 'f1' diff --git a/tensorflow/python/framework/python_api_parameter_converter_test.py b/tensorflow/python/framework/python_api_parameter_converter_test.py index 9787b6c0c53478..e6a4c705195aa4 100644 --- a/tensorflow/python/framework/python_api_parameter_converter_test.py +++ b/tensorflow/python/framework/python_api_parameter_converter_test.py @@ -23,7 +23,7 @@ from tensorflow.python.framework import _pywrap_python_api_info from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.framework._pywrap_python_api_parameter_converter import Convert @@ -82,7 +82,7 @@ def assertParamsEqual(self, actual_params, expected_params): self.assertParamEqual(actual, expected) def assertParamEqual(self, actual, expected): - if isinstance(actual, ops.Tensor): + if isinstance(actual, tensor.Tensor): self.assertAllEqual(actual, expected) else: self.assertEqual(actual, expected) diff --git a/tensorflow/python/framework/python_tensor_converter_test.py b/tensorflow/python/framework/python_tensor_converter_test.py index 413770b973b7d5..3257b8e7de91f4 100644 --- a/tensorflow/python/framework/python_tensor_converter_test.py +++ b/tensorflow/python/framework/python_tensor_converter_test.py @@ -24,7 +24,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest @@ -47,7 +47,7 @@ def makePythonTensorConverter(self): def testConvertIntWithInferredDType(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, 12) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -55,7 +55,7 @@ def testConvertIntWithInferredDType(self): def testConvertIntWithExplicitDtype(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INT64) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, 12) self.assertEqual(dtype, types_pb2.DT_INT64) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -74,7 +74,7 @@ def testConvertTensorWithInferredDType(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert( constant_op.constant([1, 2, 3]), types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [1, 2, 3]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertFalse(used_fallback) @@ -83,7 +83,7 @@ def testConvertTensorWithExplicitDtype(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert( constant_op.constant([1, 2, 3], dtypes.int64), types_pb2.DT_INT64) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [1, 2, 3]) self.assertEqual(dtype, types_pb2.DT_INT64) self.assertFalse(used_fallback) @@ -101,7 +101,7 @@ def testConvertListWithInferredDType(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]], types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -110,7 +110,7 @@ def testConvertListWithExplicitDtype(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]], types_pb2.DT_INT64) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) self.assertEqual(dtype, types_pb2.DT_INT64) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -137,7 +137,7 @@ def testConvertNumpyArrayWithInferredDType(self): converter = self.makePythonTensorConverter() x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -146,7 +146,7 @@ def testConvertNumpyArrayWithExplicitDtype(self): converter = self.makePythonTensorConverter() x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT64) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) self.assertEqual(dtype, types_pb2.DT_INT64) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -173,7 +173,7 @@ def testConvertIndexedSlicesWithInferredDType(self): constant_op.constant([1], dtypes.int64, name="x_indices"), constant_op.constant([3, 3], dtypes.int64, name="x_shape")) result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertTrue(used_fallback) @@ -185,7 +185,7 @@ def testConvertIndexedSlicesWithExplicitDtype(self): constant_op.constant([1], dtypes.int64, name="x_indices"), constant_op.constant([3, 3], dtypes.int64, name="x_shape")) result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT32) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertTrue(used_fallback) diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py index 67708b3aece98a..efaee2c1549111 100644 --- a/tensorflow/python/framework/smart_cond.py +++ b/tensorflow/python/framework/smart_cond.py @@ -14,7 +14,7 @@ # ============================================================================== """smart_cond and related utilities.""" -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_case @@ -70,7 +70,7 @@ def smart_constant_value(pred): Raises: TypeError: If `pred` is not a Tensor or bool. """ - if isinstance(pred, ops.Tensor): + if isinstance(pred, tensor.Tensor): pred_value = tensor_util.constant_value(pred) # TODO(skyewm): consider folding this into tensor_util.constant_value. # pylint: disable=protected-access diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index deded692a80f58..fbd3fdc881aa29 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -24,8 +24,8 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.framework.type_utils import fulltypes_for_flat_tensors from tensorflow.python.ops import array_ops @@ -158,7 +158,7 @@ def test_simple(self): sp = sparse_tensor.SparseTensor(indices, values, dense_shape) self.assertIsInstance(sp.shape, tensor_shape.TensorShape) - self.assertIsInstance(sp.dense_shape, ops.Tensor) + self.assertIsInstance(sp.dense_shape, tensor_lib.Tensor) self.assertEqual(sp.shape.as_list(), [5, 5]) def test_unknown_shape(self): @@ -172,7 +172,7 @@ def my_func(dense_shape): return sp my_func.get_concrete_function( - dense_shape=tensor_spec.TensorSpec( + dense_shape=tensor_lib.TensorSpec( dtype=dtypes.int64, shape=[2,])) def test_partial_shape(self): @@ -188,7 +188,7 @@ def my_func(x): return sp my_func.get_concrete_function( - x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[])) + x=tensor_lib.TensorSpec(dtype=dtypes.int64, shape=[])) def test_neg_shape(self): indices = [[0, 2]] @@ -211,7 +211,7 @@ def my_func(x): return sp my_func.get_concrete_function( - x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None, None])) + x=tensor_lib.TensorSpec(dtype=dtypes.int64, shape=[None, None])) def test_unknown_rank(self): @@ -224,7 +224,7 @@ def my_func(dense_shape): return sp my_func.get_concrete_function( - dense_shape=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None])) + dense_shape=tensor_lib.TensorSpec(dtype=dtypes.int64, shape=[None])) @test_util.run_all_in_graph_and_eager_modes @@ -266,14 +266,14 @@ def testSerialize(self, st_spec, expected): @parameterized.parameters([ (sparse_tensor.SparseTensorSpec(dtype=dtypes.string), [ - tensor_spec.TensorSpec([None, None], dtypes.int64), - tensor_spec.TensorSpec([None], dtypes.string), - tensor_spec.TensorSpec([None], dtypes.int64) + tensor_lib.TensorSpec([None, None], dtypes.int64), + tensor_lib.TensorSpec([None], dtypes.string), + tensor_lib.TensorSpec([None], dtypes.int64) ]), (sparse_tensor.SparseTensorSpec(shape=[5, None, None]), [ - tensor_spec.TensorSpec([None, 3], dtypes.int64), - tensor_spec.TensorSpec([None], dtypes.float32), - tensor_spec.TensorSpec([3], dtypes.int64) + tensor_lib.TensorSpec([None, 3], dtypes.int64), + tensor_lib.TensorSpec([None], dtypes.float32), + tensor_lib.TensorSpec([3], dtypes.int64) ]), ]) def testComponentSpecs(self, st_spec, expected): @@ -331,7 +331,7 @@ def testFromNumpyComponents(self): ]) def testFlatTensorSpecs(self, st_spec): self.assertEqual(st_spec._flat_tensor_specs, - [tensor_spec.TensorSpec(None, dtypes.variant)]) + [tensor_lib.TensorSpec(None, dtypes.variant)]) @parameterized.parameters([ dtypes.float32, diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py index e68412b4982df3..3e48388542930b 100644 --- a/tensorflow/python/framework/subscribe.py +++ b/tensorflow/python/framework/subscribe.py @@ -18,6 +18,7 @@ import re from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -42,7 +43,7 @@ def _recursive_apply(tensors, apply_fn): `TypeError` if undefined type in the tensors structure. """ tensors_type = type(tensors) - if isinstance(tensors, ops.Tensor): + if isinstance(tensors, tensor_lib.Tensor): return apply_fn(tensors) elif isinstance(tensors, variables.Variable): return apply_fn(tensors.value()) @@ -171,7 +172,9 @@ def _subscribe_extend(tensor, side_effects): for s in side_effects: outs += s(source_tensor) - out_ops = [out.op if isinstance(out, ops.Tensor) else out for out in outs] + out_ops = [ + out.op if isinstance(out, tensor_lib.Tensor) else out for out in outs + ] tensor.op._add_control_inputs(out_ops) # pylint: disable=protected-access return tensor diff --git a/tensorflow/python/framework/tensor.py b/tensorflow/python/framework/tensor.py index 1a5cb7df766fc0..61dbae3417036c 100644 --- a/tensorflow/python/framework/tensor.py +++ b/tensorflow/python/framework/tensor.py @@ -925,7 +925,7 @@ class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec, >>> tf.TensorSpec.from_tensor(t) TensorSpec(shape=(2, 3), dtype=tf.int32, name=None) - Contains metadata for describing the the nature of `tf.Tensor` objects + Contains metadata for describing the nature of `tf.Tensor` objects accepted or returned by some TensorFlow APIs. For example, it can be used to constrain the type of inputs accepted by diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index f4dc3b4e43a2e3..836c116506f8d2 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -933,7 +933,7 @@ def bar(self, partial): or None if it cannot be calculated. Raises: - TypeError: if tensor is not an ops.Tensor. + TypeError: if tensor is not an tensor.Tensor. """ if isinstance(tensor, core.Value): try: diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 1727db9016c890..874bd544c773a2 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -60,6 +60,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tfrt_utils @@ -376,7 +377,7 @@ def NHWCToNCHW(input_tensor): """ # tensor dim -> new axis order new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor_lib.Tensor): ndims = input_tensor.shape.ndims return array_ops.transpose(input_tensor, new_axes[ndims]) else: @@ -401,7 +402,7 @@ def NHWCToNCHW_VECT_C(input_shape_or_tensor): divisible by 4. """ permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} - is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) + is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor) temp_shape = ( input_shape_or_tensor.shape.as_list() if is_tensor else input_shape_or_tensor) @@ -435,7 +436,7 @@ def NCHW_VECT_CToNHWC(input_shape_or_tensor): ValueError: if last dimension of `input_shape_or_tensor` is not 4. """ permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} - is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) + is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor) input_shape = ( input_shape_or_tensor.shape.as_list() if is_tensor else input_shape_or_tensor) @@ -462,7 +463,7 @@ def NCHWToNHWC(input_tensor): """ # tensor dim -> new axis order new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor_lib.Tensor): ndims = input_tensor.shape.ndims return array_ops.transpose(input_tensor, new_axes[ndims]) else: @@ -806,7 +807,7 @@ def decorator(self, **kwargs): def _is_tensorflow_object(obj): try: return isinstance(obj, - (ops.Tensor, variables.Variable, + (tensor_lib.Tensor, variables.Variable, tensor_shape.Dimension, tensor_shape.TensorShape)) except (ReferenceError, AttributeError): # If the object no longer exists, we don't care about it. @@ -1545,7 +1546,7 @@ def decorated(*args, **kwds): tensor_args = [] tensor_indices = [] for i, arg in enumerate(args): - if isinstance(arg, (ops.Tensor, variables.Variable)): + if isinstance(arg, (tensor_lib.Tensor, variables.Variable)): tensor_args.append(arg) tensor_indices.append(i) @@ -2699,11 +2700,14 @@ def evaluate(self, tensors): return self._eval_helper(tensors) else: sess = ops.get_default_session() + flattened_tensors = nest.flatten(tensors) if sess is None: with self.test_session() as sess: - return sess.run(tensors) + flattened_results = sess.run(flattened_tensors) else: - return sess.run(tensors) + flattened_results = sess.run(flattened_tensors) + + return nest.pack_sequence_as(tensors, flattened_results) # pylint: disable=g-doc-return-or-yield @contextlib.contextmanager @@ -3583,18 +3587,18 @@ def assertShapeEqual(self, input_a, input_b, msg=None): Raises: TypeError: If the arguments have the wrong type. """ - if not isinstance(input_a, (np.ndarray, np.generic, ops.Tensor)): + if not isinstance(input_a, (np.ndarray, np.generic, tensor_lib.Tensor)): raise TypeError( "input_a must be a Numpy ndarray, Numpy scalar, or a Tensor." f"Instead received {type(input_a)}") - if not isinstance(input_b, (np.ndarray, np.generic, ops.Tensor)): + if not isinstance(input_b, (np.ndarray, np.generic, tensor_lib.Tensor)): raise TypeError( "input_b must be a Numpy ndarray, Numpy scalar, or a Tensor." f"Instead received {type(input_b)}") shape_a = input_a.get_shape().as_list() if isinstance( - input_a, ops.Tensor) else input_a.shape + input_a, tensor_lib.Tensor) else input_a.shape shape_b = input_b.get_shape().as_list() if isinstance( - input_b, ops.Tensor) else input_b.shape + input_b, tensor_lib.Tensor) else input_b.shape self.assertAllEqual(shape_a, shape_b, msg=msg) def assertDeviceEqual(self, device1, device2, msg=None): @@ -3641,7 +3645,7 @@ def _GetPyList(self, a): """Converts `a` to a nested python list.""" if isinstance(a, ragged_tensor.RaggedTensor): return self.evaluate(a).to_list() - elif isinstance(a, ops.Tensor): + elif isinstance(a, tensor_lib.Tensor): a = self.evaluate(a) return a.tolist() if isinstance(a, np.ndarray) else a elif isinstance(a, np.ndarray): diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index ec5f75625d7582..eddfdb5f35e4cf 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -16,6 +16,7 @@ import collections import copy +import dataclasses import random import sys import threading @@ -41,6 +42,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -56,6 +58,26 @@ from tensorflow.python.util.protobuf import compare_test_pb2 +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: ops.Tensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + def __eq__(self, other): + return self.mask == other.mask and self.value == other.value + + class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): def test_assert_ops_in_graph(self): @@ -917,6 +939,21 @@ def test_nested_tensors_evaluate(self): self.assertEqual(expected, self.evaluate(nested)) + @test_util.run_in_graph_and_eager_modes + def test_custom_dataclass_evaluate(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + mt_val = self.evaluate(mt) + self.assertEqual(mt_val.mask, True) + self.assertAllEqual(mt_val.value, [1]) + + mt2 = MaskedTensor(mask=True, value=constant_op.constant([1])) + mt2_val = self.evaluate(mt2) + self.assertEqual(mt_val, mt2_val) + + mt3 = MaskedTensor(mask=True, value=constant_op.constant([2])) + mt3_val = self.evaluate(mt3) + self.assertNotEqual(mt_val, mt3_val) + def test_run_in_graph_and_eager_modes(self): l = [] def inc(self, with_brackets): @@ -1199,9 +1236,9 @@ def add_two(x): if run_eagerly: self.assertTrue(isinstance(t, ops.EagerTensor) for t in results) else: - self.assertTrue(isinstance(t, ops.Tensor) for t in results) + self.assertTrue(isinstance(t, tensor.Tensor) for t in results) else: - self.assertTrue(isinstance(t, ops.Tensor) for t in results) + self.assertTrue(isinstance(t, tensor.Tensor) for t in results) class SyncDevicesTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/framework/weak_tensor.py b/tensorflow/python/framework/weak_tensor.py index a36d57c58f01f3..6b3f456028972d 100644 --- a/tensorflow/python/framework/weak_tensor.py +++ b/tensorflow/python/framework/weak_tensor.py @@ -25,6 +25,9 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_conversion_registry + _ALLOWED_WEAK_DTYPES = ( dtypes.int32, @@ -62,7 +65,7 @@ class WeakTensor(extension_type.ExtensionType): # __name__ is required for serialization in SavedModel. __name__ = "tf.WeakTensor" - tensor: ops.Tensor + tensor: tensor_lib.Tensor def __validate__(self): if self.tensor.dtype not in _ALLOWED_WEAK_DTYPES: @@ -164,6 +167,16 @@ def to_tensor(self): """Converts this 'WeakTensor' into a 'tf.Tensor'.""" return self.tensor + def numpy(self): + """Copy of the contents of this WeakTensor into a NumPy array or scalar.""" + if not isinstance(self.tensor, ops.EagerTensor): + raise ValueError("WeakTensor.numpy() is only supported in eager mode.") + return self.tensor.numpy() + + def _as_graph_element(self): + """Convert `self` to a graph element.""" + return self.tensor + @classmethod def from_tensor(cls, tensor): """Converts a 'tf.Tensor' into a 'WeakTensor'.""" @@ -179,6 +192,10 @@ def dtype(self): def shape(self): return self.tensor.shape + @property + def is_tensor_like(self): + return True + __composite_gradient__ = WeakTensorGradient() @@ -201,3 +218,19 @@ def __next__(self): result = WeakTensor(self._weak_tensor.tensor[self._index]) self._index += 1 return result + + +def maybe_convert_to_weak_tensor(t, is_weak): + return WeakTensor(t) if is_weak else t + + +# convert_to_tensor(WeakTensor) should return a WeakTensor because WeakTensor is +# a 'Tensor' with a special dtype. +def weak_tensor_conversion_function(t): + if isinstance(t, WeakTensor): + return t + + +tensor_conversion_registry.register_tensor_conversion_function( + WeakTensor, weak_tensor_conversion_function +) diff --git a/tensorflow/python/framework/weak_tensor_test.py b/tensorflow/python/framework/weak_tensor_test.py index 9e8229e0541899..8fec36d82e3fe4 100644 --- a/tensorflow/python/framework/weak_tensor_test.py +++ b/tensorflow/python/framework/weak_tensor_test.py @@ -22,7 +22,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.framework import weak_tensor from tensorflow.python.module import module @@ -157,7 +157,7 @@ def f(c, a, b): def test_weak_tensor_in_tf_func_with_spec(self): # Test weak tensor spec with matching input. - weak_tensor_spec = weak_tensor.WeakTensor.Spec(tensor_spec.TensorSpec([2])) + weak_tensor_spec = weak_tensor.WeakTensor.Spec(tensor.TensorSpec([2])) wt = weak_tensor.WeakTensor(constant_op.constant([1.0, 2.0])) @def_function.function(input_signature=[weak_tensor_spec]) @@ -185,8 +185,8 @@ class CustomModule(module.Module): @def_function.function def __call__(self, x): - if isinstance(x, ops.Tensor): - raise TypeError('Weak tensor should not be ops.Tensor type.') + if isinstance(x, tensor.Tensor): + raise TypeError('Weak tensor should not be tensor.Tensor type.') return x m = CustomModule() diff --git a/tensorflow/python/grappler/BUILD b/tensorflow/python/grappler/BUILD index 229152816da5cb..26125785138c5d 100644 --- a/tensorflow/python/grappler/BUILD +++ b/tensorflow/python/grappler/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") -load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_cloud", "tf_py_strict_test", "tf_pybind_cc_library_wrapper", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_portable", "tf_py_strict_test", "tf_pybind_cc_library_wrapper", "tf_python_pybind_extension") load("//tensorflow/core/platform:build_config.bzl", "tf_protos_grappler") load("//tensorflow:tensorflow.bzl", "if_not_windows") @@ -15,7 +15,7 @@ cc_library( name = "cost_analyzer_lib", srcs = ["cost_analyzer.cc"], hdrs = ["cost_analyzer.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 0bf6466b963526..c271a5ef77a784 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -55,14 +55,14 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/client", "//tensorflow/python/client:session", - "//tensorflow/python/distribute:distribute_coordinator", "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/framework", + "//tensorflow/python/framework:composite_tensor", + "//tensorflow/python/framework:config", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/keras/distribute:distribute_coordinator_utils", @@ -159,6 +159,8 @@ py_library( deps = [ ":backend", "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/keras/distribute:distributed_file_utils", "//tensorflow/python/keras/distribute:worker_training_state", "//tensorflow/python/keras/protobuf:projector_config_proto_py", diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index bf012e88e4313f..18ebf20ab0eff3 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -43,9 +43,9 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend_config from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc @@ -213,7 +213,7 @@ def cast_to_floatx(x): dtype('float32') """ - if isinstance(x, (ops.Tensor, + if isinstance(x, (tensor_lib.Tensor, variables_module.Variable, sparse_tensor.SparseTensor)): return math_ops.cast(x, dtype=floatx()) @@ -672,9 +672,9 @@ def _current_graph(op_input_list, graph=None): # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this # up. if (isinstance(op_input, ( - ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and - ((not isinstance(op_input, ops.Tensor)) - or type(op_input) == ops.Tensor)): # pylint: disable=unidiomatic-typecheck + ops.Operation, tensor_lib.Tensor, composite_tensor.CompositeTensor)) and + ((not isinstance(op_input, tensor_lib.Tensor)) + or type(op_input) == tensor_lib.Tensor)): # pylint: disable=unidiomatic-typecheck graph_element = op_input else: graph_element = _as_graph_element(op_input) @@ -1266,7 +1266,7 @@ def is_keras_tensor(x): """ if not isinstance(x, - (ops.Tensor, variables_module.Variable, + (tensor_lib.Tensor, variables_module.Variable, sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor, keras_tensor.KerasTensor)): raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) + @@ -1339,7 +1339,7 @@ def placeholder(shape=None, spec = ragged_tensor.RaggedTensorSpec( shape=shape, dtype=dtype, ragged_rank=ragged_rank) else: - spec = tensor_spec.TensorSpec( + spec = tensor_lib.TensorSpec( shape=shape, dtype=dtype, name=name) x = keras_tensor.keras_tensor_from_type_spec(spec, name=name) else: @@ -3859,7 +3859,7 @@ def print_tensor(x, message='', summarize=3): Returns: The same tensor `x`, unchanged. """ - if isinstance(x, ops.Tensor) and hasattr(x, 'graph'): + if isinstance(x, tensor_lib.Tensor) and hasattr(x, 'graph'): with get_graph().as_default(): op = logging_ops.print_v2( message, x, output_stream=sys.stdout, summarize=summarize) @@ -4423,7 +4423,7 @@ def compute_masked_output(mask_t, flat_out, flat_mask): return tuple( array_ops.where_v2(m, o, fm) for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)) - elif isinstance(input_length, ops.Tensor): + elif isinstance(input_length, tensor_lib.Tensor): if go_backwards: max_len = math_ops.reduce_max(input_length, axis=0) rev_input_length = math_ops.subtract(max_len - 1, input_length) @@ -4476,7 +4476,7 @@ def _step(time, output_ta_t, prev_output, *states): flat_state = nest.flatten(states) flat_new_state = nest.flatten(new_states) for state, new_state in zip(flat_state, flat_new_state): - if isinstance(new_state, ops.Tensor): + if isinstance(new_state, tensor_lib.Tensor): new_state.set_shape(state.shape) flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state) @@ -4513,7 +4513,7 @@ def _step(time, output_ta_t, *states): flat_state = nest.flatten(states) flat_new_state = nest.flatten(new_states) for state, new_state in zip(flat_state, flat_new_state): - if isinstance(new_state, ops.Tensor): + if isinstance(new_state, tensor_lib.Tensor): new_state.set_shape(state.shape) flat_output = nest.flatten(output) @@ -4536,7 +4536,7 @@ def _step(time, output_ta_t, *states): # static shape inference def set_shape(output_): - if isinstance(output_, ops.Tensor): + if isinstance(output_, tensor_lib.Tensor): shape = output_.shape.as_list() shape[0] = time_steps shape[1] = batch diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 9d414ceb1488c6..c5cbd1873c3058 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.keras import backend from tensorflow.python.keras.distribute import distributed_file_utils from tensorflow.python.keras.distribute import worker_training_state @@ -1965,10 +1966,10 @@ def on_epoch_begin(self, epoch, logs=None): lr = self.schedule(epoch, lr) except TypeError: # Support for old API for backward compatibility lr = self.schedule(epoch) - if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)): + if not isinstance(lr, (tensor_lib.Tensor, float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function ' 'should be float.') - if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating: + if isinstance(lr, tensor_lib.Tensor) and not lr.dtype.is_floating: raise ValueError('The dtype of Tensor should be float') backend.set_value(self.model.optimizer.lr, backend.get_value(lr)) if self.verbose > 0: diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 51e77dc218925e..2098b1650bc920 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -54,6 +54,8 @@ py_library( "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute/coordinator:cluster_coordinator", "//tensorflow/python/eager:monitoring", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/keras:activations", "//tensorflow/python/keras:backend", @@ -130,6 +132,7 @@ py_library( ":input_spec", ":node", "//third_party/py/numpy", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:compat", @@ -188,7 +191,7 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/keras/utils:dataset_creator", "//tensorflow/python/keras/utils:engine_utils", @@ -221,7 +224,7 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/keras/utils:object_identity", "//tensorflow/python/lib/io:lib", "//tensorflow/python/util:nest", @@ -238,6 +241,7 @@ py_library( ":base_layer", "//tensorflow/python/data", "//tensorflow/python/eager:monitoring", + "//tensorflow/python/framework:tensor", "//tensorflow/python/keras:backend", "//tensorflow/python/module", ], @@ -250,6 +254,7 @@ py_library( deps = [ ":base_layer_utils", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/keras:backend", "//tensorflow/python/keras/utils:tf_utils", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 48e0862fe1b05a..6b97d9ce9904ed 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -39,8 +39,8 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import constraints @@ -91,7 +91,7 @@ # TODO(mdan): Should we have a single generic type for types that can be passed # to tf.cast? -_AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor, +_AUTOCAST_TYPES = (tensor_lib.Tensor, sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor) @@ -822,7 +822,7 @@ def compute_output_signature(self, input_signature): TypeError: If input_signature contains a non-TensorSpec object. """ def check_type_return_shape(s): - if not isinstance(s, tensor_spec.TensorSpec): + if not isinstance(s, tensor_lib.TensorSpec): raise TypeError('Only TensorSpec signature types are supported, ' 'but saw signature entry: {}.'.format(s)) return s.shape @@ -835,7 +835,7 @@ def check_type_return_shape(s): # dtype. dtype = input_dtypes[0] return nest.map_structure( - lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), + lambda s: tensor_lib.TensorSpec(dtype=dtype, shape=s), output_shape) def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs): @@ -847,7 +847,7 @@ def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs): # TODO(fchollet): consider py_func as an alternative, which # would enable us to run the underlying graph if needed. input_signature = nest.map_structure( - lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype), + lambda x: tensor_lib.TensorSpec(shape=x.shape, dtype=x.dtype), inputs) output_signature = self.compute_output_signature(input_signature) return nest.map_structure(keras_tensor.KerasTensor, output_signature) diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 6836bdfd9eeecd..3cb10b362125b4 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -32,8 +32,8 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import constraints @@ -601,7 +601,7 @@ def compute_output_signature(self, input_signature): TypeError: If input_signature contains a non-TensorSpec object. """ def check_type_return_shape(s): - if not isinstance(s, tensor_spec.TensorSpec): + if not isinstance(s, tensor.TensorSpec): raise TypeError('Only TensorSpec signature types are supported, ' 'but saw signature entry: {}.'.format(s)) return s.shape @@ -614,7 +614,7 @@ def check_type_return_shape(s): # dtype. dtype = input_dtypes[0] return nest.map_structure( - lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), + lambda s: tensor.TensorSpec(dtype=dtype, shape=s), output_shape) @generic_utils.default @@ -1815,15 +1815,15 @@ def _maybe_cast_inputs(self, inputs): dtypes.as_dtype(compute_dtype).is_floating): def f(x): """Cast a single Tensor or TensorSpec to the compute dtype.""" - cast_types = (ops.Tensor, sparse_tensor.SparseTensor, + cast_types = (tensor.Tensor, sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor) if (isinstance(x, cast_types) and x.dtype.is_floating and x.dtype.base_dtype.name != compute_dtype): return math_ops.cast(x, compute_dtype) - elif isinstance(x, tensor_spec.TensorSpec) and x.dtype.is_floating: + elif isinstance(x, tensor.TensorSpec) and x.dtype.is_floating: # Inputs may be TensorSpecs when this function is called from # model._set_inputs. - return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name) + return tensor.TensorSpec(x.shape, compute_dtype, x.name) else: return x return nest.map_structure(f, inputs) diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py index 4aef248622d49c..0d29c21c83dcbb 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.keras import backend from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine.base_layer import Layer @@ -468,7 +469,7 @@ def convert_to_list(values, sparse_default_value=None): values, default_value=sparse_default_value) values = backend.get_value(dense_tensor) - if isinstance(values, ops.Tensor): + if isinstance(values, tensor.Tensor): values = backend.get_value(values) # We may get passed a ndarray or the code above may give us a ndarray. diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 3119c98c29ee94..50a58757df34bc 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -32,9 +32,9 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend @@ -566,7 +566,7 @@ def _is_composite(v): return _is_scipy_sparse(v) def _is_tensor_or_composite(v): - if isinstance(v, (ops.Tensor, np.ndarray)): + if isinstance(v, (tensor.Tensor, np.ndarray)): return True return _is_composite(v) @@ -1460,7 +1460,7 @@ def expand_1d(data): def _expand_single_1d_tensor(t): # Leaves `CompositeTensor`s as-is. - if (isinstance(t, ops.Tensor) and + if (isinstance(t, tensor.Tensor) and isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1): return array_ops.expand_dims_v2(t, axis=-1) return t @@ -1669,9 +1669,9 @@ def _get_tensor_types(): try: import pandas as pd # pylint: disable=g-import-not-at-top - return (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame) + return (tensor.Tensor, np.ndarray, pd.Series, pd.DataFrame) except ImportError: - return (ops.Tensor, np.ndarray) + return (tensor.Tensor, np.ndarray) def _is_scipy_sparse(x): diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 4181094edf7540..920f934eb43dfd 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -23,6 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils @@ -602,7 +603,7 @@ def _flatten_to_reference_inputs(self, tensors): def _conform_to_reference_input(self, tensor, ref_input): """Set shape and dtype based on `keras.Input`s.""" - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): # Allow (None,) and (None, 1) Tensors to be passed interchangeably. Use # the shape specified by the `keras.Input`. t_shape = tensor.shape diff --git a/tensorflow/python/keras/engine/keras_tensor.py b/tensorflow/python/keras/engine/keras_tensor.py index a8878466367bcb..03936288362def 100644 --- a/tensorflow/python/keras/engine/keras_tensor.py +++ b/tensorflow/python/keras/engine/keras_tensor.py @@ -16,10 +16,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec as type_spec_module from tensorflow.python.keras.utils import object_identity from tensorflow.python.ops import array_ops @@ -143,7 +142,7 @@ def shape(self): @classmethod def from_tensor(cls, tensor): """Convert a traced (composite)tensor to a representative KerasTensor.""" - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): name = getattr(tensor, 'name', None) type_spec = type_spec_module.type_spec_from_value(tensor) inferred_value = None @@ -304,7 +303,7 @@ def __str__(self): def __repr__(self): symbolic_description = '' inferred_value_string = '' - if isinstance(self.type_spec, tensor_spec.TensorSpec): + if isinstance(self.type_spec, tensor_lib.TensorSpec): type_spec_string = 'shape=%s dtype=%s' % (self.shape, self.dtype.name) else: type_spec_string = 'type_spec=%s' % self.type_spec @@ -361,7 +360,7 @@ def name(self): @classmethod def _overload_all_operators(cls, tensor_class): # pylint: disable=invalid-name """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: cls._overload_operator(tensor_class, operator) # We include `experimental_ref` for versions of TensorFlow that @@ -389,7 +388,7 @@ def _overload_operator(cls, tensor_class, operator): # pylint: disable=invalid- setattr(cls, operator, tensor_oper) -KerasTensor._overload_all_operators(ops.Tensor) # pylint: disable=protected-access +KerasTensor._overload_all_operators(tensor_lib.Tensor) # pylint: disable=protected-access class SparseKerasTensor(KerasTensor): @@ -556,11 +555,11 @@ def __next__(self): # 1. we do a check w/ isinstance because a key lookup based on class # would miss subclasses # 2. a list allows us to control lookup ordering -# We include ops.Tensor -> KerasTensor in the first position as a fastpath, +# We include tensor.Tensor -> KerasTensor in the first position as a fastpath, # *and* include object -> KerasTensor at the end as a catch-all. # We can re-visit these choices in the future as needed. keras_tensor_classes = [ - (ops.Tensor, KerasTensor), + (tensor_lib.Tensor, KerasTensor), (sparse_tensor.SparseTensor, SparseKerasTensor), (ragged_tensor.RaggedTensor, RaggedKerasTensor), (object, KerasTensor) diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py index 657d41840fe6e1..c4d409a74d7d96 100644 --- a/tensorflow/python/keras/engine/node.py +++ b/tensorflow/python/keras/engine/node.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer_utils @@ -80,7 +81,7 @@ def __init__(self, if not ops.executing_eagerly_outside_functions(): # Create TensorFlowOpLayers if needed (in TF1) for obj in self._flat_arguments: - if (isinstance(obj, ops.Tensor) and + if (isinstance(obj, tensor_lib.Tensor) and base_layer_utils.needs_keras_history( obj, ignore_call_context=True)): base_layer_utils.create_keras_history(obj) @@ -178,7 +179,7 @@ def _serialize_keras_tensor(t): if isinstance(t, np.ndarray): return t.tolist() - if isinstance(t, ops.Tensor): + if isinstance(t, tensor_lib.Tensor): return backend.get_value(t).tolist() return t diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index a8216ffe65c5ec..56fcbaaeb4e4bc 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks as callbacks_module @@ -2908,7 +2909,8 @@ def _multi_worker_concat(v, strategy): def _is_scalar(x): - return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 + return isinstance( + x, (tensor_lib.Tensor, variables.Variable)) and x.shape.rank == 0 def write_scalar_summaries(logs, step): diff --git a/tensorflow/python/keras/layers/legacy_rnn/BUILD b/tensorflow/python/keras/layers/legacy_rnn/BUILD index 54cdae02fd8fd6..cc69a90723249d 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/BUILD +++ b/tensorflow/python/keras/layers/legacy_rnn/BUILD @@ -26,6 +26,7 @@ py_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py index ed41c9f2b196a4..b7bcf9483180a3 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -126,7 +127,7 @@ def _concat(prefix, suffix, static=False): ValueError: if prefix or suffix was `None` and asked for dynamic Tensors out. """ - if isinstance(prefix, ops.Tensor): + if isinstance(prefix, tensor.Tensor): p = prefix p_static = tensor_util.constant_value(prefix) if p.shape.ndims == 0: @@ -140,7 +141,7 @@ def _concat(prefix, suffix, static=False): p = ( constant_op.constant(p.as_list(), dtype=dtypes.int32) if p.is_fully_defined() else None) - if isinstance(suffix, ops.Tensor): + if isinstance(suffix, tensor.Tensor): s = suffix s_static = tensor_util.constant_value(suffix) if s.shape.ndims == 0: diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index 0971812a4a1c86..070e32f7c500f2 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -44,7 +44,10 @@ py_library( "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", - "//tensorflow/python/framework", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:indexed_slices", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:backend_config", diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py index 74428c719bb547..87c3b543578973 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py @@ -15,7 +15,7 @@ """SGD optimizer implementation.""" # pylint: disable=g-classes-have-attributes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -108,7 +108,8 @@ def __init__(self, self._set_hyper("decay", self._initial_decay) self._momentum = False - if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: + if isinstance( + momentum, tensor.Tensor) or callable(momentum) or momentum > 0: self._momentum = True if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): raise ValueError("`momentum` must be between [0, 1].") diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index b00af22388d534..755363545e21ae 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import initializers @@ -684,7 +685,7 @@ def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): def apply_grad_to_update_var(var, grad): """Apply gradient to variable.""" - if isinstance(var, ops.Tensor): + if isinstance(var, tensor.Tensor): raise NotImplementedError("Trying to update a Tensor ", var) apply_kwargs = {} @@ -787,7 +788,7 @@ def _set_hyper(self, name, value): prev_value = self._hyper[name] if (callable(prev_value) or isinstance(prev_value, - (ops.Tensor, int, float, + (tensor.Tensor, int, float, learning_rate_schedule.LearningRateSchedule)) or isinstance(value, learning_rate_schedule.LearningRateSchedule)): self._hyper[name] = value @@ -965,8 +966,8 @@ def _create_hypers(self): with self._distribution_strategy_scope(): # Iterate hyper values deterministically. for name, value in sorted(self._hyper.items()): - if isinstance(value, - (ops.Tensor, tf_variables.Variable)) or callable(value): + if isinstance( + value, (tensor.Tensor, tf_variables.Variable)) or callable(value): # The check for `callable` covers the usage when `value` is a # `LearningRateSchedule`, in which case it does not need to create a # variable. diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py index a0d9d07febe452..f752c41eeaf903 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py @@ -18,6 +18,7 @@ import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.keras import backend_config from tensorflow.python.keras.optimizer_v2 import optimizer_v2 @@ -139,7 +140,8 @@ def __init__(self, self._set_hyper("rho", rho) self._momentum = False - if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: + if isinstance( + momentum, tensor.Tensor) or callable(momentum) or momentum > 0: self._momentum = True if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): raise ValueError("`momentum` must be between [0, 1].") diff --git a/tensorflow/python/keras/saving/utils_v1/BUILD b/tensorflow/python/keras/saving/utils_v1/BUILD index 411e567ae6cc74..b94009e93e52e1 100644 --- a/tensorflow/python/keras/saving/utils_v1/BUILD +++ b/tensorflow/python/keras/saving/utils_v1/BUILD @@ -39,6 +39,7 @@ py_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:gfile", "//tensorflow/python/platform:tf_logging", diff --git a/tensorflow/python/keras/saving/utils_v1/export_output.py b/tensorflow/python/keras/saving/utils_v1/export_output.py index e6a595bf5acaf3..4ad09a95a2cc9a 100644 --- a/tensorflow/python/keras/saving/utils_v1/export_output.py +++ b/tensorflow/python/keras/saving/utils_v1/export_output.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.keras.saving.utils_v1 import signature_def_utils as unexported_signature_utils from tensorflow.python.saved_model import signature_def_utils @@ -86,7 +87,7 @@ def _wrap_and_check_outputs( for key, value in outputs.items(): error_name = error_label or single_output_default_name key = self._check_output_key(key, error_name) - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor.Tensor): raise ValueError( '{} output value must be a Tensor; got {}.'.format( error_name, value)) @@ -128,12 +129,12 @@ def __init__(self, scores=None, classes=None): `Tensor` with the correct dtype. """ if (scores is not None - and not (isinstance(scores, ops.Tensor) + and not (isinstance(scores, tensor.Tensor) and scores.dtype.is_floating)): raise ValueError('Classification scores must be a float32 Tensor; ' 'got {}'.format(scores)) if (classes is not None - and not (isinstance(classes, ops.Tensor) + and not (isinstance(classes, tensor.Tensor) and dtypes.as_dtype(classes.dtype) == dtypes.string)): raise ValueError('Classification classes must be a string Tensor; ' 'got {}'.format(classes)) @@ -175,7 +176,7 @@ def __init__(self, value): Raises: ValueError: if the value is not a `Tensor` with dtype tf.float32. """ - if not (isinstance(value, ops.Tensor) and value.dtype.is_floating): + if not (isinstance(value, tensor.Tensor) and value.dtype.is_floating): raise ValueError('Regression output value must be a float32 Tensor; ' 'got {}'.format(value)) self._value = value @@ -334,7 +335,7 @@ def _wrap_and_check_metrics(self, metrics): val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX - if not isinstance(metric_val, ops.Tensor): + if not isinstance(metric_val, tensor.Tensor): raise ValueError( '{} output value must be a Tensor; got {}.'.format( key, metric_val)) @@ -347,7 +348,7 @@ def _wrap_and_check_metrics(self, metrics): # We must wrap any ops (or variables) in a Tensor before export, as the # SignatureDef proto expects tensors only. See b/109740581 metric_op_tensor = metric_op - if not isinstance(metric_op, ops.Tensor): + if not isinstance(metric_op, tensor.Tensor): with ops.control_dependencies([metric_op]): metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index af18b3fbecbb1e..3763a24e60657d 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -60,6 +60,7 @@ py_library( ":generic_utils", ":io_utils", ":tf_inspect", + "//tensorflow/python/framework:tensor", ], ) @@ -94,7 +95,7 @@ py_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:smart_cond", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:control_flow_ops", @@ -151,11 +152,13 @@ py_library( ], srcs_version = "PY3", deps = [ + ":engine_utils", ":generic_utils", ":tf_utils", "//tensorflow/python/distribute", - "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", diff --git a/tensorflow/python/keras/utils/control_flow_util.py b/tensorflow/python/keras/utils/control_flow_util.py index 0730cd6bc77978..067570eb6d64b0 100644 --- a/tensorflow/python/keras/utils/control_flow_util.py +++ b/tensorflow/python/keras/utils/control_flow_util.py @@ -17,8 +17,8 @@ This file is copied from tensorflow/python/ops/control_flow_util.py. """ -from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond as smart_module +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import cond from tensorflow.python.ops import variables @@ -124,7 +124,7 @@ def constant_value(pred): # pylint: disable=invalid-name TypeError: If `pred` is not a Variable, Tensor or bool, or Python integer 1 or 0. """ - if isinstance(pred, ops.Tensor): + if isinstance(pred, tensor.Tensor): return tensor_util.constant_value(pred) if pred in {0, 1}: # Accept 1/0 as valid boolean values return bool(pred) diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index f0ba68db8c4365..c9c4be339e1e9b 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -36,7 +36,7 @@ import numpy as np -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from six.moves.urllib.request import urlopen from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils.generic_utils import Progbar @@ -89,7 +89,7 @@ def chunk_read(response, chunk_size=8192, reporthook=None): def is_generator_or_sequence(x): """Check if `x` is a Keras generator type.""" builtin_iterators = (str, list, tuple, dict, set, frozenset) - if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators): + if isinstance(x, (tensor.Tensor, np.ndarray) + builtin_iterators): return False return (tf_inspect.isgenerator(x) or isinstance(x, Sequence) or diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py index cc1621b826aee4..fde05826279ecf 100644 --- a/tensorflow/python/keras/utils/metrics_utils.py +++ b/tensorflow/python/keras/utils/metrics_utils.py @@ -24,6 +24,7 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.keras import backend from tensorflow.python.keras.utils import losses_utils @@ -148,7 +149,7 @@ def decorated(metric_obj, *args): # Results need to be wrapped in a `tf.identity` op to ensure # correct execution order. if isinstance(raw_result, - (ops.Tensor, variables_module.Variable, float, int)): + (tensor.Tensor, variables_module.Variable, float, int)): result_t = array_ops.identity(raw_result) elif isinstance(raw_result, dict): result_t = { diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 34ec293514c66c..91c1aab5cdbada 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -24,8 +24,8 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.keras import backend as K @@ -42,7 +42,7 @@ def is_tensor_or_tensor_list(v): v = nest.flatten(v) - if v and isinstance(v[0], ops.Tensor): + if v and isinstance(v[0], tensor_lib.Tensor): return True else: return False @@ -314,7 +314,7 @@ def is_symbolic_tensor(tensor): Returns: True for symbolic tensors, False for eager tensors. """ - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): return hasattr(tensor, 'graph') elif is_extension_type(tensor): component_tensors = nest.flatten(tensor, expand_composites=True) @@ -378,7 +378,7 @@ def type_spec_from_value(value): # Get a TensorSpec for array-like data without # converting the data to a Tensor if hasattr(value, 'shape') and hasattr(value, 'dtype'): - return tensor_spec.TensorSpec(value.shape, value.dtype) + return tensor_lib.TensorSpec(value.shape, value.dtype) else: return type_spec.type_spec_from_value(value) @@ -477,7 +477,7 @@ def get_tensor_spec(t, dynamic_batch=False, name=None): hasattr(t._keras_history[0], '_type_spec')): return t._keras_history[0]._type_spec elif hasattr(t, 'shape') and hasattr(t, 'dtype'): - spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) + spec = tensor_lib.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) else: return None # Allow non-Tensors to pass through. @@ -521,7 +521,7 @@ def sync_to_numpy_or_python_type(tensors): return tensors.fetch() def _to_single_numpy_or_python_type(t): - if isinstance(t, ops.Tensor): + if isinstance(t, tensor_lib.Tensor): x = t.numpy() return x.item() if np.ndim(x) == 0 else x return t # Don't turn ragged or sparse tensors to NumPy. diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index 771439070ec402..d5a07b25752942 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -23,10 +23,12 @@ cuda_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:config", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/ops:array_ops", @@ -216,8 +218,11 @@ cuda_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/framework", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:gradient_checker", diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py index 7090bd7a3c49ae..2cfeb07d558154 100644 --- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py @@ -32,8 +32,8 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -90,7 +90,7 @@ def testNonBatchMatrixDynamicallyDefined(self): expected_transposed = [[1, 4], [2, 5], [3, 6]] # Shape (3, 2) @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=None, dtype=dtypes.int32) ]) def transpose(matrix): self.assertIs(matrix.shape.ndims, None) @@ -109,7 +109,7 @@ def testBatchMatrixDynamicallyDefined(self): expected_transposed = [matrix_0_t, matrix_1_t] # Shape (2, 3, 2) @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=None, dtype=dtypes.int32) ]) def transpose(matrix): self.assertIs(matrix.shape.ndims, None) @@ -244,8 +244,8 @@ def func(ph_tensor, ph_mask): return array_ops.boolean_mask(ph_tensor, ph_mask) f = func.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.int32), - tensor_spec.TensorSpec([None], dtypes.bool)) + tensor_lib.TensorSpec(None, dtypes.int32), + tensor_lib.TensorSpec([None], dtypes.bool)) arr = np.array([[1, 2], [3, 4]], np.int32) mask = np.array([False, True]) masked_tensor = f(arr, mask) @@ -260,8 +260,8 @@ def func(tensor, mask): with self.assertRaisesRegex(ValueError, "dimensions must be specified"): _ = func.get_concrete_function( - tensor_spec.TensorSpec([None, 2], dtypes.int32), - tensor_spec.TensorSpec(None, dtypes.bool)) + tensor_lib.TensorSpec([None, 2], dtypes.int32), + tensor_lib.TensorSpec(None, dtypes.bool)) def testMaskHasMoreDimsThanTensorRaises(self): mask = [[True, True], [False, False]] @@ -314,7 +314,7 @@ def testMaskWithAxisNonConstTensor(self): @def_function.function( autograph=False, input_signature=[ - tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=None, dtype=dtypes.int32) ]) def f(axis): return array_ops.boolean_mask([1, 2, 3], [True, False, True], axis=axis) @@ -595,10 +595,12 @@ def casts_to_bool_nparray(x): except NotImplementedError: return False - if isinstance(spec, bool) or \ - (isinstance(spec, ops.Tensor) and spec.dtype == dtypes.bool) or \ - (isinstance(spec, np.ndarray) and spec.dtype == bool) or \ - (isinstance(spec, (list, tuple)) and casts_to_bool_nparray(spec)): + if ( + isinstance(spec, bool) + or (isinstance(spec, tensor_lib.Tensor) and spec.dtype == dtypes.bool) + or (isinstance(spec, np.ndarray) and spec.dtype == bool) + or (isinstance(spec, (list, tuple)) and casts_to_bool_nparray(spec)) + ): tensor = self.test.evaluate(op) np_spec = eval_if_tensor(spec) self.test.assertAllEqual(self.x_np[np_spec], tensor) @@ -753,7 +755,7 @@ def func(inp): return inp[array_ops.newaxis, :, 0] f = func.get_concrete_function( - tensor_spec.TensorSpec([2, 2], dtypes.int16)) + tensor_lib.TensorSpec([2, 2], dtypes.int16)) # TODO(b/190416665): Allow the constant to be eagerly copied/created on # the GPU. @@ -892,7 +894,7 @@ def f(x): y = x[...] self.assertAllEqual(y.get_shape().ndims, None) - _ = f.get_concrete_function(tensor_spec.TensorSpec(None, dtypes.float32)) + _ = f.get_concrete_function(tensor_lib.TensorSpec(None, dtypes.float32)) def testScalarInput(self): c = constant_op.constant(3) @@ -916,7 +918,7 @@ def f1(x): tensor_shape.TensorShape([2, None, 7])) _ = f1.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f2(x): @@ -925,7 +927,7 @@ def f2(x): None])) _ = f2.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f3(x): @@ -934,7 +936,7 @@ def f3(x): None])) _ = f3.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f4(x): @@ -943,7 +945,7 @@ def f4(x): tensor_shape.TensorShape([2, None, 2])) _ = f4.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f5(x): @@ -952,7 +954,7 @@ def f5(x): tensor_shape.TensorShape([2, None, 0])) _ = f5.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f6(x): @@ -961,7 +963,7 @@ def f6(x): tensor_shape.TensorShape([2, None, 1, 0])) _ = f6.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f7(x): @@ -970,7 +972,7 @@ def f7(x): tensor_shape.TensorShape([2, None, 1, 0])) _ = f7.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f8(x): @@ -979,7 +981,7 @@ def f8(x): tensor_shape.TensorShape([2, None, 1, 0])) _ = f8.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f9(x): @@ -988,7 +990,7 @@ def f9(x): tensor_shape.TensorShape([1, None, 1, 0])) _ = f9.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f10(x): @@ -997,7 +999,7 @@ def f10(x): tensor_shape.TensorShape([5, None, 1, 4])) _ = f10.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) def testTensorValuedIndexShape(self): with self.session(): @@ -1008,8 +1010,8 @@ def f1(x, y): self.tensorShapeEqual(z.get_shape(), tensor_shape.TensorShape([3, 7])) _ = f1.get_concrete_function( - tensor_spec.TensorSpec((5, 3, 7)), - tensor_spec.TensorSpec((), dtypes.int32)) + tensor_lib.TensorSpec((5, 3, 7)), + tensor_lib.TensorSpec((), dtypes.int32)) @def_function.function def f2(x, y): @@ -1017,8 +1019,8 @@ def f2(x, y): self.tensorShapeEqual(z.get_shape(), tensor_shape.TensorShape([3, 7])) _ = f2.get_concrete_function( - tensor_spec.TensorSpec((5, 3, 7)), - tensor_spec.TensorSpec((), dtypes.int32)) + tensor_lib.TensorSpec((5, 3, 7)), + tensor_lib.TensorSpec((), dtypes.int32)) @def_function.function def f3(x, y): @@ -1026,8 +1028,8 @@ def f3(x, y): self.tensorShapeEqual(z.get_shape(), tensor_shape.TensorShape([2, 7])) _ = f3.get_concrete_function( - tensor_spec.TensorSpec((5, 3, 7)), - tensor_spec.TensorSpec((), dtypes.int32)) + tensor_lib.TensorSpec((5, 3, 7)), + tensor_lib.TensorSpec((), dtypes.int32)) @def_function.function def f4(x, y, s): @@ -1036,9 +1038,9 @@ def f4(x, y, s): 7])) _ = f4.get_concrete_function( - tensor_spec.TensorSpec((5, 3, 7)), - tensor_spec.TensorSpec((), dtypes.int32), - tensor_spec.TensorSpec((), dtypes.int32)) + tensor_lib.TensorSpec((5, 3, 7)), + tensor_lib.TensorSpec((), dtypes.int32), + tensor_lib.TensorSpec((), dtypes.int32)) class GradSliceChecker(object): @@ -1076,7 +1078,7 @@ def __getitem__(self, spec): # compute analytic gradient for slice np_val_grad = (2 * self.varnp * self.varnp) np_sliceval_grad = np.zeros(self.var.get_shape()) - if isinstance(spec, ops.Tensor): + if isinstance(spec, tensor_lib.Tensor): spec = self.test.evaluate(spec) np_sliceval_grad[spec] = np_val_grad[spec] # verify gradient @@ -1615,7 +1617,7 @@ def testIdentityVariable(self): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(v.initializer) result = array_ops.identity(v) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor_lib.Tensor) self.assertAllEqual(result, v) @@ -2387,8 +2389,8 @@ def func(params, indices): params=params, indices=indices, batch_dims=batch_dims) # pylint: disable=cell-var-from-loop f = func.get_concrete_function( - tensor_spec.TensorSpec(params_ph_shape, dtypes.float32), - tensor_spec.TensorSpec(indices_ph_shape, dtypes.int32)) + tensor_lib.TensorSpec(params_ph_shape, dtypes.float32), + tensor_lib.TensorSpec(indices_ph_shape, dtypes.int32)) params_val = np.ones(dtype=np.float32, shape=params_shape) indices_val = np.ones(dtype=np.int32, shape=indices_shape) @@ -2419,7 +2421,7 @@ def testRepeat(self, array, repeats, axis): array = np.array(array) @def_function.function( - input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)] * 2) + input_signature=[tensor_lib.TensorSpec(None, dtypes.int32)] * 2) def repeat_fn(array, repeats): return array_ops.repeat(array, repeats, axis) @@ -2560,7 +2562,7 @@ def stop_gradient_f(x): y = stop_gradient_f(x) self.assertIsNone(tape.gradient(y, x)) # stop_gradient converts ResourceVariable to Tensor - self.assertIsInstance(y, ops.Tensor) + self.assertIsInstance(y, tensor_lib.Tensor) self.assertAllEqual(y, x) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/array_ops/constant_op_test.py b/tensorflow/python/kernel_tests/array_ops/constant_op_test.py index a580075f3fa3d7..eb974c2ddc342d 100644 --- a/tensorflow/python/kernel_tests/array_ops/constant_op_test.py +++ b/tensorflow/python/kernel_tests/array_ops/constant_op_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import importer from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -296,7 +297,7 @@ def testAsTensorForTensorInput(self): def testAsTensorForNonTensorInput(self): with ops.Graph().as_default(): x = ops.convert_to_tensor(10.0) - self.assertTrue(isinstance(x, ops.Tensor)) + self.assertTrue(isinstance(x, tensor.Tensor)) def testAsTensorForShapeInput(self): with self.cached_session(): @@ -381,7 +382,7 @@ def testIdTensor(self): with ops.Graph().as_default(): x = constant_op.constant(2.0, shape=[6], name="input") id_op = array_ops.identity(x, name="id") - self.assertTrue(isinstance(id_op.op.inputs[0], ops.Tensor)) + self.assertTrue(isinstance(id_op.op.inputs[0], tensor.Tensor)) self.assertProtoEquals("name: 'id' op: 'Identity' input: 'input' " "attr { key: 'T' value { type: DT_FLOAT } }", id_op.op.node_def) diff --git a/tensorflow/python/kernel_tests/array_ops/inplace_ops_test.py b/tensorflow/python/kernel_tests/array_ops/inplace_ops_test.py index dae8b333f8dfcf..1a7ec203c9467e 100644 --- a/tensorflow/python/kernel_tests/array_ops/inplace_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/inplace_ops_test.py @@ -24,10 +24,21 @@ from tensorflow.python.platform import test as test_lib +BASIC_TYPES = [ + dtypes.float32, + dtypes.int8, + dtypes.uint8, + dtypes.int32, + dtypes.int64, + dtypes.uint64, + dtypes.bfloat16, +] + + class InplaceOpsTest(test_util.TensorFlowTestCase): def testBasicUpdate(self): - for dtype in [dtypes.float32, dtypes.int32, dtypes.int64, dtypes.bfloat16]: + for dtype in BASIC_TYPES: with test_util.use_gpu(): x = array_ops.ones([7, 3], dtype) y = np.ones([7, 3], dtype.as_numpy_dtype) @@ -61,7 +72,7 @@ def testBasicUpdateBool(self): self.assertAllClose(x, y) def testBasicAdd(self): - for dtype in [dtypes.float32, dtypes.int32, dtypes.int64, dtypes.bfloat16]: + for dtype in BASIC_TYPES: with test_util.use_gpu(): x = array_ops.ones([7, 3], dtype) y = np.ones([7, 3], dtype.as_numpy_dtype) @@ -80,7 +91,7 @@ def testBasicAdd(self): self.assertAllClose(x, y) def testBasicSub(self): - for dtype in [dtypes.float32, dtypes.int32, dtypes.int64, dtypes.bfloat16]: + for dtype in BASIC_TYPES: with test_util.use_gpu(): x = array_ops.ones([7, 3], dtype) y = np.ones([7, 3], dtype.as_numpy_dtype) @@ -196,7 +207,7 @@ def testInplaceOpOnEmptyTensors(self): inplace_ops.inplace_sub, inplace_ops.inplace_update, ] - for dtype in [dtypes.float32, dtypes.int32, dtypes.int64]: + for dtype in BASIC_TYPES: for op_fn in op_fns: with test_util.use_gpu(): x = array_ops.zeros([7, 0], dtype) diff --git a/tensorflow/python/kernel_tests/control_flow/BUILD b/tensorflow/python/kernel_tests/control_flow/BUILD index 74691d56c614d6..c8e29030a9dc73 100644 --- a/tensorflow/python/kernel_tests/control_flow/BUILD +++ b/tensorflow/python/kernel_tests/control_flow/BUILD @@ -72,11 +72,14 @@ cuda_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:wrap_function", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:function", "//tensorflow/python/framework:indexed_slices", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_gen", "//tensorflow/python/ops:array_ops_stack", diff --git a/tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py index d1d79afa66a176..f7e8f846a8076b 100644 --- a/tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py @@ -42,8 +42,8 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -177,7 +177,7 @@ def testRefIdentity(self): op = state_ops.assign(v, 9) v2 = control_flow_ops.with_dependencies([op], v) - self.assertTrue(isinstance(v2, ops.Tensor)) + self.assertTrue(isinstance(v2, tensor_lib.Tensor)) self.evaluate(variables.global_variables_initializer()) self.assertEqual(9, self.evaluate(v2)) @@ -2331,8 +2331,8 @@ def testWhileShapeInvariantTensorSpec(self): c = lambda i, _: i < 10 b = lambda i, x: (i + 1, array_ops_stack.stack([x, x])) shape_invariants = [ - tensor_spec.TensorSpec([], dtype=dtypes.int32), - tensor_spec.TensorSpec(None, dtype=dtypes.int32)] + tensor_lib.TensorSpec([], dtype=dtypes.int32), + tensor_lib.TensorSpec(None, dtype=dtypes.int32)] while_loop_tf.while_loop(c, b, [i, x], shape_invariants) # TODO(b/131265085) Remove this decorator when bug is fixed. @@ -2343,7 +2343,7 @@ def testWhileShapeInvariantWrongTypeSpecType(self): i = constant_op.constant(0) x = sparse_tensor.SparseTensor([[0]], [1.0], [10]) shape_invariants = [ - tensor_spec.TensorSpec([], dtype=dtypes.int32), + tensor_lib.TensorSpec([], dtype=dtypes.int32), sparse_tensor.SparseTensorSpec([None])] while_loop_tf.while_loop(c, b, [i, x], shape_invariants) @@ -3489,7 +3489,7 @@ def b(lv0, lv1, lv2): self.assertTrue(isinstance(r, list)) self.assertTrue(isinstance(r[0], named)) self.assertTrue(isinstance(r[1], tuple)) - self.assertTrue(isinstance(r[2], ops.Tensor)) + self.assertTrue(isinstance(r[2], tensor_lib.Tensor)) r_flattened = nest.flatten(r) self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0], @@ -4192,7 +4192,7 @@ def testOneValueCond(self): two = ops.convert_to_tensor(2, name="two") p = math_ops.greater_equal(c, 1) i = tf_cond.cond(p, lambda: one, lambda: two) - self.assertTrue(isinstance(i, ops.Tensor)) + self.assertTrue(isinstance(i, tensor_lib.Tensor)) # True case: c = 2 is >= 1 self.assertEqual([1], i.eval(feed_dict={c: 2})) @@ -4328,7 +4328,7 @@ def b(): return state_ops.assign(v, two) i = tf_cond.cond(p, a, b) - self.assertTrue(isinstance(i, ops.Tensor)) + self.assertTrue(isinstance(i, tensor_lib.Tensor)) self.evaluate(variables.global_variables_initializer()) self.assertEqual(0, self.evaluate(v)) diff --git a/tensorflow/python/kernel_tests/control_flow/scan_ops_test.py b/tensorflow/python/kernel_tests/control_flow/scan_ops_test.py index 7747af49cdcbb9..b6ff099c646c79 100644 --- a/tensorflow/python/kernel_tests/control_flow/scan_ops_test.py +++ b/tensorflow/python/kernel_tests/control_flow/scan_ops_test.py @@ -228,7 +228,10 @@ def _compare(self, x, axis, exclusive, reverse): with self.cached_session(): tf_out = math_ops.cumprod(x, axis, exclusive, reverse).eval() - self.assertAllClose(np_out, tf_out) + atol = rtol = 1e-6 + if x.dtype == dtypes.bfloat16.as_numpy_dtype: + atol = rtol = 1e-2 + self.assertAllClose(np_out, tf_out, atol=atol, rtol=rtol) def _compareAll(self, x, axis): for exclusive in [True, False]: diff --git a/tensorflow/python/kernel_tests/data_structures/BUILD b/tensorflow/python/kernel_tests/data_structures/BUILD index 130f24d1537274..9490160a746873 100644 --- a/tensorflow/python/kernel_tests/data_structures/BUILD +++ b/tensorflow/python/kernel_tests/data_structures/BUILD @@ -16,8 +16,10 @@ tf_py_strict_test( "no_mac", # TODO(b/129706424): Re-enable this test on Mac. ], deps = [ + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:data_flow_ops", "//tensorflow/python/platform:client_testlib", @@ -96,8 +98,11 @@ tf_py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/module", "//tensorflow/python/ops:array_ops", @@ -247,8 +252,11 @@ cuda_py_strict_test( srcs = ["padding_fifo_queue_test.py"], deps = [ "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:data_flow_ops", diff --git a/tensorflow/python/kernel_tests/data_structures/barrier_ops_test.py b/tensorflow/python/kernel_tests/data_structures/barrier_ops_test.py index 112c7454a99094..ec10af3f818ae9 100644 --- a/tensorflow/python/kernel_tests/data_structures/barrier_ops_test.py +++ b/tensorflow/python/kernel_tests/data_structures/barrier_ops_test.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import data_flow_ops from tensorflow.python.platform import test @@ -35,7 +36,7 @@ def testConstructorWithShapes(self): shapes=((1, 2, 3), (8,)), shared_name="B", name="B") - self.assertTrue(isinstance(b.barrier_ref, ops.Tensor)) + self.assertTrue(isinstance(b.barrier_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'B' op:'Barrier' attr { diff --git a/tensorflow/python/kernel_tests/data_structures/fifo_queue_test.py b/tensorflow/python/kernel_tests/data_structures/fifo_queue_test.py index 8111a6997b9d2e..8ee4f9c03df151 100644 --- a/tensorflow/python/kernel_tests/data_structures/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/data_structures/fifo_queue_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.module import module @@ -45,7 +46,7 @@ class FIFOQueueTest(test.TestCase): def testConstructor(self): with ops.Graph().as_default(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { type: DT_FLOAT } } } @@ -61,7 +62,7 @@ def testMultiQueueConstructor(self): 5, (dtypes_lib.int32, dtypes_lib.float32), shared_name="foo", name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { @@ -80,7 +81,7 @@ def testConstructorWithShapes(self): shapes=(tensor_shape.TensorShape([1, 1, 2, 3]), tensor_shape.TensorShape([5, 8])), name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { @@ -1645,7 +1646,7 @@ def testConstructor(self): names=("i", "j"), shared_name="foo", name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { @@ -1666,7 +1667,7 @@ def testConstructorWithShapes(self): shapes=(tensor_shape.TensorShape([1, 1, 2, 3]), tensor_shape.TensorShape([5, 8])), name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { diff --git a/tensorflow/python/kernel_tests/data_structures/list_ops_test.py b/tensorflow/python/kernel_tests/data_structures/list_ops_test.py index 3d64a891251f4f..1d3eb1ab96c884 100644 --- a/tensorflow/python/kernel_tests/data_structures/list_ops_test.py +++ b/tensorflow/python/kernel_tests/data_structures/list_ops_test.py @@ -479,6 +479,28 @@ def testGatherUsingSpecifiedElementShape(self): self.assertEqual(t.shape.as_list(), [3]) self.assertAllEqual(self.evaluate(t), np.zeros((3,))) + def testGatherWithInvalidIndicesFails(self): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=None, num_elements=3 + ) + + # Should raise an error when the input index is negative. + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Trying to gather element -1 in a list with 3 elements.", + ): + t = list_ops.tensor_list_gather(l, [-1], element_dtype=dtypes.float32) + self.evaluate(t) + + # Should raise an error when the input index is larger than the number of + # elements in the list. + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Trying to gather element 3 in a list with 3 elements.", + ): + t = list_ops.tensor_list_gather(l, [3], element_dtype=dtypes.float32) + self.evaluate(t) + def testScatterOutputListSize(self): c0 = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_scatter(c0, [1, 3], []) diff --git a/tensorflow/python/kernel_tests/data_structures/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/data_structures/padding_fifo_queue_test.py index db21e1bee59931..1b4c9d6b961ac7 100644 --- a/tensorflow/python/kernel_tests/data_structures/padding_fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/data_structures/padding_fifo_queue_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -37,7 +38,7 @@ def testConstructor(self): with ops.Graph().as_default(): q = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, ((None,),), name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { type: DT_FLOAT } } } @@ -53,7 +54,7 @@ def testMultiQueueConstructor(self): 5, (dtypes_lib.int32, dtypes_lib.float32), ((), ()), shared_name="foo", name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { @@ -72,7 +73,7 @@ def testConstructorWithShapes(self): shapes=(tensor_shape.TensorShape([1, 1, 2, 3]), tensor_shape.TensorShape([5, 8])), name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { diff --git a/tensorflow/python/kernel_tests/io_ops/BUILD b/tensorflow/python/kernel_tests/io_ops/BUILD index 66b71de8933cac..2ff24b665f3df8 100644 --- a/tensorflow/python/kernel_tests/io_ops/BUILD +++ b/tensorflow/python/kernel_tests/io_ops/BUILD @@ -79,9 +79,12 @@ tf_py_strict_test( "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", diff --git a/tensorflow/python/kernel_tests/io_ops/parsing_ops_test.py b/tensorflow/python/kernel_tests/io_ops/parsing_ops_test.py index 87e0e537e22e27..aca17e18280440 100644 --- a/tensorflow/python/kernel_tests/io_ops/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops/parsing_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -97,7 +98,9 @@ def _test(self, kwargs, expected_values=None, expected_err=None): serialized = kwargs["serialized"] batch_size = ( self.evaluate(serialized).size - if isinstance(serialized, ops.Tensor) else np.asarray(serialized).size) + if isinstance(serialized, tensor_lib.Tensor) + else np.asarray(serialized).size + ) for k, f in kwargs["features"].items(): if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None: self.assertEqual(tuple(out[k].shape.as_list()), (batch_size,) + f.shape) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index fa8f91b61fb7a2..42fd131137743f 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -132,7 +132,9 @@ cuda_py_strict_test( tags = ["no_windows_gpu"], deps = [ "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:linalg_ops", @@ -269,7 +271,7 @@ cuda_py_strict_test( name = "linear_operator_circulant_test", size = "medium", srcs = ["linear_operator_circulant_test.py"], - shard_count = 15, + shard_count = 32, tags = [ "no_cuda11", # TODO(b/197522782): reenable test after fixing. "optonly", # times out, b/79171797 @@ -557,7 +559,9 @@ cuda_py_strict_test( shard_count = 5, tags = ["optonly"], deps = [ - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:linalg_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py index 827d1545e716cb..88d51257b517be 100644 --- a/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops @@ -180,10 +181,10 @@ def testShapeInferenceStaticBatchWith(self, num_rows_fn, num_columns_fn): batch_shape=batch_shape) self.assertEqual(4, identity_matrix.shape.ndims) self.assertEqual((2, 3), identity_matrix.shape[:2]) - if num_rows is not None and not isinstance(num_rows, ops.Tensor): + if num_rows is not None and not isinstance(num_rows, tensor.Tensor): self.assertEqual(2, identity_matrix.shape[-2]) - if num_columns is not None and not isinstance(num_columns, ops.Tensor): + if num_columns is not None and not isinstance(num_columns, tensor.Tensor): self.assertEqual(3, identity_matrix.shape[-1]) @parameterized.parameters( diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py index e28d1b2cae2cde..e1ca2f5ce6bcad 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py @@ -17,6 +17,7 @@ import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops @@ -97,7 +98,7 @@ def test_zero_batch_matrices_returned_as_empty_list(self): def test_one_batch_matrix_returned_after_tensor_conversion(self): arr = rng.rand(2, 3, 4) tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr]) - self.assertTrue(isinstance(tensor, ops.Tensor)) + self.assertTrue(isinstance(tensor, tensor_lib.Tensor)) self.assertAllClose(arr, self.evaluate(tensor)) diff --git a/tensorflow/python/kernel_tests/math_ops/aggregate_ops_test.py b/tensorflow/python/kernel_tests/math_ops/aggregate_ops_test.py index 42ff056175d03a..4edfdb2e2678f0 100644 --- a/tensorflow/python/kernel_tests/math_ops/aggregate_ops_test.py +++ b/tensorflow/python/kernel_tests/math_ops/aggregate_ops_test.py @@ -73,6 +73,8 @@ def testAddN(self): self.assertAllCloseAccordingToType( expected, actual, + float_rtol=5e-6, + float_atol=5e-6, half_rtol=5e-3, half_atol=5e-3, ) diff --git a/tensorflow/python/kernel_tests/nn_ops/BUILD b/tensorflow/python/kernel_tests/nn_ops/BUILD index 1d9701a2f3eea5..e984097efaef7b 100644 --- a/tensorflow/python/kernel_tests/nn_ops/BUILD +++ b/tensorflow/python/kernel_tests/nn_ops/BUILD @@ -728,10 +728,12 @@ cuda_py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", diff --git a/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py index 5b3d2150b446ce..7785b1bff2cfa4 100644 --- a/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py @@ -2334,6 +2334,7 @@ def _testAvgPoolGradSamePadding3_1(self, data_format, use_gpu): data_format=data_format, use_gpu=use_gpu) + @test_util.disable_xla("Xla does not raise error on out of bounds access") def testAvgPoolGradOutputMemoryOutOfBounds(self): #os.environ["TF_USE_ROCM_NHWC"] = "1" self.skipTest("Re-enable when NHWC is fully supported on ROCM.") diff --git a/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py b/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py index 527d33fc6cb732..0786182c1d6885 100644 --- a/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py @@ -29,8 +29,8 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -650,7 +650,7 @@ def _testStateTupleWithProjAndSequenceLength(self): self.assertEqual(len(outputs_notuple), len(inputs)) self.assertEqual(len(outputs_tuple), len(inputs)) self.assertTrue(isinstance(state_tuple, tuple)) - self.assertTrue(isinstance(state_notuple, ops.Tensor)) + self.assertTrue(isinstance(state_notuple, tensor.Tensor)) variables_lib.global_variables_initializer().run() input_value = np.random.randn(batch_size, input_size) @@ -3211,7 +3211,7 @@ def testSavedModel(self): with self.cached_session(): root = autotrackable.AutoTrackable() root.cell = rnn_cell_impl.LSTMCell(8) - @def_function.function(input_signature=[tensor_spec.TensorSpec([3, 8])]) + @def_function.function(input_signature=[tensor.TensorSpec([3, 8])]) def call(x): state = root.cell.zero_state(3, dtype=x.dtype) y, _ = root.cell(x, state) diff --git a/tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py b/tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py index 97e7fcbb0418ab..40a745691acefc 100644 --- a/tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py @@ -39,7 +39,8 @@ def _testSoftplus(self, np_features, use_gpu=False): softplus = nn_ops.softplus(np_features) tf_softplus = self.evaluate(softplus) self.assertAllCloseAccordingToType( - np_softplus, tf_softplus, bfloat16_rtol=5e-2, bfloat16_atol=5e-2 + np_softplus, tf_softplus, half_rtol=5e-3, half_atol=5e-3, + bfloat16_rtol=5e-2, bfloat16_atol=5e-2 ) self.assertTrue(np.all(tf_softplus > 0)) self.assertShapeEqual(np_softplus, softplus) diff --git a/tensorflow/python/kernel_tests/sparse_ops/BUILD b/tensorflow/python/kernel_tests/sparse_ops/BUILD index b16aa002c0744d..ec88a5c81cf44f 100644 --- a/tensorflow/python/kernel_tests/sparse_ops/BUILD +++ b/tensorflow/python/kernel_tests/sparse_ops/BUILD @@ -51,9 +51,12 @@ tf_py_strict_test( srcs = ["sparse_conditional_accumulator_test.py"], deps = [ "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", "//tensorflow/python/framework:indexed_slices", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:data_flow_ops", diff --git a/tensorflow/python/kernel_tests/sparse_ops/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_ops/sparse_conditional_accumulator_test.py index 45f51861695ce5..6b85e5ab719ea4 100644 --- a/tensorflow/python/kernel_tests/sparse_ops/sparse_conditional_accumulator_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops/sparse_conditional_accumulator_test.py @@ -22,6 +22,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -58,7 +59,7 @@ def testConstructor(self): with ops.Graph().as_default(): q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q") - self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) + self.assertTrue(isinstance(q.accumulator_ref, tensor.Tensor)) self.assertProtoEquals( """ name:'Q' op:'SparseConditionalAccumulator' @@ -81,7 +82,7 @@ def testConstructorWithShape(self): dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1, 5, 2, 8])) - self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) + self.assertTrue(isinstance(q.accumulator_ref, tensor.Tensor)) self.assertProtoEquals( """ name:'Q' op:'SparseConditionalAccumulator' diff --git a/tensorflow/python/kernel_tests/tensor_priority_test.py b/tensorflow/python/kernel_tests/tensor_priority_test.py deleted file mode 100644 index bb779f26eff30c..00000000000000 --- a/tensorflow/python/kernel_tests/tensor_priority_test.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the binary ops priority mechanism.""" -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_conversion_registry -from tensorflow.python.platform import test as test_lib - - -class TensorPriorityTest(test_lib.TestCase): - - def testSupportedRhsWithoutDelegation(self): - - class NumpyArraySubclass(np.ndarray): - pass - - supported_rhs_without_delegation = (3, 3.0, [1.0, 2.0], np.array( - [1.0, 2.0]), NumpyArraySubclass( - shape=(1, 2), buffer=np.array([1.0, 2.0])), - ops.convert_to_tensor([[1.0, 2.0]])) - for rhs in supported_rhs_without_delegation: - tensor = ops.convert_to_tensor([[10.0, 20.0]]) - res = tensor + rhs - self.assertIsInstance(res, ops.Tensor) - - def testUnsupportedRhsWithoutDelegation(self): - - class WithoutReverseAdd(object): - pass - - tensor = ops.convert_to_tensor([[10.0, 20.0]]) - rhs = WithoutReverseAdd() - with self.assertRaisesWithPredicateMatch( - TypeError, lambda e: "Expected float" in str(e)): - # pylint: disable=pointless-statement - tensor + rhs - - def testUnsupportedRhsWithDelegation(self): - - class WithReverseAdd(object): - - def __radd__(self, lhs): - return "Works!" - - tensor = ops.convert_to_tensor([[10.0, 20.0]]) - rhs = WithReverseAdd() - res = tensor + rhs - self.assertEqual(res, "Works!") - - def testFullDelegationControlUsingRegistry(self): - - class NumpyArraySubclass(np.ndarray): - - def __radd__(self, lhs): - return "Works!" - - def raise_to_delegate(value, dtype=None, name=None, as_ref=False): - del value, dtype, name, as_ref # Unused. - raise TypeError - - tensor_conversion_registry.register_tensor_conversion_function( - NumpyArraySubclass, raise_to_delegate, priority=0) - tensor = ops.convert_to_tensor([[10.0, 20.0]]) - rhs = NumpyArraySubclass(shape=(1, 2), buffer=np.array([1.0, 2.0])) - res = tensor + rhs - self.assertEqual(res, "Works!") - - -if __name__ == "__main__": - test_lib.main() diff --git a/tensorflow/python/kernel_tests/variables/BUILD b/tensorflow/python/kernel_tests/variables/BUILD index 0d965052fb029d..9a945cf7e3d0e2 100644 --- a/tensorflow/python/kernel_tests/variables/BUILD +++ b/tensorflow/python/kernel_tests/variables/BUILD @@ -84,6 +84,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:memory_checker", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:test_lib", @@ -174,8 +175,11 @@ tf_py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:cond", diff --git a/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py index ebe9788667c28f..6920466f64cb7a 100644 --- a/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import memory_checker from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_ops @@ -989,7 +990,7 @@ def gradient_func(*grad): result = tape.gradient(out, v) self.assertAllEqual(out, 5.) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor_lib.Tensor) self.assertAllEqual(result, 2.) def testToFromProtoCachedValue(self): @@ -1805,7 +1806,7 @@ def testCompositeTensorTypeSpec(self): def testVariableInExtensionType(self): class MaskVariable(extension_type.ExtensionType): variable: resource_variable_ops.ResourceVariable - mask: ops.Tensor + mask: tensor_lib.Tensor v = resource_variable_ops.ResourceVariable([1., 2.]) self.evaluate(v.initializer) diff --git a/tensorflow/python/kernel_tests/variables/variables_test.py b/tensorflow/python/kernel_tests/variables/variables_test.py index 929452b104e7c6..45b88857090313 100644 --- a/tensorflow/python/kernel_tests/variables/variables_test.py +++ b/tensorflow/python/kernel_tests/variables/variables_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -377,7 +378,7 @@ def testOperatorWrapping(self): for attr in functools.WRAPPER_ASSIGNMENTS: self.assertEqual( getattr(variables.Variable.__add__, attr), - getattr(ops.Tensor.__add__, attr)) + getattr(tensor.Tensor.__add__, attr)) @test_util.run_deprecated_v1 def testOperators(self): diff --git a/tensorflow/python/lib/core/BUILD b/tensorflow/python/lib/core/BUILD index 5526bddbaf51ff..46152cc1402cf9 100644 --- a/tensorflow/python/lib/core/BUILD +++ b/tensorflow/python/lib/core/BUILD @@ -176,7 +176,6 @@ cc_library( ":ndarray_tensor", ":ndarray_tensor_bridge", ":py_util", - ":safe_pyobject_ptr", "//tensorflow/c:safe_ptr", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index 824692d450283c..c676d9f5bad529 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -31,12 +31,26 @@ from tensorflow.core.protobuf.config_pb2 import * from tensorflow.core.util.event_pb2 import * +# Compiler +from tensorflow.python.compiler.xla import jit +from tensorflow.python.compiler.xla import xla +from tensorflow.python.compiler.mlir import mlir + # Data from tensorflow.python import data +# TensorFlow Debugger (tfdbg). +from tensorflow.python.debug.lib import check_numerics_callback +from tensorflow.python.debug.lib import dumping_callback +from tensorflow.python.ops import gen_debug_ops + # Distribute from tensorflow.python import distribute +# DLPack +from tensorflow.python.dlpack.dlpack import from_dlpack +from tensorflow.python.dlpack.dlpack import to_dlpack + # Eager from tensorflow.python.eager import context from tensorflow.python.eager import def_function diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index a409f96095a3af..4943ebe4bfa040 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -1,4 +1,3 @@ -load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test") load("//tensorflow/python:build_defs.bzl", "tf_gen_op_strict_wrapper_private_py") @@ -95,6 +94,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", @@ -384,7 +384,7 @@ py_strict_library( ":batch_ops_gen", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_export", ], @@ -747,6 +747,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", ], @@ -789,6 +790,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", @@ -902,6 +904,7 @@ py_strict_library( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:compat", @@ -1130,6 +1133,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", ], ) @@ -1151,8 +1155,8 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", "//tensorflow/python/util:compat", @@ -1176,6 +1180,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:dispatch", "//tensorflow/python/util:tf_export", @@ -1193,6 +1198,7 @@ py_strict_library( ":math_ops", "//tensorflow/python/eager:context", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:tf_export", ], ) @@ -1211,6 +1217,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/types:core", @@ -1372,6 +1379,7 @@ py_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", @@ -1408,6 +1416,7 @@ py_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", @@ -1429,6 +1438,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", @@ -1618,6 +1628,8 @@ py_strict_library( ":tensor_array_ops", ":unconnected_gradients", ":while_loop", + "//tensorflow/python/debug/lib:debug_gradients", + "//tensorflow/python/debug/lib:dumping_callback", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/ops/linalg/sparse:sparse_csr_matrix_grad", @@ -1651,6 +1663,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:compat", @@ -1824,6 +1837,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/lib/io:lib", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", @@ -1861,6 +1875,7 @@ py_strict_library( ":math_ops", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", "//tensorflow/python/util:tf_export", @@ -1877,6 +1892,7 @@ py_strict_library( ":math_ops", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:compat", "//third_party/py/numpy", ], @@ -1941,6 +1957,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -1970,6 +1987,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//third_party/py/numpy", ], @@ -1981,6 +1999,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:object_identity", ], ) @@ -2004,6 +2023,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", @@ -2068,7 +2088,6 @@ py_strict_library( "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/trackable:base", "//tensorflow/python/types:core", @@ -2110,6 +2129,7 @@ py_strict_library( "//tensorflow/python/framework:cpp_shape_inference_proto_py", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//third_party/py/numpy", @@ -2183,6 +2203,7 @@ py_strict_library( "//tensorflow/python/framework:graph_util", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:device_context", @@ -2222,6 +2243,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops/ragged:ragged_math_ops", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -2340,6 +2362,7 @@ py_strict_library( "//tensorflow/python/framework:config", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/trackable:autotrackable", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_export", @@ -2404,6 +2427,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:deprecation", @@ -2485,6 +2509,7 @@ py_strict_library( "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:compat", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", @@ -2514,10 +2539,9 @@ py_strict_library( ":bitwise_ops", ":math_ops", ":stateless_random_ops_v2_gen", - "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:tf_export", ], ) @@ -2556,6 +2580,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", @@ -2704,6 +2729,7 @@ py_strict_library( ":special_math_ops_gen", "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:deprecation", @@ -2905,6 +2931,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:smart_cond", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/trackable:resource", @@ -2963,8 +2990,8 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", "//tensorflow/python/framework:type_spec_registry", @@ -2987,6 +3014,7 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/util:nest", ], @@ -3007,6 +3035,7 @@ py_strict_library( "//tensorflow/python/eager:monitoring", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/platform:tf_logging", @@ -3035,6 +3064,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/trackable:base", @@ -3101,6 +3131,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", @@ -3119,6 +3150,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", @@ -3127,7 +3159,7 @@ py_strict_library( cuda_py_strict_test( name = "bitwise_ops_test", - size = "small", + size = "medium", srcs = ["bitwise_ops_test.py"], main = "bitwise_ops_test.py", python_version = "PY3", @@ -3308,7 +3340,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:function", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/ops/ragged:ragged_factory_ops", @@ -3472,7 +3504,7 @@ cuda_py_strict_test( cuda_py_strict_test( name = "math_grad_test", - size = "small", + size = "medium", srcs = ["math_grad_test.py"], main = "math_grad_test.py", python_version = "PY3", @@ -3523,6 +3555,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/platform:test", @@ -3638,7 +3671,10 @@ cuda_py_strict_test( srcs = ["nn_test.py"], main = "nn_test.py", python_version = "PY3", - tags = ["no_windows"], + tags = [ + "no_windows", + "notap", # TODO(b/290819913) + ], xla_tags = [ "no_cuda_asan", # times out ], @@ -4279,7 +4315,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_impl", @@ -4315,6 +4351,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:deprecation", @@ -4360,7 +4397,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:indexed_slices", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/util:nest", @@ -4413,10 +4450,26 @@ py_strict_library( name = "weak_tensor_ops", srcs = ["weak_tensor_ops.py"], deps = [ - ":weak_tensor_ops_list", + ":array_ops", + ":array_ops_gen", + ":bitwise_ops_gen", + ":clip_ops", + ":image_ops_impl", + ":math_ops", + ":math_ops_gen", + ":nn_impl", + ":nn_ops", + ":nn_ops_gen", + ":special_math_ops", + "//tensorflow/python/framework:flexible_dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:weak_tensor", + "//tensorflow/python/ops/numpy_ops:np_array_ops", + "//tensorflow/python/ops/numpy_ops:np_math_ops", + "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:dispatch", + "//tensorflow/python/util:tf_decorator", ], ) @@ -4431,42 +4484,58 @@ py_strict_test( ":image_ops_impl", ":math_ops", ":weak_tensor_ops", - ":weak_tensor_ops_list", + ":weak_tensor_test_util", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:weak_tensor", "//tensorflow/python/ops/numpy_ops:np_array_ops", "//tensorflow/python/ops/numpy_ops:np_config", "//tensorflow/python/ops/numpy_ops:np_math_ops", + "//tensorflow/python/ops/ragged:ragged_tensor", "//tensorflow/python/platform:test", + "//tensorflow/python/util:dispatch", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -py_strict_library( - name = "weak_tensor_ops_list", - srcs = ["weak_tensor_ops_list.py"], +py_strict_test( + name = "weak_tensor_math_ops_test", + srcs = ["weak_tensor_math_ops_test.py"], deps = [ ":array_ops", - ":array_ops_gen", - ":bitwise_ops_gen", - ":clip_ops", - ":image_ops_impl", ":math_ops", - ":math_ops_gen", - ":nn_impl", - ":nn_ops", - ":nn_ops_gen", - ":special_math_ops", - "//tensorflow/python/ops/numpy_ops:np_array_ops", - "//tensorflow/python/ops/numpy_ops:np_math_ops", + ":tensor_array_ops", + ":weak_tensor_ops", + ":weak_tensor_test_util", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:tf2", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/framework:weak_tensor", + "//tensorflow/python/ops/ragged:ragged_factory_ops", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -pytype_strict_library( +py_strict_library( name = "weak_tensor_test_util", srcs = ["weak_tensor_test_util.py"], - deps = ["//tensorflow/python/framework:ops"], + deps = [ + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:weak_tensor", + "//third_party/py/numpy", + ], ) diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 4ae6802cc042df..ccbcfe11efa3de 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -22,6 +22,7 @@ from tensorflow.python.framework import indexed_slices as indexed_slices_lib from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -91,7 +92,7 @@ def _ExtractInputShapes(inputs): for x in inputs: input_shape = array_ops.shape(x) if not isinstance(input_shape, - ops.Tensor) or input_shape.op.type != "Const": + tensor.Tensor) or input_shape.op.type != "Const": fully_known = False break sizes.append(input_shape) @@ -109,7 +110,7 @@ def _ExtractInputShapes(inputs): input_values = op.inputs[start_value_index:end_value_index] out_grads = [] - if isinstance(grad, ops.Tensor): + if isinstance(grad, tensor.Tensor): if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. @@ -1206,7 +1207,7 @@ def _BroadcastToGrad(op, grad): input_value = op.inputs[0] broadcast_shape = op.inputs[1] shape_dtype = dtypes.int32 - if isinstance(broadcast_shape, ops.Tensor): + if isinstance(broadcast_shape, tensor.Tensor): shape_dtype = broadcast_shape.dtype input_value_shape = array_ops.shape(input_value, out_type=shape_dtype) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 2d49e40b5e06fe..753f1b03d6789a 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -1030,7 +1031,7 @@ def _slice_helper(tensor, slice_spec, var=None): appear in TensorFlow's generated documentation. Args: - tensor: An ops.Tensor object. + tensor: An tensor.Tensor object. slice_spec: The arguments to Tensor.__getitem__. var: In the case of variable slice assignment, the Variable object to slice (i.e. tensor is the read-only view of this variable). @@ -1048,9 +1049,11 @@ def _slice_helper(tensor, slice_spec, var=None): if var is None and ops._numpy_style_slicing: # pylint: disable=protected-access return tensor._numpy_style_getitem(slice_spec) # pylint: disable=protected-access - if isinstance(slice_spec, bool) or \ - (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \ - (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool): + if (isinstance(slice_spec, bool) + or (isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool) + or (isinstance(slice_spec, np.ndarray) + and slice_spec.dtype == bool)): return boolean_mask(tensor=tensor, mask=slice_spec) if not isinstance(slice_spec, (list, tuple)): @@ -1067,7 +1070,7 @@ def _slice_helper(tensor, slice_spec, var=None): # Finds the best dtype for begin, end, and strides. dtype = None for t in [s.start, s.stop, s.step]: - if t is None or not isinstance(t, ops.Tensor): + if t is None or not isinstance(t, tensor_lib.Tensor): continue if t.dtype == dtypes.int64: dtype = dtypes.int64 @@ -1117,8 +1120,9 @@ def _slice_helper(tensor, slice_spec, var=None): begin.append(s) end.append(s + 1) # TODO(mdan): Investigate why we can't set int32 here. - if isinstance(s, ops.Tensor) and (s.dtype == dtypes.int16 or - s.dtype == dtypes.int64): + if ( + isinstance(s, tensor_lib.Tensor) + and (s.dtype == dtypes.int16 or s.dtype == dtypes.int64)): strides.append(constant_op.constant(1, dtype=s.dtype)) else: strides.append(1) @@ -1413,7 +1417,7 @@ def _SliceHelperVar(var, slice_spec): return _slice_helper(var.value(), slice_spec, var) -ops.Tensor._override_operator("__getitem__", _slice_helper) +tensor_lib.Tensor._override_operator("__getitem__", _slice_helper) @tf_export("parallel_stack") @@ -2887,7 +2891,7 @@ def zeros(shape, dtype=dtypes.float32, name=None, layout=None): else: zero = 0 - if not isinstance(shape, ops.Tensor): + if not isinstance(shape, tensor_lib.Tensor): try: if not context.executing_eagerly(): # Create a constant if it won't be very big. Otherwise, create a fill @@ -3202,7 +3206,7 @@ def ones(shape, dtype=dtypes.float32, name=None, layout=None): one = np.ones([]).astype(dtype.as_numpy_dtype) else: one = 1 - if not isinstance(shape, ops.Tensor): + if not isinstance(shape, tensor_lib.Tensor): try: if not context.executing_eagerly(): # Create a constant if it won't be very big. Otherwise, create a fill @@ -3403,7 +3407,7 @@ def sparse_placeholder(dtype, shape=None, name=None): dense_shape = placeholder(dtypes.int64, shape=[rank], name=shape_name) dense_shape_default = tensor_util.constant_value_as_shape(dense_shape) else: - if isinstance(shape, ops.Tensor): + if isinstance(shape, tensor_lib.Tensor): rank = shape.get_shape()[0] dense_shape_default = tensor_util.constant_value_as_shape(shape) else: @@ -3590,7 +3594,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl paddings_constant = _get_paddings_constant(paddings) input_shape = ( tensor_shape.TensorShape(tensor.shape) - if isinstance(tensor, ops.Tensor) else result.op.inputs[0].shape) + if isinstance(tensor, tensor_lib.Tensor) else result.op.inputs[0].shape) if (input_shape.ndims is not None and not result.shape.is_fully_defined() and paddings_constant is not None): new_shape = [] @@ -3618,7 +3622,7 @@ def _get_paddings_constant(paddings): A nested list or numbers or `None`, in which `None` indicates unknown padding size. """ - if isinstance(paddings, ops.Tensor): + if isinstance(paddings, tensor_lib.Tensor): return tensor_util.constant_value(paddings, partial=True) elif isinstance(paddings, (list, tuple)): return [_get_paddings_constant(x) for x in paddings] @@ -4402,7 +4406,7 @@ def one_hot(indices, def _all_dimensions(x): """Returns a 1D-tensor listing all dimensions in x.""" # Fast path: avoid creating Rank and Range ops if ndims is known. - if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None: + if isinstance(x, tensor_lib.Tensor) and x.get_shape().ndims is not None: return constant_op.constant( np.arange(x.get_shape().ndims), dtype=dtypes.int32) if (isinstance(x, sparse_tensor.SparseTensor) and diff --git a/tensorflow/python/ops/batch_ops.py b/tensorflow/python/ops/batch_ops.py index dbe17201146d59..0361ea242946b3 100644 --- a/tensorflow/python/ops/batch_ops.py +++ b/tensorflow/python/ops/batch_ops.py @@ -16,7 +16,7 @@ """Operations for automatic batching and unbatching.""" from tensorflow.python.eager import def_function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.ops import gen_batch_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_batch_ops import * @@ -92,14 +92,14 @@ def computation(*computation_args): return fn(*computation_args) computation = computation.get_concrete_function(*[ - tensor_spec.TensorSpec( + tensor.TensorSpec( dtype=x.dtype, shape=x.shape, name="batch_" + str(i)) for i, x in enumerate(args) ]) with ops.name_scope("batch") as name: for a in args: - if not isinstance(a, ops.Tensor): + if not isinstance(a, tensor.Tensor): raise ValueError("All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " f"found {a!r}.") diff --git a/tensorflow/python/ops/bincount_ops.py b/tensorflow/python/ops/bincount_ops.py index ce63aac1b0c5ba..92290a608844e9 100644 --- a/tensorflow/python/ops/bincount_ops.py +++ b/tensorflow/python/ops/bincount_ops.py @@ -17,6 +17,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops @@ -221,7 +222,7 @@ def validate_dense_weights(values, weights, dtype=None): return array_ops.constant([], dtype=dtype) return array_ops.constant([], dtype=values.dtype) - if not isinstance(weights, ops.Tensor): + if not isinstance(weights, tensor.Tensor): raise ValueError( "Argument `weights` must be a tf.Tensor if `values` is a tf.Tensor. " f"Received weights={weights} of type: {type(weights).__name__}") diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 102e66afb8a827..bc3d6266ad3428 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -71,7 +72,7 @@ def _maybe_constant_value_string(t): - if not isinstance(t, ops.Tensor): + if not isinstance(t, tensor_lib.Tensor): return str(t) const_t = tensor_util.constant_value(t) if const_t is not None: @@ -417,7 +418,7 @@ def _pretty_print(data_item, summarize): Returns: An appropriate string representation of data_item """ - if isinstance(data_item, ops.Tensor): + if isinstance(data_item, tensor_lib.Tensor): arr = data_item.numpy() if np.isscalar(arr): # Tensor.numpy() returns a scalar for zero-dimensional tensors @@ -526,7 +527,7 @@ def assert_proper_iterable(values): `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`. """ unintentional_iterables = ( - (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray) + (tensor_lib.Tensor, sparse_tensor.SparseTensor, np.ndarray) + compat.bytes_or_text_types ) if isinstance(values, unintentional_iterables): @@ -1979,7 +1980,7 @@ def is_numeric_tensor(tensor): Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not a `tf.Tensor` object. """ - return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES + return isinstance(tensor, tensor_lib.Tensor) and tensor.dtype in NUMERIC_TYPES @tf_export( diff --git a/tensorflow/python/ops/composite_tensor_ops.py b/tensorflow/python/ops/composite_tensor_ops.py index 5067aa7f6be823..51a44613f6ddb1 100644 --- a/tensorflow/python/ops/composite_tensor_ops.py +++ b/tensorflow/python/ops/composite_tensor_ops.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import gen_composite_tensor_ops from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.util import nest @@ -74,7 +75,7 @@ def composite_tensor_from_variant(encoded, type_spec, name=None): TypeError: If `encoded` is not a Tensor with dtype=variant. InvalidArgumentError: If `encoded` is not compatible with `type_spec`. """ - if not isinstance(encoded, ops.Tensor): + if not isinstance(encoded, tensor.Tensor): raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.") if encoded.dtype != dtypes.variant: raise TypeError("Expected `encoded` to have dtype=variant, got " diff --git a/tensorflow/python/ops/cond.py b/tensorflow/python/ops/cond.py index 02cbdbf182a30f..9fae845aaeb469 100644 --- a/tensorflow/python/ops/cond.py +++ b/tensorflow/python/ops/cond.py @@ -19,6 +19,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2 @@ -207,7 +208,9 @@ def f2(): return tf.add(y, 23) res_f_flat = nest.flatten(res_f, expand_composites=True) for (x, y) in zip(res_t_flat, res_f_flat): - assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) + assert ( + isinstance(x, tensor_lib.Tensor) + and isinstance(y, tensor_lib.Tensor)) if x.dtype.base_dtype != y.dtype.base_dtype: raise ValueError( "Outputs of 'true_fn' and 'false_fn' must have the same type(s). " diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index f09a280a9438a3..66e44131875dca 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec @@ -650,7 +651,7 @@ def _make_output_composite_tensors_match(op_type, branch_graphs): for branch_idx, branch_out in enumerate(branch_outs): if isinstance(branch_out, indexed_slices.IndexedSlices): continue - elif isinstance(branch_out, ops.Tensor): + elif isinstance(branch_out, tensor_lib.Tensor): with branch_graphs[branch_idx].as_default(): branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices( branch_out) diff --git a/tensorflow/python/ops/control_flow_case.py b/tensorflow/python/ops/control_flow_case.py index a8d508f358db75..be7beca29fe10a 100644 --- a/tensorflow/python/ops/control_flow_case.py +++ b/tensorflow/python/ops/control_flow_case.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_assert @@ -401,7 +402,7 @@ def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, f"Received {pred_fn_pair}.") pred, fn = pred_fn_pair - if isinstance(pred, ops.Tensor): + if isinstance(pred, tensor.Tensor): if pred.dtype != dtypes.bool: raise TypeError("pred must be Tensor of type bool: %s" % pred.name) elif not allow_python_preds: diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index 7b8e13c8351673..6fe6d207f9805d 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -19,6 +19,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -158,7 +159,7 @@ def _ExitGrad(op, grad): if op_ctxt.grad_state: raise TypeError("Second-order gradient for while loops not supported.") - if isinstance(grad, ops.Tensor): + if isinstance(grad, tensor.Tensor): grad_ctxt.AddName(grad.name) else: if not isinstance( @@ -220,7 +221,7 @@ def _EnterGrad(op, grad): return grad if op.get_attr("is_constant"): # Add a gradient accumulator for each loop invariant. - if isinstance(grad, ops.Tensor): + if isinstance(grad, tensor.Tensor): result = grad_ctxt.AddBackpropAccumulator(op, grad) elif isinstance(grad, indexed_slices.IndexedSlices): result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 80fcc7e51910ca..7ddbfe5e8cc356 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -27,8 +27,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops @@ -70,7 +70,7 @@ def _Identity(tensor, name=None): # TODO(b/246438937): Remove this when we expand ResourceVariables into # dt_resource tensors. tensor = variable_utils.convert_variables_to_tensors(tensor) - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access return gen_array_ops.ref_identity(tensor, name=name) else: @@ -84,7 +84,7 @@ def _Identity(tensor, name=None): def _NextIteration(tensor, name=None): tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access return ref_next_iteration(tensor, name=name) else: @@ -127,7 +127,7 @@ def _Enter(tensor, than its corresponding shape in `shape_invariant`. """ tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): if tensor.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access result = gen_control_flow_ops.ref_enter( tensor, frame_name, is_constant, parallel_iterations, name=name) @@ -162,7 +162,7 @@ def exit(tensor, name=None): # pylint: disable=redefined-builtin The same tensor as `tensor`. """ tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access return gen_control_flow_ops.ref_exit(tensor, name) else: @@ -197,7 +197,7 @@ def switch(data, pred, dtype=None, name=None): data = ops.internal_convert_to_tensor_or_composite( data, dtype=dtype, name="data", as_ref=True) pred = ops.convert_to_tensor(pred, name="pred") - if isinstance(data, ops.Tensor): + if isinstance(data, tensor_lib.Tensor): return gen_control_flow_ops.switch(data, pred, name=name) else: if not isinstance(data, composite_tensor.CompositeTensor): @@ -249,7 +249,7 @@ def _SwitchRefOrTensor(data, pred, name="Switch"): # var and data may be pinned to different devices, so we want to ops # created within ops.colocate_with(data) to ignore the existing stack. with ops.colocate_with(data, ignore_existing=True): - if isinstance(data, ops.Tensor): + if isinstance(data, tensor_lib.Tensor): if data.dtype._is_ref_dtype: # pylint: disable=protected-access return ref_switch(data, pred, name=name) return switch(data, pred, name=name) @@ -287,7 +287,7 @@ def merge(inputs, name=None): ops.internal_convert_to_tensor_or_composite(inp, as_ref=True) for inp in inputs ] - if all(isinstance(v, ops.Tensor) for v in inputs): + if all(isinstance(v, tensor_lib.Tensor) for v in inputs): if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access return gen_control_flow_ops.ref_merge(inputs, name) else: @@ -296,7 +296,7 @@ def merge(inputs, name=None): # If there is a mix of tensors and indexed slices, then convert the # tensors to indexed slices. if all( - isinstance(v, (indexed_slices.IndexedSlices, ops.Tensor)) + isinstance(v, (indexed_slices.IndexedSlices, tensor_lib.Tensor)) for v in inputs): inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) @@ -384,8 +384,8 @@ def _shape_invariant_to_type_spec(var, shape=None): "'shape' must be one of TypeSpec, TensorShape or None. " f"Received: {type(shape)}") - if isinstance(var, ops.Tensor): - return tensor_spec.TensorSpec(shape, var.dtype) + if isinstance(var, tensor_lib.Tensor): + return tensor_lib.TensorSpec(shape, var.dtype) else: try: return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access @@ -408,7 +408,7 @@ def _EnforceShapeInvariant(merge_var, next_var): ValueError: If any tensor in `merge_var` has a more specific shape than its corresponding tensor in `next_var`. """ - if isinstance(merge_var, ops.Tensor): + if isinstance(merge_var, tensor_lib.Tensor): m_shape = merge_var.get_shape() n_shape = next_var.get_shape() if not _ShapeLessThanOrEqual(n_shape, m_shape): @@ -427,7 +427,7 @@ def _EnforceShapeInvariant(merge_var, next_var): def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): """Add NextIteration and back edge from v to m.""" - if isinstance(m, ops.Tensor): + if isinstance(m, tensor_lib.Tensor): v = ops.convert_to_tensor(v) v = _NextIteration(v) if enforce_shape_invariant: @@ -1632,7 +1632,7 @@ def _InitializeValues(self, values): """Makes the values known to this context.""" self._values = set() for x in values: - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): self._values.add(x.name) else: raise TypeError("'values' must be a list of Tensors. " @@ -1831,7 +1831,7 @@ def _FixControlInputsAndContext(self, enters): graph = ops.get_default_graph() # pylint: disable=protected-access for e in enters: - if isinstance(e, ops.Tensor): + if isinstance(e, tensor_lib.Tensor): xs = [e] else: raise TypeError("'enters' must be a list of Tensors. " @@ -1888,7 +1888,7 @@ def _AsTensorList(x, p): if isinstance(v, ops.Operation): v = with_dependencies([v], p) v = ops.convert_to_tensor_or_composite(v) - if isinstance(v, ops.Tensor): + if isinstance(v, tensor_lib.Tensor): l.append(array_ops.identity(v)) else: l.append( @@ -2150,7 +2150,7 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined ] if control_inputs: for c in control_inputs: - if isinstance(c, ops.Tensor): + if isinstance(c, tensor_lib.Tensor): c = c.op elif not isinstance(c, ops.Operation): raise TypeError( diff --git a/tensorflow/python/ops/control_flow_switch_case.py b/tensorflow/python/ops/control_flow_switch_case.py index 8cb4fe685ef104..843a088017df36 100644 --- a/tensorflow/python/ops/control_flow_switch_case.py +++ b/tensorflow/python/ops/control_flow_switch_case.py @@ -16,6 +16,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_util as util @@ -45,7 +46,7 @@ def _indexed_case_verify_and_canonicalize_args(branch_fns, default, Returns: branch_fns: validated list of callables for each branch (default last). """ - if not isinstance(branch_index, ops.Tensor): + if not isinstance(branch_index, tensor.Tensor): raise TypeError("'branch_index' must be a Tensor, got {}".format( type(branch_index))) if not branch_index.dtype.is_integer: diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index fc46a849d86f5c..2fbd2643da45d1 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -576,7 +576,7 @@ def embedding_lookup_sparse_v2( Since row 1 and 2 of `sp_ids` only have one value each, they simply select the corresponding row from `params` as the output row. Row 1 has value `3` so it selects the `params` elements `[7, 8]` and row 2 has the value 2 so it - selects the the `params` elements `[5, 6]`. + selects the `params` elements `[5, 6]`. If `sparse_weights` is specified, it must have the same shape as `sp_ids`. `sparse_weights` is used to assign a weight to each slice of `params`. For diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index f394ce2cc77ad3..a50b445aea6394 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_functional_ops @@ -1044,7 +1045,7 @@ def WhileBody(i, n, start, delta, *args): if isinstance(for_result, ops.Operation): for_result = () # Unary functions return a single Tensor value. - elif isinstance(for_result, ops.Tensor): + elif isinstance(for_result, tensor.Tensor): for_result = (for_result,) return (i + 1, n, start, delta) + tuple(for_result) diff --git a/tensorflow/python/ops/gradient_checker.py b/tensorflow/python/ops/gradient_checker.py index 9f00934c8263a6..8ed4aee3d47c6b 100644 --- a/tensorflow/python/ops/gradient_checker.py +++ b/tensorflow/python/ops/gradient_checker.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops @@ -108,7 +109,7 @@ def _compute_theoretical_jacobian(x, x_shape, x_data, dy, dy_shape, dx, r_end = r_begin + x_val_size jacobian[r_begin:r_end, col] += v.flat else: - assert isinstance(dx, ops.Tensor), "dx = " + str(dx) + assert isinstance(dx, tensor.Tensor), "dx = " + str(dx) backprop = sess.run( dx, feed_dict=_extra_feeds(extra_feed_dict, {x: x_data, dy: dy_data})) jacobian[:, col] = backprop.ravel().view(jacobian.dtype) diff --git a/tensorflow/python/ops/gradient_checker_v2.py b/tensorflow/python/ops/gradient_checker_v2.py index 3a201b103fafb3..390bdbd20e92cd 100644 --- a/tensorflow/python/ops/gradient_checker_v2.py +++ b/tensorflow/python/ops/gradient_checker_v2.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -74,7 +75,7 @@ def _to_numpy(a): """ if isinstance(a, ops.EagerTensor): return a.numpy() - if isinstance(a, ops.Tensor): + if isinstance(a, tensor.Tensor): sess = ops.get_default_session() return sess.run(a) if isinstance(a, indexed_slices.IndexedSlicesValue): diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index d1ccbbf581fc60..673fcf46fc393c 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -14,6 +14,8 @@ # ============================================================================== """Implements the graph generation for computation of gradients.""" +from tensorflow.python.debug.lib import debug_gradients # pylint: disable=unused-import +from tensorflow.python.debug.lib import dumping_callback # pylint: disable=unused-import from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_grad # pylint: disable=unused-import diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 62d5bde1eae3e8..b643278e3f9eb2 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import function as framework_function from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.framework.constant_op import constant @@ -234,9 +234,9 @@ def _TestOpGrad(_, float_grad, string_grad): z = x * 2.0 w = z * 3.0 grads = gradients.gradients(z, [c]) - self.assertIsInstance(grads[0], ops.Tensor) + self.assertIsInstance(grads[0], tensor.Tensor) grads = gradients.gradients(w, [c]) - self.assertIsInstance(grads[0], ops.Tensor) + self.assertIsInstance(grads[0], tensor.Tensor) def testNoGradientForStringOutputsWithOpNamespace(self): with ops.Graph().as_default(): @@ -254,9 +254,9 @@ def _TestOpGrad(_, float_grad, string_grad): z = x * 2.0 w = z * 3.0 grads = gradients.gradients(z, [c]) - self.assertIsInstance(grads[0], ops.Tensor) + self.assertIsInstance(grads[0], tensor.Tensor) grads = gradients.gradients(w, [c]) - self.assertIsInstance(grads[0], ops.Tensor) + self.assertIsInstance(grads[0], tensor.Tensor) def testSingletonIndexedSlices(self): with ops.Graph().as_default(): @@ -1614,7 +1614,7 @@ def F(x): self.assertAllClose(grads_re, grads) f_graph = def_function.function( - F, input_signature=[tensor_spec.TensorSpec(None)]) + F, input_signature=[tensor.TensorSpec(None)]) grads_re = self._grad(custom_gradient.recompute_grad(f_graph))(x) grads = self._grad(f_graph)(x) self.assertAllClose(grads_re, grads) @@ -1633,8 +1633,8 @@ def F(x1, x2): f_graph = def_function.function( F, input_signature=[ - tensor_spec.TensorSpec(None, dtype=dtypes.int32), - tensor_spec.TensorSpec(None, dtype=dtypes.float32), + tensor.TensorSpec(None, dtype=dtypes.int32), + tensor.TensorSpec(None, dtype=dtypes.float32), ]) grads_re = self._grad(custom_gradient.recompute_grad(f_graph))(x1, x2) grads = self._grad(f_graph)(x1, x2) diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 31c06e091eb7e8..3579e4f937479d 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -166,56 +167,77 @@ def _DefaultGradYs(grad_ys, if y.dtype.is_complex: raise TypeError( f"Gradients of complex tensors ({y}) must set grad_ys (y.dtype = " - f"{dtypes.as_dtype(y.dtype).name})") + f"{dtypes.as_dtype(y.dtype).name})" + ) new_grad_ys.append( array_ops.ones( - array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i)) + array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i + ) + ) continue if y.dtype.is_floating or y.dtype.is_integer: if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: raise TypeError( f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " f"for real or integer-valued tensor {y} with type " - f"{dtypes.as_dtype(y.dtype).name} must be real or integer") + f"{dtypes.as_dtype(y.dtype).name} must be real or integer" + ) elif y.dtype.is_complex: if not grad_y.dtype.is_complex: raise TypeError( f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " f"for complex-valued tensor {y} with type " - f"{dtypes.as_dtype(y.dtype).name} must be real") + f"{dtypes.as_dtype(y.dtype).name} must be real" + ) elif y.dtype == dtypes.variant: if grad_y.dtype != dtypes.variant: raise TypeError( f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " f"for variant tensor {y} with type " - f"{dtypes.as_dtype(y.dtype).name} must be variant") + f"{dtypes.as_dtype(y.dtype).name} must be variant" + ) elif y.dtype == dtypes.resource: # We assume y is the handle of a ResourceVariable. The gradient of a # ResourceVariable should be a numeric value, not another resource. if grad_y.dtype == dtypes.resource: - raise TypeError(f"Input gradient {grad_y} for resource tensor {y} " - "should not be a resource") + raise TypeError( + f"Input gradient {grad_y} for resource tensor {y} " + "should not be a resource" + ) else: raise TypeError( f"Tensor {y} with type {dtypes.as_dtype(y.dtype).name} must be " - "numeric to obtain a default gradient") + "numeric to obtain a default gradient" + ) # Create a grad_y tensor in the name scope of the gradient. # Required for TensorArrays to identify which gradient call a # grad_y value is coming from. if isinstance(grad_y, indexed_slices.IndexedSlices): new_grad_ys.append( indexed_slices.IndexedSlices( - indices=(array_ops.identity( - grad_y.indices, name="grad_ys_%d_indices" % i) - if isinstance(grad_y.indices, ops.Tensor) else - grad_y.indices), - values=(array_ops.identity( - grad_y.values, name="grad_ys_%d_values" % i) if isinstance( - grad_y.values, ops.Tensor) else grad_y.values), - dense_shape=(array_ops.identity( - grad_y.dense_shape, name="grad_ys_%d_shape" % i) - if isinstance(grad_y.dense_shape, ops.Tensor) else - grad_y.dense_shape))) + indices=( + array_ops.identity( + grad_y.indices, name="grad_ys_%d_indices" % i + ) + if isinstance(grad_y.indices, tensor_lib.Tensor) + else grad_y.indices + ), + values=( + array_ops.identity( + grad_y.values, name="grad_ys_%d_values" % i + ) + if isinstance(grad_y.values, tensor_lib.Tensor) + else grad_y.values + ), + dense_shape=( + array_ops.identity( + grad_y.dense_shape, name="grad_ys_%d_shape" % i + ) + if isinstance(grad_y.dense_shape, tensor_lib.Tensor) + else grad_y.dense_shape + ), + ) + ) else: new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i)) @@ -594,10 +616,11 @@ def _GradientsHelper(ys, func_call = None is_partitioned_call = _IsPartitionedCall(op) # pylint: disable=protected-access - is_func_call = ( - src_graph._is_function(op.type) or is_partitioned_call) + is_func_call = src_graph._is_function(op.type) or is_partitioned_call # pylint: enable=protected-access - has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) + has_out_grads = any( + isinstance(g, tensor_lib.Tensor) or g for g in out_grads + ) if has_out_grads and (op not in stop_ops): try: grad_fn = ops.get_gradient_function(op) @@ -662,9 +685,12 @@ def _GradientsHelper(ys, # output, it means that the cost does not depend on output[i], # therefore dC/doutput[i] is 0. for i, out_grad in enumerate(out_grads): - if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( + if ( + not isinstance(out_grad, tensor_lib.Tensor) and not out_grad + ) and ( (not grad_fn and is_func_call) - or backprop_util.IsTrainable(op.outputs[i])): + or backprop_util.IsTrainable(op.outputs[i]) + ): # Only trainable outputs or outputs for a function call that # will use SymbolicGradient get a zero gradient. Gradient # functions should ignore the gradient for other outputs. @@ -710,7 +736,7 @@ def _GradientsHelper(ys, # line up with in_grads. for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)): if in_grad is not None: - if (isinstance(in_grad, ops.Tensor) and + if (isinstance(in_grad, tensor_lib.Tensor) and t_in.dtype != dtypes.resource): try: in_grad.set_shape(t_in.get_shape()) @@ -738,7 +764,7 @@ def _HasAnyNotNoneGrads(grads, op): """Return true iff op has real gradient.""" out_grads = _GetGrads(grads, op) for out_grad in out_grads: - if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)): + if isinstance(out_grad, (tensor_lib.Tensor, indexed_slices.IndexedSlices)): return True if out_grad and isinstance(out_grad, collections_abc.Sequence): if any(g is not None for g in out_grad): @@ -842,7 +868,7 @@ def _GetGrads(grads, op): def _AccumulatorShape(inputs): shape = tensor_shape.unknown_shape() for i in inputs: - if isinstance(i, ops.Tensor): + if isinstance(i, tensor_lib.Tensor): shape = shape.merge_with(i.get_shape()) return shape @@ -981,12 +1007,13 @@ def _AggregatedGrads(grads, out_grads = _GetGrads(grads, op) for i, out_grad in enumerate(out_grads): if loop_state: - if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)): + if isinstance( + out_grad, (tensor_lib.Tensor, indexed_slices.IndexedSlices)): assert control_flow_util.IsLoopSwitch(op) continue # Grads have to be Tensors or IndexedSlices if (isinstance(out_grad, collections_abc.Sequence) and not all( - isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)) + isinstance(g, (tensor_lib.Tensor, indexed_slices.IndexedSlices)) for g in out_grad if g is not None)): raise TypeError(f"Invalid gradient {out_grad} [index = {i}]. Gradients " @@ -996,7 +1023,8 @@ def _AggregatedGrads(grads, if len(out_grad) < 2: used = "nop" out_grads[i] = out_grad[0] - elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None): + elif all( + isinstance(g, tensor_lib.Tensor) for g in out_grad if g is not None): tensor_shape = _AccumulatorShape(out_grad) if aggregation_method in [ AggregationMethod.EXPERIMENTAL_TREE, diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index aeebb3cded09db..80c2d996de8954 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -98,7 +99,7 @@ def _is_tensor(x): Returns: `True` if `x` is a `tf.Tensor` or `tf.Variable`, otherwise `False`. """ - return isinstance(x, (ops.Tensor, variables.Variable)) + return isinstance(x, (tensor_lib.Tensor, variables.Variable)) def _ImageDimensions(image, rank): diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 45ce92131c915b..64e2d0e5b7feb8 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.lib.io import python_io from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_io_ops @@ -275,7 +276,7 @@ def read(self, queue, name=None): key: A string scalar Tensor. value: A string scalar Tensor. """ - if isinstance(queue, ops.Tensor): + if isinstance(queue, tensor_lib.Tensor): queue_ref = queue else: queue_ref = queue.queue_ref @@ -307,7 +308,7 @@ def read_up_to(self, queue, num_records, # pylint: disable=invalid-name keys: A 1-D string Tensor. values: A 1-D string Tensor. """ - if isinstance(queue, ops.Tensor): + if isinstance(queue, tensor_lib.Tensor): queue_ref = queue else: queue_ref = queue.queue_ref diff --git a/tensorflow/python/ops/linalg/sparse/BUILD b/tensorflow/python/ops/linalg/sparse/BUILD index ec7abda24c2910..fa87211b113e3b 100644 --- a/tensorflow/python/ops/linalg/sparse/BUILD +++ b/tensorflow/python/ops/linalg/sparse/BUILD @@ -1,4 +1,5 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") +load("//tensorflow:pytype.default.bzl", "pytype_strict_library") # Description: Sparse CSR support for TensorFlow. load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") @@ -53,7 +54,7 @@ py_strict_library( srcs = ["__init__.py"], ) -py_strict_library( +pytype_strict_library( name = "conjugate_gradient", srcs = ["conjugate_gradient.py"], deps = [ @@ -79,6 +80,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py b/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py index 3c50276a00bfb8..9e5ab10faaabd4 100644 --- a/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py +++ b/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -96,7 +97,7 @@ def dense_shape_and_type(matrix): ValueError: if `matrix` lacks static handle data containing the dense shape and dtype. """ - if not isinstance(matrix, ops.Tensor): + if not isinstance(matrix, tensor_lib.Tensor): raise TypeError("matrix should be a tensor, but saw: %s" % (matrix,)) if matrix.dtype != dtypes.variant: raise TypeError( @@ -352,7 +353,9 @@ def _matrix(self): return self._csr_matrix def _from_matrix(self, matrix, handle_data=None): - assert isinstance(matrix, ops.Tensor) and matrix.dtype == dtypes.variant + assert ( + isinstance(matrix, tensor_lib.Tensor) and matrix.dtype == dtypes.variant + ) ret = type(self).__new__(type(self)) # pylint: disable=protected-access ret._dtype = self._dtype diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index f007c2021254c1..237cbd5a5d4650 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond from tensorflow.python.ops import gen_array_ops @@ -62,7 +63,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind): gramian = math_ops.matmul( matrix, matrix, adjoint_a=first_kind, adjoint_b=not first_kind) - if isinstance(l2_regularizer, ops.Tensor) or l2_regularizer != 0: + if isinstance(l2_regularizer, tensor_lib.Tensor) or l2_regularizer != 0: matrix_shape = array_ops.shape(matrix) batch_shape = matrix_shape[:-2] if first_kind: diff --git a/tensorflow/python/ops/linalg_ops_impl.py b/tensorflow/python/ops/linalg_ops_impl.py index 4481bb4e350dcb..45393ccba5349b 100644 --- a/tensorflow/python/ops/linalg_ops_impl.py +++ b/tensorflow/python/ops/linalg_ops_impl.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import compat @@ -42,8 +43,8 @@ def eye(num_rows, num_columns = num_rows if num_columns is None else num_columns # We cannot statically infer what the diagonal size should be: - if (isinstance(num_rows, ops.Tensor) or - isinstance(num_columns, ops.Tensor)): + if (isinstance(num_rows, tensor.Tensor) or + isinstance(num_columns, tensor.Tensor)): diag_size = math_ops.minimum(num_rows, num_columns) else: # We can statically infer the diagonal size, and whether it is square. @@ -56,9 +57,12 @@ def eye(num_rows, diag_size = np.minimum(num_rows, num_columns) # We can not statically infer the shape of the tensor. - if isinstance(batch_shape, ops.Tensor) or isinstance(diag_size, ops.Tensor): + if isinstance(batch_shape, tensor.Tensor) or isinstance( + diag_size, tensor.Tensor + ): batch_shape = ops.convert_to_tensor( - batch_shape, name='shape', dtype=dtypes.int32) + batch_shape, name='shape', dtype=dtypes.int32 + ) diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0) if not is_square: shape = array_ops.concat((batch_shape, [num_rows, num_columns]), axis=0) diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py index 5577b04e5502e6..bc2aeca89b26b4 100644 --- a/tensorflow/python/ops/list_ops.py +++ b/tensorflow/python/ops/list_ops.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -383,7 +384,7 @@ def _build_element_shape(shape): Returns: A None-free shape that can be converted to a tensor. """ - if isinstance(shape, ops.Tensor): + if isinstance(shape, tensor_lib.Tensor): return shape if isinstance(shape, tensor_shape.TensorShape): # `TensorShape.as_list` requires rank to be known. @@ -398,7 +399,7 @@ def _build_element_shape(shape): def convert(val): if val is None: return -1 - if isinstance(val, ops.Tensor): + if isinstance(val, tensor_lib.Tensor): return val if isinstance(val, tensor_shape.Dimension): return val.value if val.value is not None else -1 diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index fe06e18ee57232..b4eeedabe55407 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -706,7 +707,7 @@ def __init__(self, ValueError: when the filename is empty, or when the table key and value data types do not match the expected data types. """ - if not isinstance(filename, ops.Tensor) and not filename: + if not isinstance(filename, tensor_lib.Tensor) and not filename: raise ValueError("`filename` argument required for tf.lookup.TextFileInitializer") self._filename_arg = filename @@ -1499,7 +1500,7 @@ def index_table_from_file(vocabulary_file=None, num_oov_buckets) if vocab_size is not None and vocab_size < 1: vocab_file_value = vocabulary_file - if isinstance(vocabulary_file, ops.Tensor): + if isinstance(vocabulary_file, tensor_lib.Tensor): vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?" raise ValueError("`vocab_size` must be greater than 0, got %d for " "vocabulary_file: %s." % (vocab_size, vocab_file_value)) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 48664a17b63070..71d03a2704eea7 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -85,8 +86,8 @@ def SmartBroadcastGradientArgs(x, y, grad): # NOTE: It may be productive to apply these optimizations in the eager case # as well. if context.executing_eagerly() or not ( - isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) - and isinstance(grad, ops.Tensor)): + isinstance(x, tensor.Tensor) and isinstance(y, tensor.Tensor) + and isinstance(grad, tensor.Tensor)): sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) @@ -1303,7 +1304,7 @@ def _AddGrad(op, grad): # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, grad (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( @@ -1337,7 +1338,7 @@ def _SubGrad(op, grad): # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, -grad (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( @@ -1371,7 +1372,7 @@ def _MulGrad(op, grad): # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad) and grad.dtype in (dtypes.int32, dtypes.float32)): return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) @@ -1403,7 +1404,7 @@ def _MulNoNanGrad(op, grad): """The gradient of scalar multiplication with NaN-suppression.""" x = op.inputs[0] y = op.inputs[1] - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad) assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) @@ -1625,7 +1626,7 @@ def _SquaredDifferenceGrad(op, grad): # Tensor (not a number like 2.0) which causes it to convert to Tensor. x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return x_grad, -x_grad diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 39f8e77a520cd6..c90bde289564df 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -76,6 +76,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -995,8 +996,8 @@ def cast(x, dtype, name=None): """ base_type = dtypes.as_dtype(dtype).base_dtype - if isinstance(x, - (ops.Tensor, _resource_variable_type)) and base_type == x.dtype: + if isinstance( + x, (tensor_lib.Tensor, _resource_variable_type)) and base_type == x.dtype: return x with ops.name_scope(name, "Cast", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): @@ -1386,8 +1387,8 @@ def to_complex128(x, name="ToComplex128"): return cast(x, dtypes.complex128, name=name) -ops.Tensor._override_operator("__neg__", gen_math_ops.neg) -ops.Tensor._override_operator("__abs__", abs) +tensor_lib.Tensor._override_operator("__neg__", gen_math_ops.neg) +tensor_lib.Tensor._override_operator("__abs__", abs) def _maybe_get_dtype(x): @@ -1396,7 +1397,7 @@ def _maybe_get_dtype(x): # value (not just dtype) of np.ndarray to decide the result type. if isinstance(x, numbers.Real): return x - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): return x.dtype.as_numpy_dtype if isinstance(x, dtypes.DType): return x.as_numpy_dtype @@ -1442,7 +1443,7 @@ def maybe_promote_tensors(*tensors, force_same_dtype=False): result_type = np_dtypes._result_type( *[_maybe_get_dtype(x) for x in nest.flatten(tensors)]) def _promote_or_cast(x): - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): x = cast(x, result_type) else: x = ops.convert_to_tensor(x, result_type) @@ -1450,7 +1451,8 @@ def _promote_or_cast(x): return [_promote_or_cast(x) for x in tensors] -def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor): +def _OverrideBinaryOperatorHelper( + func, op_name, clazz_object=tensor_lib.Tensor): """Register operators with different tensor and scalar versions. If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices, @@ -1516,7 +1518,7 @@ def r_binary_op_wrapper(y, x): r_binary_op_wrapper.__doc__ = doc binary_op_wrapper_sparse.__doc__ = doc - if clazz_object is ops.Tensor: + if clazz_object is tensor_lib.Tensor: clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper) del binary_op_wrapper clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper) @@ -1835,7 +1837,7 @@ def _add_dispatch(x, y, name=None): Returns: The result of the elementwise `+` operation. """ - if not isinstance(y, ops.Tensor) and not isinstance( + if not isinstance(y, tensor_lib.Tensor) and not isinstance( y, sparse_tensor.SparseTensor): y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y") if x.dtype == dtypes.string: @@ -1953,7 +1955,7 @@ def invert_(x, name=None): _OverrideBinaryOperatorHelper(and_, "and") _OverrideBinaryOperatorHelper(or_, "or") _OverrideBinaryOperatorHelper(xor_, "xor") -ops.Tensor._override_operator("__invert__", invert_) +tensor_lib.Tensor._override_operator("__invert__", invert_) def _promote_dtypes_decorator(fn): @@ -1963,13 +1965,13 @@ def wrapper(x, y, *args, **kwargs): return tf_decorator.make_decorator(fn, wrapper) -ops.Tensor._override_operator("__lt__", _promote_dtypes_decorator( +tensor_lib.Tensor._override_operator("__lt__", _promote_dtypes_decorator( gen_math_ops.less)) -ops.Tensor._override_operator("__le__", _promote_dtypes_decorator( +tensor_lib.Tensor._override_operator("__le__", _promote_dtypes_decorator( gen_math_ops.less_equal)) -ops.Tensor._override_operator("__gt__", _promote_dtypes_decorator( +tensor_lib.Tensor._override_operator("__gt__", _promote_dtypes_decorator( gen_math_ops.greater)) -ops.Tensor._override_operator("__ge__", _promote_dtypes_decorator( +tensor_lib.Tensor._override_operator("__ge__", _promote_dtypes_decorator( gen_math_ops.greater_equal)) @@ -2077,8 +2079,11 @@ def tensor_equals(self, other): if other is None: return False g = getattr(self, "graph", None) - if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and - (g is None or g.building_function)): + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + and (g is None or g.building_function) + ): self, other = maybe_promote_tensors(self, other) return gen_math_ops.equal(self, other, incompatible_shape_error=False) else: @@ -2115,7 +2120,10 @@ def tensor_not_equals(self, other): """ if other is None: return True - if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + ): self, other = maybe_promote_tensors(self, other) return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) else: @@ -2123,8 +2131,8 @@ def tensor_not_equals(self, other): return self is not other -ops.Tensor._override_operator("__eq__", tensor_equals) -ops.Tensor._override_operator("__ne__", tensor_not_equals) +tensor_lib.Tensor._override_operator("__eq__", tensor_equals) +tensor_lib.Tensor._override_operator("__ne__", tensor_not_equals) @tf_export("range") @@ -2184,11 +2192,11 @@ def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disa start, limit = 0, start with ops.name_scope(name, "Range", [start, limit, delta]) as name: - if not isinstance(start, ops.Tensor): + if not isinstance(start, tensor_lib.Tensor): start = ops.convert_to_tensor(start, dtype=dtype, name="start") - if not isinstance(limit, ops.Tensor): + if not isinstance(limit, tensor_lib.Tensor): limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit") - if not isinstance(delta, ops.Tensor): + if not isinstance(delta, tensor_lib.Tensor): delta = ops.convert_to_tensor(delta, dtype=dtype, name="delta") # infer dtype if not explicitly provided @@ -3941,7 +3949,7 @@ def _as_indexed_slices(x, optimize=True): TypeError: If 'x' is not a Tensor or an IndexedSlices object. """ # TODO(touts): op_scope - if not isinstance(x, (ops.Tensor, indexed_slices.IndexedSlices)): + if not isinstance(x, (tensor_lib.Tensor, indexed_slices.IndexedSlices)): raise TypeError(f"Not a Tensor or IndexedSlices: {type(x)}.") if isinstance(x, indexed_slices.IndexedSlices): return x @@ -4109,7 +4117,7 @@ def add_n(inputs, name=None): "Tensor/IndexedSlices with the same dtype and shape.") inputs = indexed_slices.convert_n_to_tensor_or_indexed_slices(inputs) if not all( - isinstance(x, (ops.Tensor, indexed_slices.IndexedSlices)) + isinstance(x, (tensor_lib.Tensor, indexed_slices.IndexedSlices)) for x in inputs): raise ValueError("Inputs must be an iterable of at least one " "Tensor/IndexedSlices with the same dtype and shape.") @@ -4185,7 +4193,7 @@ def _input_error(): if not inputs or not isinstance(inputs, (list, tuple)): raise _input_error() inputs = indexed_slices.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): + if not all(isinstance(x, tensor_lib.Tensor) for x in inputs): raise _input_error() if not all(x.dtype == inputs[0].dtype for x in inputs): raise _input_error() @@ -4194,7 +4202,7 @@ def _input_error(): else: shape = tensor_shape.unknown_shape() for input_tensor in inputs: - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor_lib.Tensor): shape = shape.merge_with(input_tensor.get_shape()) # tensor_dtype is for safety only; operator's output type computed in C++ @@ -4542,7 +4550,7 @@ def conj(x, name=None): Equivalent to numpy.conj. @end_compatibility """ - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): dt = x.dtype if dt.is_floating or dt.is_integer: return x diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 350524ef4aa689..01eefe80f74ba2 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients @@ -784,9 +785,9 @@ def testConsistent(self): def testWithPythonValue(self): # Test case for https://github.com/tensorflow/tensorflow/issues/39475 x = math_ops.divide(5, 2) - self.assertIsInstance(x, ops.Tensor) + self.assertIsInstance(x, tensor_lib.Tensor) x = math_ops.divide(5, array_ops.constant(2.0)) - self.assertIsInstance(x, ops.Tensor) + self.assertIsInstance(x, tensor_lib.Tensor) def intEdgeTestData(self, dtype): """Edge-case test data for integer types.""" @@ -1206,7 +1207,7 @@ def testEqualityNoDowncast(self, is_equals, float_literal): x = constant_op.constant(4) try: result = op(x, float_literal) - if isinstance(result, ops.Tensor): + if isinstance(result, tensor_lib.Tensor): result = self.evaluate(result) except TypeError: # Throwing a TypeError is OK diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index d402ff88d53540..81ca682f9464d7 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -178,6 +178,7 @@ from tensorflow.python.framework import graph_util from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -1242,7 +1243,8 @@ def convolution_internal( not tensor_util.is_tf_type(filters)): with ops.name_scope("convolution_internal", None, [filters, input]): filters = ops.convert_to_tensor(filters, name='filters') - if (not isinstance(input, ops.Tensor) and not tensor_util.is_tf_type(input)): + if (not isinstance(input, tensor_lib.Tensor) and not tensor_util.is_tf_type( + input)): with ops.name_scope("convolution_internal", None, [filters, input]): input = ops.convert_to_tensor(input, name="input") @@ -2239,7 +2241,7 @@ def conv1d_transpose( input = array_ops.expand_dims(input, spatial_start_dim) filters = array_ops.expand_dims(filters, 0) output_shape = list(output_shape) if not isinstance( - output_shape, ops.Tensor) else output_shape + output_shape, tensor_lib.Tensor) else output_shape output_shape = array_ops.concat([output_shape[: spatial_start_dim], [1], output_shape[spatial_start_dim:]], 0) @@ -3819,7 +3821,7 @@ def _swap_axis(input_tensor, dim_index, last_index, name=None): return compute_op(inputs, name=name) dim_val = dim - if isinstance(dim, ops.Tensor): + if isinstance(dim, tensor_lib.Tensor): dim_val = tensor_util.constant_value(dim) if dim_val is not None and not -shape.ndims <= dim_val < shape.ndims: raise errors_impl.InvalidArgumentError( @@ -3833,7 +3835,7 @@ def _swap_axis(input_tensor, dim_index, last_index, name=None): # In case dim is negative (and is not last dimension -1), add shape.ndims ndims = array_ops.rank(inputs) - if not isinstance(dim, ops.Tensor): + if not isinstance(dim, tensor_lib.Tensor): if dim < 0: dim += ndims else: diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD index fc6e6d7b420141..11a521d3d98053 100644 --- a/tensorflow/python/ops/numpy_ops/BUILD +++ b/tensorflow/python/ops/numpy_ops/BUILD @@ -57,6 +57,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", @@ -66,7 +67,6 @@ py_strict_library( "//tensorflow/python/ops:manip_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:sort_ops", - "//tensorflow/python/util:dispatch", "//tensorflow/python/util:nest", "//third_party/py/numpy", ], @@ -126,6 +126,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/ops:bitwise_ops", @@ -137,7 +138,6 @@ py_strict_library( "//tensorflow/python/ops:sort_ops", "//tensorflow/python/ops:special_math_ops", "//tensorflow/python/ops:while_loop", - "//tensorflow/python/util:dispatch", "//third_party/py/numpy", ], ) @@ -148,7 +148,7 @@ py_strict_library( deps = [ ":np_dtypes", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", ], ) @@ -175,6 +175,7 @@ cuda_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/util:nest", @@ -231,6 +232,7 @@ cuda_py_strict_test( ":np_math_ops", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index 10b676e1d3f075..638e4935f95029 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -39,7 +40,6 @@ from tensorflow.python.ops.numpy_ops import np_dtypes from tensorflow.python.ops.numpy_ops import np_export from tensorflow.python.ops.numpy_ops import np_utils -from tensorflow.python.util import dispatch from tensorflow.python.util import nest @@ -51,7 +51,6 @@ def empty(shape, dtype=float): # pylint: disable=redefined-outer-name return zeros(shape, dtype) -@dispatch.add_dispatch_support @np_utils.np_doc('empty_like') def empty_like(a, dtype=None): return zeros_like(a, dtype) @@ -64,7 +63,6 @@ def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name return array_ops.zeros(shape, dtype=dtype) -@dispatch.add_dispatch_support @np_utils.np_doc('zeros_like') def zeros_like(a, dtype=None): # pylint: disable=missing-docstring dtype = np_utils.result_type_unary(a, dtype) @@ -80,7 +78,6 @@ def ones(shape, dtype=float): # pylint: disable=redefined-outer-name return array_ops.ones(shape, dtype=dtype) -@dispatch.add_dispatch_support @np_utils.np_doc('ones_like') def ones_like(a, dtype=None): dtype = np_utils.result_type_unary(a, dtype) @@ -154,7 +151,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=red """Main implementation of np.array().""" result_t = val - if not isinstance(result_t, ops.Tensor): + if not isinstance(result_t, tensor_lib.Tensor): dtype = np_utils.result_type_unary(result_t, dtype) # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) @@ -199,7 +196,6 @@ def true_fn(): # TODO(wangpeng): investigate whether we can make `copy` default to False. # pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args -@dispatch.add_dispatch_support @np_utils.np_doc_only('array') def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name """Since Tensors are immutable, a copy is made only if val is placed on a @@ -216,7 +212,6 @@ def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-out # pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args -@dispatch.add_dispatch_support @np_utils.np_doc('asarray') def asarray(a, dtype=None): if dtype: @@ -227,20 +222,17 @@ def asarray(a, dtype=None): return array(a, dtype, copy=False) -@dispatch.add_dispatch_support @np_utils.np_doc('asanyarray') def asanyarray(a, dtype=None): return asarray(a, dtype) -@dispatch.add_dispatch_support @np_utils.np_doc('ascontiguousarray') def ascontiguousarray(a, dtype=None): return array(a, dtype, ndmin=1) # Numerical ranges. -@dispatch.add_dispatch_support @np_utils.np_doc('arange') def arange(start, stop=None, step=1, dtype=None): """Returns `step`-separated values in the range [start, stop). @@ -283,7 +275,6 @@ def arange(start, stop=None, step=1, dtype=None): # Building matrices. -@dispatch.add_dispatch_support @np_utils.np_doc('diag') def diag(v, k=0): # pylint: disable=missing-docstring """Raises an error if input is not 1- or 2-d.""" @@ -319,7 +310,6 @@ def _diag_part(v, k): return result -@dispatch.add_dispatch_support @np_utils.np_doc('diagonal') def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring a = asarray(a) @@ -351,7 +341,6 @@ def _zeros(): # pylint: disable=missing-docstring return a -@dispatch.add_dispatch_support @np_utils.np_doc('diagflat') def diagflat(v, k=0): v = asarray(v) @@ -420,7 +409,6 @@ def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,m return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis) -@dispatch.add_dispatch_support @np_utils.np_doc('copy') def copy(a): return array(a, copy=True) @@ -438,7 +426,6 @@ def _maybe_promote_to_int(a): return a -@dispatch.add_dispatch_support @np_utils.np_doc('cumprod') def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring a = asarray(a, dtype=dtype) @@ -455,7 +442,6 @@ def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring return math_ops.cumprod(a, axis) -@dispatch.add_dispatch_support @np_utils.np_doc('cumsum') def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring a = asarray(a, dtype=dtype) @@ -472,7 +458,6 @@ def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring return math_ops.cumsum(a, axis) -@dispatch.add_dispatch_support @np_utils.np_doc('imag') def imag(val): val = asarray(val) @@ -548,7 +533,7 @@ def _reduce(tf_fn, elif promote_int == _TO_FLOAT: a = math_ops.cast(a, np_dtypes.default_float_type()) - if isinstance(axis, ops.Tensor) and axis.dtype not in ( + if isinstance(axis, tensor_lib.Tensor) and axis.dtype not in ( dtypes.int32, dtypes.int64): axis = math_ops.cast(axis, dtypes.int64) @@ -570,7 +555,6 @@ def size(x, axis=None): # pylint: disable=missing-docstring return array_ops.size_v2(x) -@dispatch.add_dispatch_support @np_utils.np_doc('sum') def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin return _reduce( @@ -582,7 +566,6 @@ def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-b tf_bool_fn=math_ops.reduce_any) -@dispatch.add_dispatch_support @np_utils.np_doc('prod') def prod(a, axis=None, dtype=None, keepdims=None): return _reduce( @@ -594,7 +577,6 @@ def prod(a, axis=None, dtype=None, keepdims=None): tf_bool_fn=math_ops.reduce_all) -@dispatch.add_dispatch_support @np_utils.np_doc('mean', unsupported_params=['out']) def mean(a, axis=None, dtype=None, out=None, keepdims=None): if out is not None: @@ -608,7 +590,6 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=None): promote_int=_TO_FLOAT) -@dispatch.add_dispatch_support @np_utils.np_doc('amax', unsupported_params=['out']) def amax(a, axis=None, out=None, keepdims=None): if out is not None: @@ -624,7 +605,6 @@ def amax(a, axis=None, out=None, keepdims=None): preserve_bool=True) -@dispatch.add_dispatch_support @np_utils.np_doc('amin', unsupported_params=['out']) def amin(a, axis=None, out=None, keepdims=None): if out is not None: @@ -640,7 +620,6 @@ def amin(a, axis=None, out=None, keepdims=None): preserve_bool=True) -@dispatch.add_dispatch_support @np_utils.np_doc('var') def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=missing-docstring if dtype: @@ -688,7 +667,6 @@ def reduce_fn(input_tensor, axis, keepdims): return result -@dispatch.add_dispatch_support @np_utils.np_doc('std') def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring return _reduce( @@ -700,14 +678,12 @@ def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstr promote_int=_TO_FLOAT) -@dispatch.add_dispatch_support @np_utils.np_doc('ravel') def ravel(a): # pylint: disable=missing-docstring a = asarray(a) return array_ops.reshape(a, [-1]) -@dispatch.add_dispatch_support @np_utils.np_doc('real') def real(val): val = asarray(val) @@ -747,7 +723,6 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring return result -@dispatch.add_dispatch_support @np_utils.np_doc('around') def around(a, decimals=0): # pylint: disable=missing-docstring a = asarray(a) @@ -770,7 +745,6 @@ def around(a, decimals=0): # pylint: disable=missing-docstring setattr(np_arrays.ndarray, '__round__', around) -@dispatch.add_dispatch_support @np_utils.np_doc('reshape') def reshape(a, newshape, order='C'): """order argument can only b 'C' or 'F'.""" @@ -801,21 +775,18 @@ def _reshape_method_wrapper(a, *newshape, **kwargs): return reshape(a, newshape, order=order) -@dispatch.add_dispatch_support @np_utils.np_doc('expand_dims') def expand_dims(a, axis): a = asarray(a) return array_ops.expand_dims(a, axis=axis) -@dispatch.add_dispatch_support @np_utils.np_doc('squeeze') def squeeze(a, axis=None): a = asarray(a) return array_ops.squeeze(a, axis) -@dispatch.add_dispatch_support @np_utils.np_doc('flatten', link=np_utils.NoLink()) def flatten(a, order='C'): a = asarray(a) @@ -830,7 +801,6 @@ def flatten(a, order='C'): '(column major).') -@dispatch.add_dispatch_support @np_utils.np_doc('transpose') def transpose(a, axes=None): a = asarray(a) @@ -839,7 +809,6 @@ def transpose(a, axes=None): return array_ops.transpose(a=a, perm=axes) -@dispatch.add_dispatch_support @np_utils.np_doc('swapaxes') def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring a = asarray(a) @@ -872,7 +841,6 @@ def f(x): return a -@dispatch.add_dispatch_support @np_utils.np_doc('moveaxis') def moveaxis(a, source, destination): # pylint: disable=missing-docstring """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" @@ -1096,7 +1064,7 @@ def broadcast_to(array, shape): # pylint: disable=redefined-outer-name @np_utils.np_doc('stack') def stack(arrays, axis=0): # pylint: disable=missing-function-docstring - if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)): + if isinstance(arrays, (np_arrays.ndarray, tensor_lib.Tensor)): arrays = asarray(arrays) if axis == 0: return arrays @@ -1280,7 +1248,6 @@ def tril(m, k=0): # pylint: disable=missing-docstring array_ops.broadcast_to(mask, array_ops.shape(m)), m, z) -@dispatch.add_dispatch_support @np_utils.np_doc('triu') def triu(m, k=0): # pylint: disable=missing-docstring m = asarray(m) @@ -1302,7 +1269,6 @@ def triu(m, k=0): # pylint: disable=missing-docstring array_ops.broadcast_to(mask, array_ops.shape(m)), z, m) -@dispatch.add_dispatch_support @np_utils.np_doc('flip') def flip(m, axis=None): # pylint: disable=missing-docstring m = asarray(m) @@ -1315,13 +1281,11 @@ def flip(m, axis=None): # pylint: disable=missing-docstring return array_ops.reverse(m, [axis]) -@dispatch.add_dispatch_support @np_utils.np_doc('flipud') def flipud(m): # pylint: disable=missing-docstring return flip(m, 0) -@dispatch.add_dispatch_support @np_utils.np_doc('fliplr') def fliplr(m): # pylint: disable=missing-docstring return flip(m, 1) @@ -1340,7 +1304,6 @@ def roll(a, shift, axis=None): # pylint: disable=missing-docstring return array_ops.reshape(a, original_shape) -@dispatch.add_dispatch_support @np_utils.np_doc('rot90') def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring m_rank = array_ops.rank(m) @@ -1361,7 +1324,6 @@ def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring return flip(transpose(m, perm), ax2) -@dispatch.add_dispatch_support @np_utils.np_doc('vander') def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name x = asarray(x) @@ -1520,19 +1482,16 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring # pylint: disable=redefined-builtin,undefined-variable -@dispatch.add_dispatch_support @np_utils.np_doc('max', link=np_utils.AliasOf('amax')) def max(a, axis=None, keepdims=None): return amax(a, axis=axis, keepdims=keepdims) -@dispatch.add_dispatch_support @np_utils.np_doc('min', link=np_utils.AliasOf('amin')) def min(a, axis=None, keepdims=None): return amin(a, axis=axis, keepdims=keepdims) -@dispatch.add_dispatch_support @np_utils.np_doc('round', link=np_utils.AliasOf('around')) def round(a, decimals=0): return around(a, decimals=decimals) @@ -1880,10 +1839,17 @@ def _as_spec_tuple(slice_spec): def _getitem(self, slice_spec): """Implementation of ndarray.__getitem__.""" - if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and - slice_spec.dtype == dtypes.bool) or - (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and - slice_spec.dtype == np.bool_)): + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool + ) + or ( + isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) + and slice_spec.dtype == np.bool_ + ) + ): return array_ops.boolean_mask(tensor=self, mask=slice_spec) if not isinstance(slice_spec, tuple): @@ -1895,10 +1861,17 @@ def _getitem(self, slice_spec): def _with_index_update_helper(update_method, a, slice_spec, updates): """Implementation of ndarray._with_index_*.""" - if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and - slice_spec.dtype == dtypes.bool) or - (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and - slice_spec.dtype == np.bool_)): + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool + ) + or ( + isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) + and slice_spec.dtype == np.bool_ + ) + ): slice_spec = nonzero(slice_spec) if not isinstance(slice_spec, tuple): diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py index 987f7738c17073..78257ae37ec66b 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays.py @@ -17,7 +17,7 @@ # pylint: disable=g-direct-tensorflow-import from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.ops.numpy_ops import np_dtypes @@ -47,4 +47,4 @@ def convert_to_tensor(value, dtype=None, dtype_hint=None): value, dtype=dtype, dtype_hint=dtype_hint) -ndarray = ops.Tensor +ndarray = tensor.Tensor diff --git a/tensorflow/python/ops/numpy_ops/np_arrays_test.py b/tensorflow/python/ops/numpy_ops/np_arrays_test.py index 9985c6ce9d909e..6bba3cdbafce11 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays_test.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays_test.py @@ -19,6 +19,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops.numpy_ops import np_arrays # Required for operator overloads @@ -193,8 +194,8 @@ def testFromToCompositeTensor(self): # Each ndarray contains only one tensor, so the flattened output should be # just 2 tensors in a list. self.assertLen(flattened, 2) - self.assertIsInstance(flattened[0], ops.Tensor) - self.assertIsInstance(flattened[1], ops.Tensor) + self.assertIsInstance(flattened[0], tensor.Tensor) + self.assertIsInstance(flattened[1], tensor.Tensor) repacked = nest.pack_sequence_as(tensors, flattened, expand_composites=True) self.assertLen(repacked, 2) @@ -208,7 +209,7 @@ def testFromToCompositeTensor(self): # TODO(wangpeng): Test in graph mode as well. Also test in V2 (the requirement # for setting _USE_EQUALITY points to V2 behavior not being on). ops.enable_eager_execution() - ops.Tensor._USE_EQUALITY = True + tensor.Tensor._USE_EQUALITY = True ops.set_dtype_conversion_mode('legacy') np_math_ops.enable_numpy_methods_on_tensor() test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index 701c8f1eb1a92d..7aad67c44c0304 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import bitwise_ops @@ -40,7 +41,6 @@ from tensorflow.python.ops.numpy_ops import np_dtypes from tensorflow.python.ops.numpy_ops import np_export from tensorflow.python.ops.numpy_ops import np_utils -from tensorflow.python.util import dispatch pi = np_export.np_export_constant(__name__, 'pi', np.pi) @@ -569,7 +569,6 @@ def bitwise_xor(x1, x2): return _bitwise_binary_op(bitwise_ops.bitwise_xor, x1, x2) -@dispatch.add_dispatch_support @np_utils.np_doc('bitwise_not', link=np_utils.AliasOf('invert')) def bitwise_not(x): @@ -602,73 +601,61 @@ def _scalar(tf_fn, x, promote_to_float=False): return tf_fn(x) -@dispatch.add_dispatch_support @np_utils.np_doc('log') def log(x): return _scalar(math_ops.log, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('exp') def exp(x): return _scalar(math_ops.exp, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('sqrt') def sqrt(x): return _scalar(math_ops.sqrt, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('abs', link=np_utils.AliasOf('absolute')) def abs(x): # pylint: disable=redefined-builtin return _scalar(math_ops.abs, x) -@dispatch.add_dispatch_support @np_utils.np_doc('absolute') def absolute(x): return abs(x) -@dispatch.add_dispatch_support @np_utils.np_doc('fabs') def fabs(x): return abs(x) -@dispatch.add_dispatch_support @np_utils.np_doc('ceil') def ceil(x): return _scalar(math_ops.ceil, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('floor') def floor(x): return _scalar(math_ops.floor, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('conj') def conj(x): return _scalar(math_ops.conj, x) -@dispatch.add_dispatch_support @np_utils.np_doc('negative') def negative(x): return _scalar(math_ops.negative, x) -@dispatch.add_dispatch_support @np_utils.np_doc('reciprocal') def reciprocal(x): return _scalar(math_ops.reciprocal, x) -@dispatch.add_dispatch_support @np_utils.np_doc('signbit') def signbit(x): @@ -680,79 +667,66 @@ def f(x): return _scalar(f, x) -@dispatch.add_dispatch_support @np_utils.np_doc('sin') def sin(x): return _scalar(math_ops.sin, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('cos') def cos(x): return _scalar(math_ops.cos, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('tan') def tan(x): return _scalar(math_ops.tan, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('sinh') def sinh(x): return _scalar(math_ops.sinh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('cosh') def cosh(x): return _scalar(math_ops.cosh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('tanh') def tanh(x): return _scalar(math_ops.tanh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arcsin') def arcsin(x): return _scalar(math_ops.asin, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arccos') def arccos(x): return _scalar(math_ops.acos, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arctan') def arctan(x): return _scalar(math_ops.atan, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arcsinh') def arcsinh(x): return _scalar(math_ops.asinh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arccosh') def arccosh(x): return _scalar(math_ops.acosh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arctanh') def arctanh(x): return _scalar(math_ops.atanh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('deg2rad') def deg2rad(x): @@ -762,7 +736,6 @@ def f(x): return _scalar(f, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('rad2deg') def rad2deg(x): return x * (180.0 / np.pi) @@ -773,7 +746,6 @@ def rad2deg(x): ] -@dispatch.add_dispatch_support @np_utils.np_doc('angle') def angle(z, deg=False): # pylint: disable=missing-function-docstring @@ -790,7 +762,6 @@ def f(x): return y -@dispatch.add_dispatch_support @np_utils.np_doc('cbrt') def cbrt(x): @@ -802,13 +773,11 @@ def f(x): return _scalar(f, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('conjugate', link=np_utils.AliasOf('conj')) def conjugate(x): return _scalar(math_ops.conj, x) -@dispatch.add_dispatch_support @np_utils.np_doc('exp2') def exp2(x): @@ -818,13 +787,11 @@ def f(x): return _scalar(f, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('expm1') def expm1(x): return _scalar(math_ops.expm1, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('fix') def fix(x): @@ -880,7 +847,6 @@ def nan_reduction(a, axis=None, dtype=None, keepdims=False): nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1) -@dispatch.add_dispatch_support @np_utils.np_doc('nanmean') def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring a = np_array_ops.array(a) @@ -921,31 +887,26 @@ def isposinf(x): return False -@dispatch.add_dispatch_support @np_utils.np_doc('log2') def log2(x): return log(x) / np.log(2) -@dispatch.add_dispatch_support @np_utils.np_doc('log10') def log10(x): return log(x) / np.log(10) -@dispatch.add_dispatch_support @np_utils.np_doc('log1p') def log1p(x): return _scalar(math_ops.log1p, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('positive') def positive(x): return _scalar(lambda x: x, x) -@dispatch.add_dispatch_support @np_utils.np_doc('sinc') def sinc(x): @@ -957,13 +918,11 @@ def f(x): return _scalar(f, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('square') def square(x): return _scalar(math_ops.square, x) -@dispatch.add_dispatch_support @np_utils.np_doc('diff') def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring @@ -1245,7 +1204,6 @@ def _argsort(a, axis, stable): return np_array_ops.array(tf_ans, dtype=np.intp) -@dispatch.add_dispatch_support @np_utils.np_doc('sort') def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring if kind != 'quicksort': @@ -1292,7 +1250,6 @@ def append(arr, values, axis=None): return concatenate([arr, values], axis=axis) -@dispatch.add_dispatch_support @np_utils.np_doc('average') def average(a, axis=None, weights=None, returned=False): # pylint: disable=missing-docstring if axis is not None and not isinstance(axis, int): @@ -1355,7 +1312,6 @@ def rank_not_equal_case(): return avg -@dispatch.add_dispatch_support @np_utils.np_doc('trace') def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring if dtype: @@ -1456,39 +1412,44 @@ def _tensor_size(self): def _tensor_tolist(self): - if isinstance(self, ops.EagerTensor): - return self._numpy().tolist() # pylint: disable=protected-access + if ops.is_symbolic_tensor(self): + raise ValueError('Symbolic Tensors do not support the tolist API.') - raise ValueError('Symbolic Tensors do not support the tolist API.') + return self._numpy().tolist() # pylint: disable=protected-access -def enable_numpy_methods_on_tensor(): - """Adds additional NumPy methods on tf.Tensor class.""" +def _enable_numpy_methods(tensor_class): + """A helper method for adding additional NumPy methods.""" t = property(_tensor_t) - setattr(ops.Tensor, 'T', t) + setattr(tensor_class, 'T', t) ndim = property(_tensor_ndim) - setattr(ops.Tensor, 'ndim', ndim) + setattr(tensor_class, 'ndim', ndim) size = property(_tensor_size) - setattr(ops.Tensor, 'size', size) + setattr(tensor_class, 'size', size) - setattr(ops.Tensor, '__pos__', _tensor_pos) - setattr(ops.Tensor, 'tolist', _tensor_tolist) + setattr(tensor_class, '__pos__', _tensor_pos) + setattr(tensor_class, 'tolist', _tensor_tolist) # TODO(b/178540516): Make a custom `setattr` that changes the method's # docstring to the TF one. - setattr(ops.Tensor, 'transpose', np_array_ops.transpose) - setattr(ops.Tensor, 'flatten', np_array_ops.flatten) - setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access - setattr(ops.Tensor, 'ravel', np_array_ops.ravel) - setattr(ops.Tensor, 'clip', clip) - setattr(ops.Tensor, 'astype', math_ops.cast) - setattr(ops.Tensor, '__round__', np_array_ops.around) - setattr(ops.Tensor, 'max', np_array_ops.amax) - setattr(ops.Tensor, 'mean', np_array_ops.mean) - setattr(ops.Tensor, 'min', np_array_ops.amin) + setattr(tensor_class, 'transpose', np_array_ops.transpose) + setattr(tensor_class, 'flatten', np_array_ops.flatten) + setattr(tensor_class, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access + setattr(tensor_class, 'ravel', np_array_ops.ravel) + setattr(tensor_class, 'clip', clip) + setattr(tensor_class, 'astype', math_ops.cast) + setattr(tensor_class, '__round__', np_array_ops.around) + setattr(tensor_class, 'max', np_array_ops.amax) + setattr(tensor_class, 'mean', np_array_ops.mean) + setattr(tensor_class, 'min', np_array_ops.amin) # TODO(wangpeng): Remove `data` when all uses of it are removed data = property(lambda self: self) - setattr(ops.Tensor, 'data', data) + setattr(tensor_class, 'data', data) + + +def enable_numpy_methods_on_tensor(): + """Adds additional NumPy methods on tf.Tensor class.""" + _enable_numpy_methods(tensor.Tensor) diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops_test.py b/tensorflow/python/ops/numpy_ops/np_math_ops_test.py index 534a1dc9335f39..2a6b6368e8fb14 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops_test.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops_test.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.numpy_ops import np_math_ops @@ -377,7 +378,7 @@ def testIsInf(self): self.assertFalse(np_math_ops.isneginf(x2)) if __name__ == '__main__': - ops.enable_tensor_equality() + tensor.enable_tensor_equality() ops.enable_eager_execution() ops.set_dtype_conversion_mode('legacy') np_math_ops.enable_numpy_methods_on_tensor() diff --git a/tensorflow/python/ops/numpy_ops/tests/BUILD b/tensorflow/python/ops/numpy_ops/tests/BUILD index 8e03c69abaa7aa..953ee8828cdbcb 100644 --- a/tensorflow/python/ops/numpy_ops/tests/BUILD +++ b/tensorflow/python/ops/numpy_ops/tests/BUILD @@ -23,7 +23,6 @@ py_strict_library( ":extensions", "//tensorflow:tensorflow_py", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops/numpy_ops:np_array_ops", diff --git a/tensorflow/python/ops/numpy_ops/tests/extensions.py b/tensorflow/python/ops/numpy_ops/tests/extensions.py index b915dff755bc42..901a662ff80c7e 100644 --- a/tensorflow/python/ops/numpy_ops/tests/extensions.py +++ b/tensorflow/python/ops/numpy_ops/tests/extensions.py @@ -447,7 +447,7 @@ def eval_on_shapes(f, static_argnums=(), allow_static_outputs=False): def abstractify(args): def _abstractify(x): x = _canonicalize_jit_arg(x) - if isinstance(x, (ops.Tensor, tf_np.ndarray)): + if isinstance(x, (tensor_lib.Tensor, tf_np.ndarray)): return tensor_lib.TensorSpec(x.shape, x.dtype) else: return x @@ -472,7 +472,7 @@ def recorder(args, kwargs, res): def is_tensor_like(x): if hasattr(x, "_type_spec"): return True # x is a CompositeTensor - return isinstance(x, (tf_np.ndarray, ops.Tensor)) + return isinstance(x, (tf_np.ndarray, tensor_lib.Tensor)) py_values = nest.map_structure( lambda x: None if is_tensor_like(x) else x, res ) @@ -494,7 +494,7 @@ def is_tensor_like(x): # pylint: disable=missing-docstring def f_return(*args): def to_tensor_spec(x): - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): return tensor_lib.TensorSpec(x.shape, x.dtype) else: return x @@ -1574,14 +1574,14 @@ def dataset_as_numpy(dataset): # Type check for Tensors and Datasets for ds_el in flat_ds: - if not isinstance(ds_el, (ops.Tensor, dataset_ops.DatasetV2)): + if not isinstance(ds_el, (tensor_lib.Tensor, dataset_ops.DatasetV2)): types = nest.map_structure(type, nested_ds) raise ValueError("Arguments to dataset_as_numpy must be (possibly nested " "structure of) tf.Tensors or tf.data.Datasets. Got: %s" % types) for ds_el in flat_ds: - if isinstance(ds_el, ops.Tensor): + if isinstance(ds_el, tensor_lib.Tensor): np_el = tf_np.asarray(ds_el) elif isinstance(ds_el, dataset_ops.DatasetV2): np_el = _eager_dataset_iterator(ds_el) @@ -1888,7 +1888,7 @@ def wrapper(*args): flattened_input_args = nest.flatten(args) flattened_per_device_args = [[] for _ in devices] for arg in flattened_input_args: - if isinstance(arg, ops.Tensor): + if isinstance(arg, tensor_lib.Tensor): # TODO(nareshmodi): Try and use the dynamic shape instead. if (not arg.shape.rank) or arg.shape[0] != len(devices): # TODO(nareshmodi): Fix this restriction @@ -1932,7 +1932,7 @@ def wrapper(*args): tensors = [] for j, device in enumerate(devices): assert isinstance( - flattened_results[j][i], ops.Tensor + flattened_results[j][i], tensor_lib.Tensor ), "currently only tensor return items are supported" tensors.append(flattened_results[j][i]) final_tree.append(ShardedNdArray(tensors)) diff --git a/tensorflow/python/ops/numpy_ops/tests/test_util.py b/tensorflow/python/ops/numpy_ops/tests/test_util.py index 27840cc70251fc..cf178a3f5dfbcc 100644 --- a/tensorflow/python/ops/numpy_ops/tests/test_util.py +++ b/tensorflow/python/ops/numpy_ops/tests/test_util.py @@ -32,7 +32,6 @@ import numpy.random as npr from tensorflow.python.util import nest -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor from tensorflow.python.framework import dtypes from tensorflow.python.ops import gradient_checker_v2 @@ -83,7 +82,7 @@ def _dtype(x): - if isinstance(x, ops.Tensor): + if isinstance(x, tensor.Tensor): return x.dtype.as_numpy_dtype return (getattr(x, 'dtype', None) or onp.dtype(python_scalar_dtypes.get(type(x), None)) or diff --git a/tensorflow/python/ops/op_selector.py b/tensorflow/python/ops/op_selector.py index 77258b8f117726..3ddbe8c8f433ec 100644 --- a/tensorflow/python/ops/op_selector.py +++ b/tensorflow/python/ops/op_selector.py @@ -15,6 +15,7 @@ """Tools for selecting ops in a graph.""" from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.util import object_identity @@ -27,7 +28,7 @@ def is_differentiable(op): def is_iterable(obj): """Return true if the object is iterable.""" - if isinstance(obj, ops.Tensor): + if isinstance(obj, tensor_lib.Tensor): return False try: _ = iter(obj) @@ -94,7 +95,7 @@ def get_unique_graph(tops, check_types=None, none_if_empty=False): if not is_iterable(tops): raise TypeError("{} is not iterable".format(type(tops))) if check_types is None: - check_types = (ops.Operation, ops.Tensor) + check_types = (ops.Operation, tensor_lib.Tensor) elif not is_iterable(check_types): check_types = (check_types,) g = None @@ -153,9 +154,9 @@ def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): if not ts: return [] if check_graph: - check_types = None if ignore_ops else ops.Tensor + check_types = None if ignore_ops else tensor_lib.Tensor get_unique_graph(ts, check_types=check_types) - return [t for t in ts if isinstance(t, ops.Tensor)] + return [t for t in ts if isinstance(t, tensor_lib.Tensor)] def get_generating_ops(ts): @@ -272,7 +273,7 @@ def get_backward_walk_ops(seed_ops, # Empty iterable. return [] - if isinstance(first_seed_op, ops.Tensor): + if isinstance(first_seed_op, tensor_lib.Tensor): ts = make_list_of_t(seed_ops, allow_graph=False) seed_ops = get_generating_ops(ts) else: @@ -318,7 +319,7 @@ class UnliftableError(Exception): def _as_operation(op_or_tensor): - if isinstance(op_or_tensor, ops.Tensor): + if isinstance(op_or_tensor, tensor_lib.Tensor): return op_or_tensor.op return op_or_tensor @@ -338,7 +339,7 @@ def show_path(from_op, tensors, sources): Returns: A python string containing the path, or "??" if none is found. """ - if isinstance(from_op, ops.Tensor): + if isinstance(from_op, tensor_lib.Tensor): from_op = from_op.op if not isinstance(tensors, list): diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 2e3a3ad2f00fe7..a86b5d93a0cf70 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -37,8 +37,8 @@ py_strict_library( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:smart_cond", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_gen", @@ -95,6 +95,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", @@ -256,6 +257,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:clip_ops", @@ -277,6 +279,7 @@ py_strict_library( deps = [ ":control_flow_ops", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", "//tensorflow/python/ops:gradients_impl", diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index d4f102e09221d6..e65e4fdd1c1a2c 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec @@ -314,7 +315,7 @@ def _pfor_impl(loop_fn, for loop_fn_output in nest.flatten(loop_fn_output_tensors): if (loop_fn_output is not None and not isinstance( loop_fn_output, - (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))): + (ops.Operation, tensor.Tensor, sparse_tensor.SparseTensor))): if isinstance(loop_fn_output, indexed_slices.IndexedSlices): logging.warn("Converting %s to a dense representation may make it slow." " Alternatively, output the indices and values of the" diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py index 3ef8e0cc58ce9e..da667a5e1bbde5 100644 --- a/tensorflow/python/ops/parallel_for/gradients.py +++ b/tensorflow/python/ops/parallel_for/gradients.py @@ -14,6 +14,7 @@ # ============================================================================== """Jacobian ops.""" from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import gradients_impl as gradient_ops @@ -66,7 +67,7 @@ def loop_fn(i): parallel_iterations=parallel_iterations) for i, out in enumerate(pfor_outputs): - if isinstance(out, ops.Tensor): + if isinstance(out, tensor.Tensor): new_shape = array_ops.concat( [output_shape, array_ops.shape(out)[1:]], axis=0) out = array_ops.reshape(out, new_shape) diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py index 932e07bde749aa..240c94fbdd4077 100644 --- a/tensorflow/python/ops/parallel_for/math_test.py +++ b/tensorflow/python/ops/parallel_for/math_test.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -149,8 +150,8 @@ def loop_fn(i): def test_binary_cwise_ops(self): # Enable tensor equality to test `equal` and `not_equal` ops below. - default_equality = framework_ops.Tensor._USE_EQUALITY - framework_ops.enable_tensor_equality() + default_equality = tensor.Tensor._USE_EQUALITY + tensor.enable_tensor_equality() try: logical_ops = [ math_ops.logical_and, math_ops.logical_or, math_ops.logical_xor @@ -225,7 +226,7 @@ def loop_fn(i): self._test_loop_fn(loop_fn, 3) finally: if not default_equality: - framework_ops.disable_tensor_equality() + tensor.disable_tensor_equality() def test_approximate_equal(self): x = random_ops.random_uniform([3, 5]) diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 472c196c02ed71..8905f4efc8c32a 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -34,8 +34,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -229,7 +229,7 @@ def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config): self._pfor_config = pfor_config self._pfor_ops = set(pfor_ops) self._pfor_op_ids = set(x._id for x in pfor_ops) - assert isinstance(exit_node, ops.Tensor) + assert isinstance(exit_node, tensor_lib.Tensor) self._while_context = exit_node.op._get_control_flow_context() assert isinstance(self._while_context, control_flow_ops.WhileContext) self._context_name = self._while_context.name @@ -260,13 +260,13 @@ def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config): # to different Operations/Tensors of a single cycle as illustrated above. # List of Switch ops (ops.Operation) that feed into an Exit Node. self._exit_switches = [] - # List of inputs (ops.Tensor) to NextIteration. + # List of inputs (tensor_lib.Tensor) to NextIteration. self._body_outputs = [] # List of list of control inputs of the NextIteration nodes. self._next_iter_control_inputs = [] # List of Merge ops (ops.Operation). self._enter_merges = [] - # List of output (ops.Tensor) of Exit nodes. + # List of output (tensor_lib.Tensor) of Exit nodes. self._outputs = [] # List of Enter Tensors. @@ -1071,7 +1071,7 @@ def wrap(tensor, is_stacked=True, is_sparse_stacked=False): """Helper to create a WrappedTensor object.""" assert isinstance(is_stacked, bool) assert isinstance(is_sparse_stacked, bool) - assert isinstance(tensor, ops.Tensor) + assert isinstance(tensor, tensor_lib.Tensor) assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " "stacked via a sparse " "conversion, it must also be " @@ -1116,7 +1116,7 @@ def while_body(i, *ta_list): # TODO(agarwal): Add tf.debugging asserts to check that the shapes across # the different iterations are the same. for out, ta in zip(op_outputs, ta_list): - assert isinstance(out, ops.Tensor) + assert isinstance(out, tensor_lib.Tensor) outputs.append(ta.write(i, out)) return tuple([i + 1] + outputs) @@ -1143,7 +1143,7 @@ def _has_reductions(self): def _set_iters(self, iters): """Set number of pfor iterations.""" - if isinstance(iters, ops.Tensor): + if isinstance(iters, tensor_lib.Tensor): iters = tensor_util.constant_value(iters) self._maybe_iters = iters @@ -1170,12 +1170,12 @@ def reduce(self, fn, *args): # Creates a concrete function that will be used for reduction. tensor_specs = [] for arg in args: - if not isinstance(arg, ops.Tensor): + if not isinstance(arg, tensor_lib.Tensor): raise ValueError(f"Got a non-Tensor argument {arg} in reduce.") batched_shape = tensor_shape.TensorShape([self._maybe_iters ]).concatenate(arg.shape) tensor_specs.append( - tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype)) + tensor_lib.TensorSpec(shape=batched_shape, dtype=arg.dtype)) concrete_function = def_function.function(fn).get_concrete_function( *tensor_specs) @@ -1184,7 +1184,7 @@ def reduce(self, fn, *args): pl_outputs = [] with ops.control_dependencies(args): for output in concrete_function.outputs: - if not isinstance(output, ops.Tensor): + if not isinstance(output, tensor_lib.Tensor): raise ValueError(f"Got a non-Tensor output {output} while running " "reduce.") # Note that we use placeholder_with_default just to make XLA happy since @@ -1249,7 +1249,7 @@ def reduce_sum(self, x): def _lookup_reduction(self, t): """Lookups Tensor `t` in the reduction maps.""" - assert isinstance(t, ops.Tensor), t + assert isinstance(t, tensor_lib.Tensor), t return self._reduce_map.get(t.op) @@ -1298,7 +1298,7 @@ def __init__(self, """Creates an object to rewrite a parallel-for loop. Args: - loop_var: ops.Tensor output of a Placeholder operation. The value should + loop_var: Tensor output of a Placeholder operation. The value should be an int32 scalar representing the loop iteration number. loop_len: A scalar or scalar Tensor representing the number of iterations the loop is run for. @@ -1316,7 +1316,7 @@ def __init__(self, pfor_config: PForConfig object used while constructing the loop body. warn: Whether or not to warn on while loop conversions. """ - assert isinstance(loop_var, ops.Tensor) + assert isinstance(loop_var, tensor_lib.Tensor) assert loop_var.op.type == "PlaceholderWithDefault" self._loop_var = loop_var loop_len_value = tensor_util.constant_value(loop_len) @@ -1425,7 +1425,7 @@ def convert(self, y): """Returns the converted value corresponding to y. Args: - y: A ops.Tensor or a ops.Operation object. If latter, y should not have + y: A Tensor or a ops.Operation object. If latter, y should not have any outputs. Returns: @@ -1436,10 +1436,10 @@ def convert(self, y): return None if isinstance(y, sparse_tensor.SparseTensor): return self._convert_sparse(y) - assert isinstance(y, (ops.Tensor, ops.Operation)), y + assert isinstance(y, (tensor_lib.Tensor, ops.Operation)), y output = self._convert_helper(y) if isinstance(output, WrappedTensor): - assert isinstance(y, ops.Tensor) + assert isinstance(y, tensor_lib.Tensor) return self._unwrap_or_tile(output) else: assert isinstance(y, ops.Operation) @@ -1453,7 +1453,8 @@ def _was_converted(self, t): return converted_t.t is not t def _add_conversion(self, old_output, new_output): - assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output + assert isinstance( + old_output, (tensor_lib.Tensor, ops.Operation)), old_output assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output self._conversion_map[old_output] = new_output @@ -1467,7 +1468,7 @@ def _convert_reduction(self, y): (reduction_fn, reduction_args) = reduction batched_args = [] for reduction_arg in reduction_args: - assert isinstance(reduction_arg, ops.Tensor), reduction_arg + assert isinstance(reduction_arg, tensor_lib.Tensor), reduction_arg # Tensor being reduced should already be converted due to a control # dependency on the created placeholder. # Note that in cases where reduction_arg is in an outer context, one @@ -1499,7 +1500,7 @@ def _convert_helper(self, op_or_tensor): "Got %s", y) y_op = y else: - assert isinstance(y, ops.Tensor), y + assert isinstance(y, tensor_lib.Tensor), y y_op = y.op is_while_loop = y_op.type == "Exit" @@ -1891,7 +1892,7 @@ def _channel_flatten_input(x, data_format): We then merge the S and C dimension. Args: - x: ops.Tensor to transform. + x: tensor_lib.Tensor to transform. data_format: "NCHW" or "NHWC". Returns: @@ -2588,7 +2589,7 @@ def _convert_gather(pfor_input): if param_stacked: pfor_input.stack_inputs(stack_indices=[1]) indices = pfor_input.stacked_input(1) - if isinstance(axis, ops.Tensor): + if isinstance(axis, tensor_lib.Tensor): axis = array_ops.where(axis >= 0, axis + 1, axis) else: axis = axis + 1 if axis >= 0 else axis diff --git a/tensorflow/python/ops/parsing_config.py b/tensorflow/python/ops/parsing_config.py index 32ff8190c083f2..8be4c9c79fc988 100644 --- a/tensorflow/python/ops/parsing_config.py +++ b/tensorflow/python/ops/parsing_config.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -498,7 +499,7 @@ def _make_dense_default(self, key, shape, dtype): else: if default_value is None: default_value = constant_op.constant([], dtype=dtype) - elif not isinstance(default_value, ops.Tensor): + elif not isinstance(default_value, tensor.Tensor): key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) default_value = ops.convert_to_tensor( default_value, dtype=dtype, name=key_name) diff --git a/tensorflow/python/ops/ragged/ragged_cross_op_test.py b/tensorflow/python/ops/ragged/ragged_cross_op_test.py index 0408051f1f5c89..c098c13644f342 100644 --- a/tensorflow/python/ops/ragged/ragged_cross_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_cross_op_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_ragged_array_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_array_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -350,29 +349,32 @@ def testRaggedCrossLargeBatch(self): dict( testcase_name='BadDType', inputs=[ragged_const([[1.1], [2.2, 3.3]])], - message=r'Unexpected dtype for inputs\[0\]'), + message=r'Unexpected dtype for inputs\[0\]', + ), dict( testcase_name='StaticBatchSizeMismatch1', - inputs=[ragged_const([[1]]), - ragged_const([[2], [3]])], + inputs=[ragged_const([[1]]), ragged_const([[2], [3]])], exception=(ValueError, errors.InvalidArgumentError), - message='inputs must all have the same batch dimension size'), + message='inputs must all have the same batch dimension size', + ), dict( testcase_name='StaticBatchSizeMismatch2', - inputs=[ragged_const([[1]]), - dense_const([[2], [3]])], + inputs=[ragged_const([[1]]), dense_const([[2], [3]])], exception=(ValueError, errors.InvalidArgumentError), - message='inputs must all have the same batch dimension size'), + message='inputs must all have the same batch dimension size', + ), dict( testcase_name='3DDenseTensor', inputs=[dense_const([[[1]]])], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), dict( testcase_name='0DDenseTensor', inputs=[dense_const(1)], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), ]) def testStaticError(self, inputs, exception=ValueError, message=None): with self.assertRaisesRegex(exception, message): @@ -382,25 +384,29 @@ def testStaticError(self, inputs, exception=ValueError, message=None): dict( testcase_name='3DRaggedTensor', inputs=[ragged_const([[[1]]], ragged_rank=1)], - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), dict( testcase_name='0DDenseTensor', inputs=[dense_const(1)], signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), dict( testcase_name='1DDenseTensor', inputs=[dense_const([1])], signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), dict( testcase_name='3DDenseTensor', inputs=[dense_const([[[1]]])], signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), ]) def testRuntimeError(self, inputs, @@ -458,7 +464,15 @@ def testRaggedValuesAndSplitsMustMatch(self): out_values_type=dtypes.string, out_row_splits_type=dtypes.int64)) - def testRaggedCrossInvalidValue(self): + @parameterized.named_parameters([ + dict(testcase_name='EmptySplits', ragged_splits=[]), + dict( + testcase_name='NegativeSplits', ragged_splits=[-216, -114, -58, -54] + ), + dict(testcase_name='TooLargeValueSplits', ragged_splits=[0, 1, 2, 10]), + dict(testcase_name='UnsortedSplits', ragged_splits=[0, 2, 2, 1]), + ]) + def testRaggedCrossInvalidRaggedSplits(self, ragged_splits): # Test case in GitHub isseu 59114. with self.assertRaisesRegex( (ValueError, errors.InvalidArgumentError), 'Invalid RaggedTensor' @@ -468,8 +482,8 @@ def testRaggedCrossInvalidValue(self): ragged_values = [ ragged_values_0, ] - ragged_row_splits_0_tensor = random_ops.random_uniform( - [4], minval=-256, maxval=257, dtype=dtypes.int64 + ragged_row_splits_0_tensor = ragged_const( + ragged_splits, dtype=dtypes.int64 ) ragged_row_splits_0 = array_ops.identity(ragged_row_splits_0_tensor) ragged_row_splits = [ diff --git a/tensorflow/python/ops/random_ops_util.py b/tensorflow/python/ops/random_ops_util.py index 1e81f4691b88df..4f9eefcc920e31 100644 --- a/tensorflow/python/ops/random_ops_util.py +++ b/tensorflow/python/ops/random_ops_util.py @@ -18,7 +18,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import bitwise_ops @@ -69,7 +69,7 @@ def convert_alg_to_int(alg): return alg if isinstance(alg, Algorithm): return alg.value - if isinstance(alg, ops.Tensor): + if isinstance(alg, tensor.Tensor): return alg if isinstance(alg, str): # canonicalized alg diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 153ae97b89f1b7..6ad2dd866e371c 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -43,7 +43,6 @@ from tensorflow.python.framework import tensor as tensor_module from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -76,7 +75,7 @@ def get_eager_safe_handle_data(handle): """Get the data handle from the Tensor `handle`.""" - assert isinstance(handle, ops.Tensor) + assert isinstance(handle, tensor_module.Tensor) if isinstance(handle, ops.EagerTensor): return handle._handle_data # pylint: disable=protected-access @@ -268,7 +267,7 @@ class EagerResourceDeleter: __slots__ = ["_handle", "_handle_device", "_context"] def __init__(self, handle, handle_device): - if not isinstance(handle, ops.Tensor): + if not isinstance(handle, tensor_module.Tensor): raise ValueError( (f"Passed handle={handle} to EagerResourceDeleter. Was expecting " f"the handle to be a `tf.Tensor`.")) @@ -1933,7 +1932,7 @@ def _init_from_args( "`variable_def`. You provided neither.") init_from_fn = callable(initial_value) - if isinstance(initial_value, ops.Tensor) and hasattr( + if isinstance(initial_value, tensor_module.Tensor) and hasattr( initial_value, "graph") and initial_value.graph.building_function: raise ValueError(f"Argument `initial_value` ({initial_value}) could not " "be lifted out of a `tf.function`. " @@ -2540,7 +2539,7 @@ def __eq__(self, other): return isinstance(other, PList) and self.components == other.components -class VariableSpec(tensor_spec.DenseSpec): +class VariableSpec(tensor_module.DenseSpec): """Describes a tf.Variable. A `VariableSpec` provides metadata describing the `tf.Variable` objects @@ -2626,7 +2625,8 @@ def _from_components(self, components): raise ValueError(f"Components of a ResourceVariable must only contain " f"its resource handle, got f{components} instead.") handle = components[0] - if not isinstance(handle, ops.Tensor) or handle.dtype != dtypes.resource: + if not isinstance( + handle, tensor_module.Tensor) or handle.dtype != dtypes.resource: raise ValueError(f"The handle of a ResourceVariable must be a resource " f"tensor, got {handle} instead.") return ResourceVariable(trainable=self.trainable, @@ -2637,7 +2637,7 @@ def _from_components(self, components): @property def _component_specs(self): return [ - tensor_spec.TensorSpec( + tensor_module.TensorSpec( [], dtypes.DType( dtypes.resource._type_enum, # pylint: disable=protected-access @@ -2696,7 +2696,7 @@ def placeholder_value(self, placeholder_context): # exists in the PlaceholderContext variable = placeholder_context.get_placeholder(self.alias_id) else: - spec = tensor_spec.TensorSpec([], dtypes.resource) + spec = tensor_module.TensorSpec([], dtypes.resource) spec_context = trace_type.InternalPlaceholderContext( context_graph.outer_graph) spec_context.update_naming_scope(name) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index a0c3bee073d497..e3c26ff539387f 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -17,6 +17,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -125,7 +126,7 @@ def _infer_state_dtype(explicit_dtype, state): def _maybe_tensor_shape_from_tensor(shape): - if isinstance(shape, ops.Tensor): + if isinstance(shape, tensor.Tensor): return tensor_shape.as_shape(tensor_util.constant_value(shape)) else: return shape diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index e190f4a35cd81c..9cff803335c6a3 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -23,7 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_impl @@ -84,7 +84,7 @@ def _concat(prefix, suffix, static=False): ValueError: if prefix or suffix was `None` and asked for dynamic Tensors out. """ - if isinstance(prefix, ops.Tensor): + if isinstance(prefix, tensor.Tensor): p = prefix p_static = tensor_util.constant_value(prefix) if p.shape.ndims == 0: @@ -102,7 +102,7 @@ def _concat(prefix, suffix, static=False): if p.is_fully_defined() else None ) - if isinstance(suffix, ops.Tensor): + if isinstance(suffix, tensor.Tensor): s = suffix s_static = tensor_util.constant_value(suffix) if s.shape.ndims == 0: diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index e6aff616a32ae9..9ef5794fb1a67c 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.util import compat @@ -166,7 +167,7 @@ def get_session_handle(data, name=None): ``` """ - if not isinstance(data, ops.Tensor): + if not isinstance(data, tensor_lib.Tensor): raise TypeError("`data` must be of type Tensor.") # Colocate this operation with data. diff --git a/tensorflow/python/ops/signal/BUILD b/tensorflow/python/ops/signal/BUILD index 683c1e52e38e32..a87dd8bf525b32 100644 --- a/tensorflow/python/ops/signal/BUILD +++ b/tensorflow/python/ops/signal/BUILD @@ -52,6 +52,7 @@ py_strict_library( ":shape_ops", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/ops/signal/mel_ops.py b/tensorflow/python/ops/signal/mel_ops.py index bcb306f7873495..47d85859ddd9b4 100644 --- a/tensorflow/python/ops/signal/mel_ops.py +++ b/tensorflow/python/ops/signal/mel_ops.py @@ -16,6 +16,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -75,7 +76,7 @@ def _validate_arguments(num_mel_bins, sample_rate, if lower_edge_hertz >= upper_edge_hertz: raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % (lower_edge_hertz, upper_edge_hertz)) - if not isinstance(sample_rate, ops.Tensor): + if not isinstance(sample_rate, tensor.Tensor): if sample_rate <= 0.0: raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) if upper_edge_hertz > sample_rate / 2: @@ -156,7 +157,7 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, """ with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name: # Convert Tensor `sample_rate` to float, if possible. - if isinstance(sample_rate, ops.Tensor): + if isinstance(sample_rate, tensor.Tensor): maybe_const_val = tensor_util.constant_value(sample_rate) if maybe_const_val is not None: sample_rate = maybe_const_val diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 42d26e27d79ed9..b52e299149a874 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -98,7 +99,7 @@ def _convert_to_sparse_tensors(sp_inputs): def _make_int64_tensor(value, name): if isinstance(value, compat.integral_types): return ops.convert_to_tensor(value, name=name, dtype=dtypes.int64) - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor_lib.Tensor): raise TypeError("{} must be an integer value".format(name)) if value.dtype == dtypes.int64: return value @@ -215,7 +216,7 @@ def sparse_expand_dims(sp_input, axis=None, name=None): with ops.name_scope(name, default_name="expand_dims", values=[sp_input]): if isinstance(axis, compat.integral_types): axis = ops.convert_to_tensor(axis, name="axis", dtype=dtypes.int32) - elif not isinstance(axis, ops.Tensor): + elif not isinstance(axis, tensor_lib.Tensor): raise TypeError("axis must be an integer value in range [-rank(sp_input)" " - 1, rank(sp_input)]") @@ -717,7 +718,8 @@ def _sparse_cross_internal_v2(inputs): if not isinstance(inputs, (tuple, list)): raise TypeError("Inputs must be a list") if not all( - isinstance(i, sparse_tensor.SparseTensor) or isinstance(i, ops.Tensor) + isinstance( + i, sparse_tensor.SparseTensor) or isinstance(i, tensor_lib.Tensor) for i in inputs): raise TypeError("All inputs must be Tensor or SparseTensor.") sparse_inputs = [ @@ -747,7 +749,8 @@ def _sparse_cross_internal(inputs, if not isinstance(inputs, (tuple, list)): raise TypeError("Inputs must be a list") if not all( - isinstance(i, sparse_tensor.SparseTensor) or isinstance(i, ops.Tensor) + isinstance( + i, sparse_tensor.SparseTensor) or isinstance(i, tensor_lib.Tensor) for i in inputs): raise TypeError("All inputs must be SparseTensors") @@ -1901,7 +1904,7 @@ def sparse_merge_impl(sp_ids, if isinstance(sp_ids, sparse_tensor.SparseTensorValue) or isinstance( sp_ids, sparse_tensor.SparseTensor): sp_ids = [sp_ids] - if not (isinstance(vocab_size, ops.Tensor) or + if not (isinstance(vocab_size, tensor_lib.Tensor) or isinstance(vocab_size, numbers.Integral)): raise TypeError("vocab_size has to be a Tensor or Python int. Found %s" % type(vocab_size)) @@ -1914,7 +1917,8 @@ def sparse_merge_impl(sp_ids, raise TypeError("vocab_size has to be a list of Tensors or Python ints. " "Found %s" % type(vocab_size)) for dim in vocab_size: - if not (isinstance(dim, ops.Tensor) or isinstance(dim, numbers.Integral)): + if not (isinstance( + dim, tensor_lib.Tensor) or isinstance(dim, numbers.Integral)): raise TypeError( "vocab_size has to be a list of Tensors or Python ints. Found %s" % type(dim)) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 77a7fc0631f3d7..6ee429a6783e27 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -28,6 +28,7 @@ from tensorflow.compiler.tf2xla.ops import gen_xla_ops from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1060,7 +1061,7 @@ def _reshape_if_necessary(tensor, new_shape): new_shape = tuple(-1 if x is None else x for x in new_shape) cur_shape = tuple(x.value for x in tensor.shape.dims) if (len(new_shape) == len(cur_shape) and - all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1) + all(not isinstance(d1, tensor_lib.Tensor) and (d0 == d1 or d1 == -1) for d0, d1 in zip(cur_shape, new_shape))): return tensor else: diff --git a/tensorflow/python/ops/stateful_random_ops.py b/tensorflow/python/ops/stateful_random_ops.py index ae396e66dceb61..a33ea0c9f21a2e 100644 --- a/tensorflow/python/ops/stateful_random_ops.py +++ b/tensorflow/python/ops/stateful_random_ops.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import gen_stateful_random_ops @@ -161,7 +162,7 @@ def _get_state_size(alg): def _check_state_shape(shape, alg): - if isinstance(alg, ops.Tensor) and not context.executing_eagerly(): + if isinstance(alg, tensor.Tensor) and not context.executing_eagerly(): return shape.assert_is_compatible_with([_get_state_size(int(alg))]) diff --git a/tensorflow/python/ops/structured/BUILD b/tensorflow/python/ops/structured/BUILD index 0b98081990e0f8..708242c253f540 100644 --- a/tensorflow/python/ops/structured/BUILD +++ b/tensorflow/python/ops/structured/BUILD @@ -43,8 +43,8 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:type_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", @@ -72,7 +72,7 @@ py_strict_library( srcs_version = "PY3", deps = [ ":structured_tensor", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops/ragged:dynamic_ragged_shape", @@ -91,6 +91,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:random_ops", @@ -116,8 +117,8 @@ py_strict_test( "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops/ragged:dynamic_ragged_shape", @@ -163,8 +164,8 @@ py_strict_test( ":structured_tensor", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", diff --git a/tensorflow/python/ops/structured/structured_array_ops.py b/tensorflow/python/ops/structured/structured_array_ops.py index 3bee40bc9e1d1d..9805418517399b 100644 --- a/tensorflow/python/ops/structured/structured_array_ops.py +++ b/tensorflow/python/ops/structured/structured_array_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -515,7 +516,7 @@ def _structured_tensor_from_row_partitions(shape, row_partitions): # pylint: disable=protected_access def _all_nested_row_partitions(rt): """Returns all nested row partitions in rt, including for dense dimensions.""" - if isinstance(rt, ops.Tensor): + if isinstance(rt, tensor_lib.Tensor): if rt.shape.rank <= 1: return () else: @@ -529,7 +530,7 @@ def _all_nested_row_partitions(rt): def _structured_tensor_like(t): """Create a StructuredTensor with the shape of a (composite) tensor.""" - if isinstance(t, ops.Tensor): + if isinstance(t, tensor_lib.Tensor): return _structured_tensor_from_dense_tensor(t) if ragged_tensor.is_ragged(t): return StructuredTensor.from_fields( diff --git a/tensorflow/python/ops/structured/structured_array_ops_test.py b/tensorflow/python/ops/structured/structured_array_ops_test.py index 09421202488416..04f21fb28880c1 100644 --- a/tensorflow/python/ops/structured/structured_array_ops_test.py +++ b/tensorflow/python/ops/structured/structured_array_ops_test.py @@ -17,8 +17,8 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -57,11 +57,11 @@ def assertAllEqual(self, a, b, msg=None): self.assertIsNone(e, (msg + ": " if msg else "") + str(e)) a_tensors = [ x for x in nest.flatten(a, expand_composites=True) - if isinstance(x, ops.Tensor) + if isinstance(x, tensor.Tensor) ] b_tensors = [ x for x in nest.flatten(b, expand_composites=True) - if isinstance(x, ops.Tensor) + if isinstance(x, tensor.Tensor) ] self.assertLen(a_tensors, len(b_tensors)) a_arrays, b_arrays = self.evaluate((a_tensors, b_tensors)) diff --git a/tensorflow/python/ops/structured/structured_tensor.py b/tensorflow/python/ops/structured/structured_tensor.py index 696589fcb39c9d..752d29895fe1a3 100644 --- a/tensorflow/python/ops/structured/structured_tensor.py +++ b/tensorflow/python/ops/structured/structured_tensor.py @@ -23,8 +23,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -39,8 +39,12 @@ from tensorflow.python.util.tf_export import tf_export # Each field may contain one of the following types of Tensors. -_FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor', - extension_type.ExtensionType] +_FieldValue = Union[ + tensor.Tensor, + ragged_tensor.RaggedTensor, + 'StructuredTensor', + extension_type.ExtensionType +] # Function that takes a FieldValue as input and returns the transformed # FieldValue. _FieldFn = Callable[[_FieldValue], _FieldValue] @@ -134,7 +138,7 @@ def _old_init(cls, fields, shape, nrows, row_partitions, internal=False): """ assert isinstance(fields, dict), fields assert isinstance(shape, tensor_shape.TensorShape), shape - assert nrows is None or isinstance(nrows, ops.Tensor), nrows + assert nrows is None or isinstance(nrows, tensor.Tensor), nrows assert row_partitions is None or isinstance(row_partitions, tuple), row_partitions return StructuredTensor( @@ -786,7 +790,7 @@ def _tensor_getitem(self, key): if not (k.start is None and k.stop is None and k.step is None): # TODO(edloper): Better static shape analysis here. result_shape[d] = None - elif isinstance(k, (int, ops.Tensor)): + elif isinstance(k, (int, tensor.Tensor)): result_shape[d] = -1 # mark for deletion elif k is None: raise ValueError('Slicing not supported for tf.newaxis') @@ -1008,7 +1012,7 @@ def _from_pylist_of_value(cls, pyval, typespec, path_so_far): return ragged_factory_ops.constant(pyval) except Exception as exc: raise ValueError('Error parsing path %r' % (path_so_far,)) from exc - elif isinstance(typespec, tensor_spec.TensorSpec): + elif isinstance(typespec, tensor.TensorSpec): try: result = constant_op.constant(pyval, typespec.dtype) except Exception as exc: @@ -1049,7 +1053,7 @@ def _from_pyscalar(cls, pyval, typespec, path_so_far): except Exception as exc: raise ValueError('Error parsing path %r' % (path_so_far,)) from exc else: - if not (isinstance(typespec, tensor_spec.TensorSpec) and + if not (isinstance(typespec, tensor.TensorSpec) and typespec.shape.rank == 0): raise ValueError('Value at %r does not match typespec: %r vs %r' % (path_so_far, typespec, pyval)) @@ -1200,7 +1204,7 @@ def rank(self): def _convert_to_structured_field_value(value): """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor.""" if isinstance(value, - (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): + (tensor.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): return value elif ragged_tensor.is_ragged(value): return ragged_tensor.convert_to_tensor_or_ragged_tensor(value) @@ -1215,7 +1219,7 @@ def _convert_to_structured_field_value(value): def _find_shape_dtype( - fields: Mapping[str, _FieldValue], nrows: Optional[ops.Tensor], + fields: Mapping[str, _FieldValue], nrows: Optional[tensor.Tensor], row_partitions: Optional[Sequence[RowPartition]]) -> dtypes.DType: """Return a consistent dtype for fields, nrows, & row_partitions. @@ -1232,7 +1236,7 @@ def _find_shape_dtype( If int32 is explicitly specified, return int32. Otherwise, return int64. """ field_dtypes = [_field_shape_dtype(v) for v in fields.values()] - nrows_dtypes = [nrows.dtype] if isinstance(nrows, ops.Tensor) else [] + nrows_dtypes = [nrows.dtype] if isinstance(nrows, tensor.Tensor) else [] rp_dtypes = [] if row_partitions is None else [ rp.dtype for rp in row_partitions ] @@ -1266,7 +1270,7 @@ def _merge_nrows(nrows, static_nrows, value, dtype, validate): A tuple `(nrows, static_nrows)`. """ static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0) - if isinstance(value, ops.Tensor): + if isinstance(value, tensor.Tensor): value_nrows = array_ops.shape(value, out_type=dtype)[0] else: value_nrows = value.nrows() @@ -1287,7 +1291,7 @@ def _merge_nrows(nrows, static_nrows, value, dtype, validate): def _merge_row_partitions(row_partitions, value, rank, dtype, validate): """Merges `row_partitions` with `row_partitions(value)`.""" - if isinstance(value, ops.Tensor): + if isinstance(value, tensor.Tensor): value_row_partitions = _row_partitions_for_tensor(value, rank, dtype) elif isinstance(value, ragged_tensor.RaggedTensor): @@ -1486,7 +1490,7 @@ def _replace_row_partitions(value, new_partitions): A value that is equivalent to `value`, where outer row partitions have been replaced by `new_partitions`. """ - if isinstance(value, ops.Tensor) or not new_partitions: + if isinstance(value, tensor.Tensor) or not new_partitions: return value elif isinstance(value, ragged_tensor.RaggedTensor): @@ -1532,14 +1536,14 @@ def _partition_outer_dimension(value, row_partition): `result.rank = value.rank + 1`. """ is_ragged = row_partition.uniform_row_length() is None - if isinstance(value, ops.Tensor) and not is_ragged: + if isinstance(value, tensor.Tensor) and not is_ragged: new_shape = array_ops.concat( [[row_partition.nrows(), row_partition.uniform_row_length()], array_ops.shape(value, out_type=row_partition.dtype)[1:]], axis=0) return array_ops.reshape(value, new_shape) - elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): + elif isinstance(value, (tensor.Tensor, ragged_tensor.RaggedTensor)): return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access value, row_partition) else: @@ -1558,7 +1562,7 @@ def _partition_outer_dimension(value, row_partition): def _merge_dims(value, outer_axis, inner_axis): """Merges `outer_axis...inner_axis` of `value` into a single dimension.""" assert outer_axis < inner_axis - if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(value, (tensor.Tensor, ragged_tensor.RaggedTensor)): return ragged_tensor.merge_dims(value, outer_axis, inner_axis) else: assert isinstance(value, StructuredTensor) @@ -1575,7 +1579,7 @@ def _merge_dims(value, outer_axis, inner_axis): def _dynamic_ragged_shape_spec_from_spec( spec: Union[dynamic_ragged_shape.DynamicRaggedShape.Spec, ragged_tensor.RaggedTensorSpec, StructuredTensor.Spec, - tensor_spec.TensorSpec] + tensor.TensorSpec] ) -> dynamic_ragged_shape.DynamicRaggedShape.Spec: if isinstance(spec, StructuredTensor.Spec): return spec._ragged_shape # pylint: disable=protected-access @@ -1630,7 +1634,7 @@ def _dynamic_ragged_shape_from_tensor( return field._ragged_shape # pylint: disable=protected-access shape = array_ops.shape_v2(field, out_type=dtype) - if isinstance(shape, ops.Tensor): + if isinstance(shape, tensor.Tensor): return dynamic_ragged_shape.DynamicRaggedShape( row_partitions=[], inner_shape=shape) elif isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape): @@ -1697,7 +1701,7 @@ def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions): """Produce a DynamicRaggedShape for StructuredTensor.""" assert isinstance(fields, dict), fields assert isinstance(shape, tensor_shape.TensorShape), shape - assert nrows is None or isinstance(nrows, ops.Tensor) or isinstance( + assert nrows is None or isinstance(nrows, tensor.Tensor) or isinstance( nrows, int), nrows assert row_partitions is None or isinstance(row_partitions, tuple), row_partitions diff --git a/tensorflow/python/ops/structured/structured_tensor_dynamic.py b/tensorflow/python/ops/structured/structured_tensor_dynamic.py index 84944861af1830..8aa434831e616a 100644 --- a/tensorflow/python/ops/structured/structured_tensor_dynamic.py +++ b/tensorflow/python/ops/structured/structured_tensor_dynamic.py @@ -14,7 +14,7 @@ # ============================================================================== """Dynamic shape for structured Tensors.""" -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import dynamic_ragged_shape @@ -26,7 +26,7 @@ def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions): """Produce a DynamicRaggedShape for StructuredTensor.""" assert isinstance(fields, dict), fields assert isinstance(shape, tensor_shape.TensorShape), shape - assert nrows is None or isinstance(nrows, ops.Tensor), nrows + assert nrows is None or isinstance(nrows, tensor.Tensor), nrows assert isinstance(row_partitions, tuple), row_partitions rank = shape.rank diff --git a/tensorflow/python/ops/structured/structured_tensor_test.py b/tensorflow/python/ops/structured/structured_tensor_test.py index 3e84e21ae6f018..b182c530b1a393 100644 --- a/tensorflow/python/ops/structured/structured_tensor_test.py +++ b/tensorflow/python/ops/structured/structured_tensor_test.py @@ -25,8 +25,8 @@ from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -611,7 +611,7 @@ def testFromFields(self, for field, value in fields.items(): self.assertIsInstance( struct.field_value(field), - (ops.Tensor, structured_tensor.StructuredTensor, + (tensor.Tensor, structured_tensor.StructuredTensor, ragged_tensor.RaggedTensor)) self.assertAllEqual(struct.field_value(field), value) @@ -791,7 +791,7 @@ def testPartitionOuterDims(self): dtype=dtypes.int64), _fields={ "x": - tensor_spec.TensorSpec([2, 2], dtypes.int32), + tensor.TensorSpec([2, 2], dtypes.int32), "y": ragged_tensor.RaggedTensorSpec([2, 2, None], dtypes.int32) @@ -855,8 +855,8 @@ def testPartitionOuterDimsErrors(self): "pyval": {"a": 12, "b": [1, 2, 3], "c": [[1, 2], [3]]}, "type_spec": StructuredTensor.Spec._from_fields_and_rank( fields={ - "a": tensor_spec.TensorSpec([], dtypes.int32), - "b": tensor_spec.TensorSpec([None], dtypes.int32), + "a": tensor.TensorSpec([], dtypes.int32), + "b": tensor.TensorSpec([None], dtypes.int32), "c": ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)}, rank=0), @@ -889,7 +889,7 @@ def testPartitionOuterDimsErrors(self): "testcase_name": "EmptyListWithTypeSpecAndFields", "pyval": [], "type_spec": structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={"a": tensor_spec.TensorSpec([0], dtypes.int32)}, + fields={"a": tensor.TensorSpec([0], dtypes.int32)}, rank=1), "expected": lambda: StructuredTensor.from_fields(shape=[0], fields={ "a": []}) @@ -963,7 +963,7 @@ def testPartitionOuterDimsErrors(self): "pyval": [[{"a": 1}, {"a": 2}, {"a": 3},], [{"a": 4}, {"a": 5}, {"a": 6}]], "type_spec": structured_tensor.StructuredTensorSpec([2, 3], { - "a": tensor_spec.TensorSpec(None, dtypes.int32)}), + "a": tensor.TensorSpec(None, dtypes.int32)}), "expected": lambda: StructuredTensor.from_fields( shape=[2, 3], fields={"a": [[1, 2, 3], [4, 5, 6]]}) }, @@ -979,8 +979,8 @@ def testPyvalConversion(self, pyval, expected, type_spec=None): def testStructuredTensorSpecFactory(self): spec = StructuredTensor.Spec._from_fields_and_rank( fields={ - "a": tensor_spec.TensorSpec([], dtypes.int32), - "b": tensor_spec.TensorSpec([None], dtypes.int32), + "a": tensor.TensorSpec([], dtypes.int32), + "b": tensor.TensorSpec([None], dtypes.int32), "c": ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32) }, rank=0) @@ -1042,19 +1042,19 @@ def testToPyval(self, st, expected): dict(testcase_name="TypeSpecMismatch_DictKey", pyval={"a": 1}, type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"b": tensor_spec.TensorSpec([1], dtypes.int32)}, + fields={"b": tensor.TensorSpec([1], dtypes.int32)}, rank=1), msg=r"Value at \(\) does not match typespec"), dict(testcase_name="TypeSpecMismatch_ListDictKey", pyval=[{"a": 1}], type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"b": tensor_spec.TensorSpec([1], dtypes.int32)}, + fields={"b": tensor.TensorSpec([1], dtypes.int32)}, rank=1), msg=r"Value at \(\) does not match typespec"), dict(testcase_name="TypeSpecMismatch_RankMismatch", pyval=[{"a": 1}], type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"a": tensor_spec.TensorSpec([], dtypes.int32)}, + fields={"a": tensor.TensorSpec([], dtypes.int32)}, rank=0), msg=r"Value at \(\) does not match typespec \(rank mismatch\)"), dict(testcase_name="TypeSpecMismatch_Scalar", @@ -1068,14 +1068,14 @@ def testToPyval(self, st, expected): dict(testcase_name="TypeSpecMismatch_ListTensor", pyval={"a": [[1]]}, type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"a": tensor_spec.TensorSpec([], dtypes.int32)}, + fields={"a": tensor.TensorSpec([], dtypes.int32)}, rank=0), msg=r"Value at \('a',\) does not match typespec"), dict(testcase_name="TypeSpecMismatch_ListTensorDeep", pyval={"a": {"b": [[1]]}}, type_spec=StructuredTensor.Spec._from_fields_and_rank( fields={"a": StructuredTensor.Spec._from_fields_and_rank( - fields={"b": tensor_spec.TensorSpec([], dtypes.int32)}, + fields={"b": tensor.TensorSpec([], dtypes.int32)}, rank=0 )}, rank=0), @@ -1095,7 +1095,7 @@ def testToPyval(self, st, expected): dict(testcase_name="TypeSpecMismatch_ListStruct", pyval=[[1]], type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"a": tensor_spec.TensorSpec([1, 1], dtypes.int32)}, + fields={"a": tensor.TensorSpec([1, 1], dtypes.int32)}, rank=2), msg=r"Value at \(\) does not match typespec"), dict(testcase_name="InconsistentDictionaryDepth", diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index 95b6f67879890c..9ef7cced15cf3a 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -992,10 +993,14 @@ def graph_v1(param, step=None, name=None): Raises: TypeError: If `param` isn't already a `tf.Tensor` in graph mode. """ - if not context.executing_eagerly() and not isinstance(param, ops.Tensor): - raise TypeError("graph() needs a argument `param` to be tf.Tensor " - "(e.g. tf.placeholder) in graph mode, but received " - f"param={param} of type {type(param).__name__}.") + if not context.executing_eagerly() and not isinstance( + param, tensor_lib.Tensor + ): + raise TypeError( + "graph() needs a argument `param` to be tf.Tensor " + "(e.g. tf.placeholder) in graph mode, but received " + f"param={param} of type {type(param).__name__}." + ) writer = _summary_state.writer if writer is None: return control_flow_ops.no_op() @@ -1170,7 +1175,7 @@ def _serialize_graph(arbitrary_graph): def _choose_step(step): if step is None: return training_util.get_or_create_global_step() - if not isinstance(step, ops.Tensor): + if not isinstance(step, tensor_lib.Tensor): return ops.convert_to_tensor(step, dtypes.int64) return step diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 3dba50011dd5a7..0459bff690c853 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -28,8 +28,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry @@ -103,7 +103,7 @@ def __init__(self, raise ValueError( "Cannot provide both `handle` and `tensor_array_name` arguments at " "the same time.") - if handle is not None and not isinstance(handle, ops.Tensor): + if handle is not None and not isinstance(handle, tensor_lib.Tensor): raise TypeError( f"Expected `handle` to be a Tensor, but got `{handle}` of type " f"`{type(handle)}` instead.") @@ -452,21 +452,26 @@ def __init__(self, self._dynamic_size = dynamic_size self._size = size - if (flow is not None and - (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)): + if flow is not None and ( + not isinstance(flow, tensor_lib.Tensor) or flow.dtype != dtypes.variant + ): raise TypeError( - f"Expected `flow` to be a variant tensor, but received `{flow.dtype}` " - f"instead.") + f"Expected `flow` to be a variant tensor, but received `{flow.dtype}`" + " instead." + ) if flow is None and size is None: - raise ValueError("Argument `size` must be provided if argument `flow` " - "is not provided.") + raise ValueError( + "Argument `size` must be provided if argument `flow` is not provided." + ) if flow is not None and size is not None: - raise ValueError("Cannot provide both `flow` and `size` arguments " - "at the same time.") + raise ValueError( + "Cannot provide both `flow` and `size` arguments at the same time." + ) if flow is not None and element_shape is not None: raise ValueError( "Cannot provide both `flow` and `element_shape` arguments" - "at the same time.") + "at the same time." + ) self._dtype = dtypes.as_dtype(dtype).base_dtype @@ -1434,7 +1439,7 @@ def _serialize(self): @property def _component_specs(self): - return [tensor_spec.TensorSpec([], dtypes.variant)] + return [tensor_lib.TensorSpec([], dtypes.variant)] def _to_components(self, value): if not isinstance(value, TensorArray): @@ -1510,7 +1515,7 @@ def placeholder_value(self, placeholder_context): return self._value def _flatten(self): - return [tensor_spec.TensorSpec([], dtypes.variant)] + return [tensor_lib.TensorSpec([], dtypes.variant)] def _from_tensors(self, tensors): return next(tensors) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index a86b8c2999f49a..33dd0438fa2f2f 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -27,6 +27,7 @@ from tensorflow.python.eager import monitoring from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -698,7 +699,7 @@ def _get_partitioned_variable(self, sharded variable exists for the given name but with different sharding. """ initializing_from_value = initializer is not None and isinstance( - initializer, ops.Tensor) + initializer, tensor.Tensor) if name in self._vars: raise ValueError( "A partitioner was provided, but an unpartitioned version of the " @@ -780,7 +781,7 @@ def _get_partitioned_variable(self, elif callable(initializer): init = initializer init_shape = var_shape - elif isinstance(initializer, ops.Tensor): + elif isinstance(initializer, tensor.Tensor): init = array_ops.slice(initializer, var_offset, var_shape) # Use the dtype of the given tensor. dtype = init.dtype.base_dtype diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 3e6dc5e3a5145e..76d8bb9ad6ecea 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -981,7 +982,7 @@ def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint @classmethod def _OverloadAllOperators(cls): # pylint: disable=invalid-name """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: cls._OverloadOperator(operator) # For slicing, bind getitem differently than a tensor (use SliceHelperVar # instead) @@ -990,9 +991,10 @@ def _OverloadAllOperators(cls): # pylint: disable=invalid-name @classmethod def _OverloadOperator(cls, operator): # pylint: disable=invalid-name - """Defer an operator overload to `ops.Tensor`. + """Defer an operator overload to `tensor_lib.Tensor`. - We pull the operator out of ops.Tensor dynamically to avoid ordering issues. + We pull the operator out of tensor_lib.Tensor dynamically to avoid ordering + issues. Args: operator: string. The operator name. @@ -1004,7 +1006,7 @@ def _OverloadOperator(cls, operator): # pylint: disable=invalid-name if operator == "__eq__" or operator == "__ne__": return - tensor_oper = getattr(ops.Tensor, operator) + tensor_oper = getattr(tensor_lib.Tensor, operator) def _run_op(a, *args, **kwargs): # pylint: disable=protected-access @@ -1014,17 +1016,24 @@ def _run_op(a, *args, **kwargs): setattr(cls, operator, _run_op) def __hash__(self): - if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + ): # pylint: disable=protected-access raise TypeError( "Variable is unhashable. " - f"Instead, use variable.ref() as the key. (Variable: {self})") + f"Instead, use variable.ref() as the key. (Variable: {self})" + ) else: return id(self) # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing def __eq__(self, other): """Compares two variables element-wise for equality.""" - if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + ): # pylint: disable=protected-access return gen_math_ops.equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality @@ -1033,7 +1042,10 @@ def __eq__(self, other): # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing def __ne__(self, other): """Compares two variables element-wise for equality.""" - if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + ): # pylint: disable=protected-access return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality @@ -1342,7 +1354,7 @@ def _try_guard_against_uninitialized_dependencies(name, initial_value): Raises: TypeError: If `initial_value` is not a `Tensor`. """ - if not isinstance(initial_value, ops.Tensor): + if not isinstance(initial_value, tensor_lib.Tensor): raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) # Don't modify initial_value if it contains any cyclic dependencies. diff --git a/tensorflow/python/ops/weak_tensor_math_ops_test.py b/tensorflow/python/ops/weak_tensor_math_ops_test.py new file mode 100644 index 00000000000000..82dda558c67c8f --- /dev/null +++ b/tensorflow/python/ops/weak_tensor_math_ops_test.py @@ -0,0 +1,458 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.math_ops on WeakTensor.""" +from absl.testing import parameterized +import numpy as np + +from tensorflow.core.framework import full_type_pb2 +from tensorflow.python import tf2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor +from tensorflow.python.framework import test_util +from tensorflow.python.framework.weak_tensor import WeakTensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import weak_tensor_ops # pylint: disable=unused-import +from tensorflow.python.ops import weak_tensor_test_util +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.platform import googletest + +_convert_to_input_type = weak_tensor_test_util.convert_to_input_type +_get_weak_tensor = weak_tensor_test_util.get_weak_tensor + + +@test_util.run_all_in_graph_and_eager_modes +class ReduceTest(test_util.TensorFlowTestCase, parameterized.TestCase): + # Test unary ops with optional dtype arg. + + @parameterized.parameters( + ("WeakTensor", WeakTensor), + ("Python", WeakTensor), + ("NumPy", tensor.Tensor), + ("Tensor", tensor.Tensor), + ) + def testReduceAllDims(self, input_type, result_type): + test_input = _convert_to_input_type( + [[1, 2, 3], [4, 5, 6]], input_type, np.int32 + ) + with test_util.device(use_gpu=True): + res = math_ops.reduce_sum(test_input) + self.assertIsInstance(res, result_type) + self.assertEqual(self.evaluate(res), 21) + + def testReduceExtendType(self): + test_in = np.random.randn(1000, 1000).astype(np.float32) + in_f32 = _get_weak_tensor(test_in, dtypes.float32) + in_bfl6 = math_ops.cast(test_in, dtypes.bfloat16) + + out_f32 = self.evaluate(math_ops.reduce_sum(in_f32)) + out_bf16 = self.evaluate(math_ops.reduce_sum(in_bfl6)) + expected = math_ops.cast(out_f32, dtypes.bfloat16) + + self.assertAllClose(out_bf16, expected, 1e-3) + + def testCountNonzero(self): + # simple case + x = _get_weak_tensor([[0, -2, 0], [4, 0, 0]], dtypes.int32) + self.assertEqual(self.evaluate(math_ops.count_nonzero(x)), 2) + + # boolean input + x = math_ops.not_equal(x, 0) + self.assertEqual(self.evaluate(math_ops.count_nonzero(x)), 2) + + # would overflow if int8 would be used for internal calculations + x = 2 * np.ones(512, dtype=np.int8) + self.assertEqual(self.evaluate(math_ops.count_nonzero(x)), 512) + + @parameterized.parameters( + ("WeakTensor", WeakTensor), + ("Python", WeakTensor), + ("NumPy", tensor.Tensor), + ("Tensor", tensor.Tensor), + ) + def testReduceExplicitAxes(self, input_type, result_type): + x = _convert_to_input_type([[1, 2, 3], [4, 5, 6]], input_type, np.int32) + with test_util.device(use_gpu=True): + for axis in (0, -2): + res = math_ops.reduce_sum(x, axis=axis) + self.assertIsInstance(res, result_type) + self.assertAllEqual(res, [5, 7, 9]) + for axis in (1, -1): + res = math_ops.reduce_sum(x, axis=axis) + self.assertIsInstance(res, result_type) + self.assertAllEqual(res, [6, 15]) + for axis in (None, (0, 1), (1, 0), (-1, 0), (0, -1), (-2, 1), (1, -2), + (-1, -2), (-2, -1)): + res = math_ops.reduce_sum(x, axis=axis) + self.assertIsInstance(res, result_type) + self.assertEqual(self.evaluate(res), 21) + + def testReduceInvalidAxis(self): + if context.executing_eagerly(): + # The shape check is in run a graph construction time. In eager mode, + # it misses the check, magically return result given wrong shape. + return + x = _get_weak_tensor([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + axis = np.array([[0], [1]]) + with self.assertRaisesRegex(ValueError, "must be at most rank 1"): + math_ops.reduce_sum(x, axis) + + def testReduceVar(self): + x = _get_weak_tensor([[0, 0, 0], [0, 0, 0]], dtype=dtypes.float32) + self.assertAllClose(self.evaluate(math_ops.reduce_variance(x)), 0) + self.assertAllClose( + self.evaluate(math_ops.reduce_variance(x, axis=0)), [0, 0, 0]) + + x = _get_weak_tensor([[1, 2, 1, 1], [1, 1, 0, 1]]) + with self.assertRaisesRegex(TypeError, "must be either real or complex"): + math_ops.reduce_variance(x) + + x = _get_weak_tensor([[1.0, 2.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0]]) + self.assertEqual(self.evaluate(math_ops.reduce_variance(x)), 0.25) + x_np = np.array([[1, 2, 1, 1], [1, 1, 0, 1]], "float32") + self.assertEqual(np.var(x_np), 0.25) + self.assertEqual(self.evaluate(math_ops.reduce_variance(x_np)), 0.25) + + x = ragged_factory_ops.constant([[5., 1., 4., 1.], [], [5., 9., 2.], [5.], + []]) + self.assertAllClose(math_ops.reduce_variance(x, axis=0), [0., 16., 1., 0.]) + + def testReduceVarComplex(self): + # Ensure that complex values are handled to be consistent with numpy + complex_ys = [([0 - 1j, 0 + 1j], dtypes.float64), + (np.array([0 - 1j, 0 + 1j], "complex64"), dtypes.float32), + (np.array([0 - 1j, 0 + 1j], "complex128"), dtypes.float64)] + for y, dtype in complex_ys: + y_result = math_ops.reduce_variance(y) + self.assertEqual(np.var(y), 1.0) + self.assertEqual(self.evaluate(y_result), 1.0) + self.assertEqual(y_result.dtype, dtype) + + def testReduceStd(self): + x = _get_weak_tensor([[0, 0, 0], [0, 0, 0]], dtypes.float32) + self.assertAllClose(self.evaluate(math_ops.reduce_std(x)), 0) + self.assertAllClose( + self.evaluate(math_ops.reduce_std(x, axis=0)), [0, 0, 0]) + + x = _get_weak_tensor([[1, 2, 1, 1], [1, 1, 0, 1]]) + with self.assertRaisesRegex(TypeError, "must be either real or complex"): + math_ops.reduce_std(x) + + x = [[1., 2., 1., 1.], [1., 1., 0., 1.]] + res = math_ops.reduce_std(x) + self.assertEqual(self.evaluate(res), 0.5) + self.assertIsInstance(res, WeakTensor) + x_np = np.array(x) + self.assertEqual(np.std(x_np), 0.5) + self.assertEqual(self.evaluate(math_ops.reduce_std(x_np)), 0.5) + self.assertIsInstance(math_ops.reduce_std(x_np), tensor.Tensor) + + x = ragged_factory_ops.constant([[5., 1., 4., 1.], [], [5., 9., 2.], [5.], + []]) + self.assertAllClose(math_ops.reduce_std(x, axis=0), [0., 4., 1., 0.]) + + def testReduceStdComplex(self): + # Ensure that complex values are handled to be consistent with numpy + complex_ys = [([0 - 1j, 0 + 1j], dtypes.float64), + (np.array([0 - 1j, 0 + 1j], "complex64"), dtypes.float32), + (np.array([0 - 1j, 0 + 1j], "complex128"), dtypes.float64)] + for y, dtype in complex_ys: + y_result = math_ops.reduce_std(y) + self.assertEqual(np.std(y), 1.0) + self.assertEqual(self.evaluate(y_result), 1.0) + self.assertEqual(y_result.dtype, dtype) + + +@test_util.run_all_in_graph_and_eager_modes +class LogSumExpTest(test_util.TensorFlowTestCase): + + def testReduceLogSumExp(self): + for dtype in [np.float16, np.float32, np.double]: + x_np = np.random.rand(5, 5).astype(dtype) + with test_util.use_gpu(): + y_tf_np = math_ops.reduce_logsumexp(x_np) + y_np = np.log(np.sum(np.exp(x_np))) + self.assertAllClose(y_tf_np, y_np) + + def testReductionIndices(self): + for dtype in [np.float16, np.float32, np.double]: + x_np = np.random.rand(5, 5).astype(dtype) + with test_util.use_gpu(): + y_tf = math_ops.reduce_logsumexp(x_np, axis=[0]) + y_np = np.log(np.sum(np.exp(x_np), axis=0)) + self.assertShapeEqual(y_np, y_tf) + y_tf_np = self.evaluate(y_tf) + self.assertAllClose(y_tf_np, y_np) + + def testReductionIndices2(self): + for dtype in [np.float16, np.float32, np.double]: + x_np = np.random.rand(5, 5).astype(dtype) + with test_util.use_gpu(): + y_tf = math_ops.reduce_logsumexp(x_np, axis=0) + y_np = np.log(np.sum(np.exp(x_np), axis=0)) + self.assertShapeEqual(y_np, y_tf) + y_tf_np = self.evaluate(y_tf) + self.assertAllClose(y_tf_np, y_np) + + def testKeepDims(self): + for dtype in [np.float16, np.float32, np.double]: + x_np = np.random.rand(5, 5).astype(dtype) + with test_util.use_gpu(): + y_tf_np = math_ops.reduce_logsumexp(x_np, keepdims=True) + self.assertEqual(y_tf_np.shape.rank, x_np.ndim) + y_np = np.log(np.sum(np.exp(x_np), keepdims=True)) + self.assertAllClose(y_tf_np, y_np) + + def testOverflow(self): + x = [1000, 1001, 1002, 1003] + for dtype in [np.float32, np.double]: + x_np = np.array(x, dtype=dtype) + max_np = np.max(x_np) + with self.assertRaisesRegex(RuntimeWarning, + "overflow encountered in exp"): + out = np.log(np.sum(np.exp(x_np))) + if out == np.inf: + raise RuntimeWarning("overflow encountered in exp") + + with test_util.use_gpu(): + x_tf = _get_weak_tensor(x_np, shape=x_np.shape) + y_tf_np = math_ops.reduce_logsumexp(x_tf) + y_np = np.log(np.sum(np.exp(x_np - max_np))) + max_np + self.assertAllClose(y_tf_np, y_np) + + def testUnderflow(self): + x = [-1000, -1001, -1002, -1003] + for dtype in [np.float32, np.double]: + x_np = np.array(x, dtype=dtype) + max_np = np.max(x_np) + with self.assertRaisesRegex(RuntimeWarning, + "divide by zero encountered in log"): + out = np.log(np.sum(np.exp(x_np))) + if out == -np.inf: + raise RuntimeWarning("divide by zero encountered in log") + + with test_util.use_gpu(): + x_tf = _get_weak_tensor(x_np, shape=x_np.shape) + y_tf_np = math_ops.reduce_logsumexp(x_tf) + y_np = np.log(np.sum(np.exp(x_np - max_np))) + max_np + self.assertAllClose(y_tf_np, y_np) + + def testInfinity(self): + with test_util.use_gpu(): + res = math_ops.reduce_logsumexp(-np.inf) + self.assertEqual(-np.inf, self.evaluate(res)) + + +@test_util.run_all_in_graph_and_eager_modes +class RoundTest(test_util.TensorFlowTestCase): + + def testRounding(self): + x = np.arange(-5.0, 5.0, .25) + for dtype in [np.float32, np.double, np.int32]: + x_np = np.array(x, dtype=dtype) + with test_util.device(use_gpu=True): + x_tf = _get_weak_tensor(x_np, shape=x_np.shape) + y_tf = math_ops.round(x_tf) + y_tf_np = self.evaluate(y_tf) + y_np = np.round(x_np) + self.assertAllClose(y_tf_np, y_np, atol=1e-2) + + +class SignTest(test_util.TensorFlowTestCase): + + def test_complex_sign_gradient(self): + with context.eager_mode(): + x = math_ops.complex(1., 1.) + with backprop.GradientTape() as t: + t.watch(x) + y = math_ops.sign(x) + self.assertAllClose( + t.gradient(y, x), math_ops.complex(0.353553, -0.353553)) + + +@test_util.run_all_in_graph_and_eager_modes +class ReciprocalNoNanTest(test_util.TensorFlowTestCase): + + allowed_dtypes = [dtypes.float32, dtypes.float64, dtypes.complex128] + + def testBasic(self): + for dtype in self.allowed_dtypes: + x = _get_weak_tensor([1.0, 2.0, 0.0, 4.0], dtype=dtype) + + y = math_ops.reciprocal_no_nan(x) + + target = _get_weak_tensor([1.0, 0.5, 0.0, 0.25], dtype=dtype) + + self.assertAllEqual(y, target) + self.assertEqual(y.dtype.base_dtype, target.dtype.base_dtype) + + def testInverse(self): + for dtype in self.allowed_dtypes: + x = np.random.choice([0, 1, 2, 4, 5], size=(5, 5, 5)) + x = _get_weak_tensor(x, dtype=dtype) + + y = math_ops.reciprocal_no_nan(math_ops.reciprocal_no_nan(x)) + + self.assertAllClose(y, x) + self.assertEqual(y.dtype.base_dtype, x.dtype.base_dtype) + + +class EqualityTest(test_util.TensorFlowTestCase, parameterized.TestCase): + + @test_util.run_all_in_graph_and_eager_modes + def testEqualityNone(self): + x = _get_weak_tensor([1.0, 2.0, 0.0, 4.0], dtype=dtypes.float32) + self.assertNotEqual(x, None) + self.assertNotEqual(None, x) + self.assertFalse(math_ops.tensor_equals(x, None)) + self.assertTrue(math_ops.tensor_not_equals(x, None)) + + @parameterized.named_parameters( + (f"-is_equals={is_equals}-float_literal_type={type(float_literal)}" # pylint: disable=g-complex-comprehension + f"-float_literal={float_literal}", is_equals, float_literal) + for float_literal in [4.6, np.float32(4.6), 4.4, np.float32(4.4)] + for is_equals in [True, False]) + def testEqualityNoDowncast(self, is_equals, float_literal): + if (tf2.enabled() and isinstance(float_literal, np.float32) or + not tf2.enabled() and isinstance(float_literal, float)): + # TODO(b/199262800): Remove this skip + self.skipTest("There is a bug in type promotion.") + if is_equals: + op = math_ops.tensor_equals + else: + op = math_ops.tensor_not_equals + x = _get_weak_tensor(4) + try: + result = op(x, float_literal) + if isinstance(result, tensor.Tensor): + result = self.evaluate(result) + except TypeError: + # Throwing a TypeError is OK + return + self.assertEqual(result, not is_equals) + + +@test_util.run_all_in_graph_and_eager_modes +class ErfcinvTest(test_util.TensorFlowTestCase): + + def testErfcinv(self): + values = _get_weak_tensor( + np.random.uniform(0.1, 1.9, size=int(1e4)).astype(np.float32) + ) + approx_id = math_ops.erfc(math_ops.erfcinv(values)) + self.assertAllClose(values, self.evaluate(approx_id)) + + +@test_util.run_all_in_graph_and_eager_modes +class ArgMaxMinTest(test_util.TensorFlowTestCase): + + def _generateRandomWeakTensor(self, dtype, shape): + if dtype.is_integer: + array = np.random.default_rng().integers( + low=dtype.min, high=dtype.max, size=shape, endpoint=True) + return _get_weak_tensor(array, dtype=dtype) + else: + array = np.random.default_rng().uniform(low=-1.0, high=1.0, size=shape) + return _get_weak_tensor(array, dtype=dtype) + + def _getValidDtypes(self): + return (dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64) + + def testArgMax(self): + shape = (24, 8) + for dtype in self._getValidDtypes(): + tf_values = self._generateRandomWeakTensor(dtype, shape) + np_values = self.evaluate(tf_values) + for axis in range(0, len(shape)): + np_max = np.argmax(np_values, axis=axis) + tf_max = math_ops.argmax(tf_values, axis=axis) + self.assertAllEqual(tf_max, np_max) + + def testArgMaxReturnsFirstOccurence(self): + for dtype in self._getValidDtypes(): + values = _get_weak_tensor( + [[10, 11, 15, 15, 10], [12, 12, 10, 10, 12]], dtype=dtype + ) + self.assertAllEqual( + math_ops.argmax(values, axis=1), + np.argmax(self.evaluate(values), axis=1)) + + # Long tensor to ensure works with multithreading/GPU + values = array_ops.zeros(shape=(193681,), dtype=dtype) + self.assertAllEqual(math_ops.argmax(values), 0) + + def testArgMaxUint16(self): + shape = (24, 8) + for dtype in self._getValidDtypes(): + tf_values = self._generateRandomWeakTensor(dtype, shape) + np_values = self.evaluate(tf_values) + for axis in range(0, len(shape)): + np_max = np.argmax(np_values, axis=axis) + tf_max = math_ops.argmax( + tf_values, axis=axis, output_type=dtypes.uint16) + self.assertAllEqual(tf_max, np_max) + + def testArgMin(self): + shape = (24, 8) + for dtype in self._getValidDtypes(): + tf_values = self._generateRandomWeakTensor(dtype, shape) + np_values = self.evaluate(tf_values) + for axis in range(0, len(shape)): + np_min = np.argmin(np_values, axis=axis) + tf_min = math_ops.argmin(tf_values, axis=axis) + self.assertAllEqual(tf_min, np_min) + + def testArgMinReturnsFirstOccurence(self): + for dtype in self._getValidDtypes(): + values = _get_weak_tensor( + [[10, 11, 15, 15, 10], [12, 12, 10, 10, 12]], dtype=dtype + ) + self.assertAllEqual( + math_ops.argmin(values, axis=1), + np.argmin(self.evaluate(values), axis=1)) + + # Long tensor to ensure works with multithreading/GPU + values = array_ops.zeros(shape=(193681,), dtype=dtype) + self.assertAllEqual(math_ops.argmin(values), 0) + + +class CastTest(test_util.TensorFlowTestCase): + + def testCastWithFullType(self): + + @def_function.function + def test_fn(): + ta = tensor_array_ops.TensorArray(dtypes.int32, size=1) + h = math_ops.cast(ta.flow, dtypes.variant) + + t = full_type_pb2.FullTypeDef( + type_id=full_type_pb2.TFT_PRODUCT, + args=[full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY)]) + h.op.experimental_set_type(t) + + ta = tensor_array_ops.TensorArray(dtypes.int32, flow=h) + ta = ta.write(0, _get_weak_tensor(1)) + return ta.stack() + + self.assertAllEqual(self.evaluate(test_fn()), [1]) + +if __name__ == "__main__": + ops.set_dtype_conversion_mode("all") + googletest.main() diff --git a/tensorflow/python/ops/weak_tensor_ops.py b/tensorflow/python/ops/weak_tensor_ops.py index 43ac9db14565c3..c82fe80396a115 100644 --- a/tensorflow/python/ops/weak_tensor_ops.py +++ b/tensorflow/python/ops/weak_tensor_ops.py @@ -16,44 +16,468 @@ import inspect +from tensorflow.python.framework import flexible_dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework.weak_tensor import WeakTensor -from tensorflow.python.ops import weak_tensor_ops_list +from tensorflow.python.framework import tensor +from tensorflow.python.framework import weak_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_bitwise_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import image_ops_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops.numpy_ops import np_array_ops +from tensorflow.python.ops.numpy_ops import np_math_ops +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import dispatch +from tensorflow.python.util import tf_decorator -# This file must depend on math_ops so that e.g. `__add__` is -# added to the Tensor class. -for operator in ops.Tensor.OVERLOADABLE_OPERATORS: - tensor_oper = getattr(ops.Tensor, operator) - setattr(WeakTensor, operator, tensor_oper) - # List of unary ops that have support for WeakTensor. -_TF_UNARY_APIS = weak_tensor_ops_list.ALL_UNARY_OPS +_TF_UNARY_APIS = [] +_TF_BINARY_APIS = [] + + +# ============================================================================== +# Utils to handle WeakTensor inputs and outputs. +# ============================================================================== +# pylint: disable=g-doc-args,g-doc-return-or-yield +def _convert_or_cast(x, dtype, name): + """Converts/casts the input x to dtype.""" + # TODO(b/290216343): remove this branch once we fix the precision loss bug in + # tf.cast. + if isinstance(x, (int, float, complex)): + return ops.convert_to_tensor(x, dtype=dtype, name=name) + else: + return math_ops.cast(x, dtype=dtype, name=name) -def register_unary_weak_tensor_dispatcher(op): - """Add dispatch for WeakTensor inputs.""" +def weak_tensor_unary_op_wrapper(op): + """Infers input type and adds WeakTensor support to unary ops. + + This wrapper infers input type according to the auto dtype conversion + semantics - Tensor and NumPy inputs as Tensor of corresponding dtype and + WeakTensor and python inputs as WeakTensor of corresponding dtype. If the + inferred input dtype is "weak" and the op doesn't specify a return dtype, + returns WeakTensor. + """ signature = inspect.signature(op) - weak_tensor_arg_name = next(iter(signature.parameters.keys())) + arg_names = iter(signature.parameters.keys()) + x_arg_name = next(arg_names) - @dispatch.dispatch_for_api(op, {weak_tensor_arg_name: WeakTensor}) def wrapper(*args, **kwargs): + if not ops.is_auto_dtype_conversion_enabled(): + return op(*args, **kwargs) bound_arguments = signature.bind(*args, **kwargs) bound_arguments.apply_defaults() bound_kwargs = bound_arguments.arguments - bound_kwargs[weak_tensor_arg_name] = bound_kwargs[ - weak_tensor_arg_name - ].to_tensor() - - # Only return WeakTensor if there is no dtype specified. - if bound_kwargs.get("dtype", None) is None: - return WeakTensor.from_tensor((op(**bound_kwargs))) - else: + x = bound_kwargs[x_arg_name] + # No input/output handling needed when input is a Tensor because Tensor + # input in unary op always outputs a Tensor. + if isinstance(x, tensor.Tensor): + return op(**bound_kwargs) + # Infer input type and determine the result promotion type. + try: + target_type, is_weak = flexible_dtypes.result_type(x) + # NotImplementedError is thrown from result_type when x is an + # unsupported input type (e.g. CompositeTensor). + except NotImplementedError: + logging.warning( + "The new dtype semantics do not support" + f" {op.__module__}.{op.__name__}({type(x)}). Falling back to old" + " semantics." + ) return op(**bound_kwargs) + bound_kwargs[x_arg_name] = _convert_or_cast(x, target_type, "x") + # Only return WeakTensor when dtype is NOT specified. + if bound_kwargs.get("dtype", None) is not None: + is_weak = False + return weak_tensor.maybe_convert_to_weak_tensor(op(**bound_kwargs), is_weak) + + wrapper = tf_decorator.make_decorator(op, wrapper) + # Update dispatch dictionary to store monkey-patched op references. + _update_weak_tensor_patched_ops_in_dispatch_dict(wrapper) + + # Add the updated function to list of unary ops with WeakTensor support. + _TF_UNARY_APIS.append(wrapper) return wrapper -for tf_unary_api in _TF_UNARY_APIS: - register_unary_weak_tensor_dispatcher(tf_unary_api) +def weak_tensor_binary_op_wrapper(op): + """Determines result promotion type and adds WeakTensor support to binary ops. + + This wrapper first infers dtype of any Tensor, WeakTensor, python/numpy + inputs. Then, both inputs are promoted to the correct promotion result dtype. + If the result promotion dtype is "weak", returns WeakTensor. + """ + + signature = inspect.signature(op) + arg_names = iter(signature.parameters.keys()) + x_arg_name = next(arg_names) + y_arg_name = next(arg_names) + + def wrapper(*args, **kwargs): + if not ops.is_auto_dtype_conversion_enabled(): + return op(*args, **kwargs) + bound_arguments = signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + bound_kwargs = bound_arguments.arguments + x = bound_kwargs[x_arg_name] + y = bound_kwargs[y_arg_name] + # Infer input type and determine the result promotion type. + try: + target_type, is_weak = flexible_dtypes.result_type(x, y) + # NotImplementedError is thrown from result_type when x or y is an + # unsupported input type (e.g. CompositeTensor). + except NotImplementedError: + logging.warning( + "The new dtype semantics do not support" + f" {op.__module__}.{op.__name__}({type(x)}, {type(y)}). Falling back" + " to old semantics." + ) + return op(**bound_kwargs) + + bound_kwargs[x_arg_name] = _convert_or_cast(x, target_type, "x") + bound_kwargs[y_arg_name] = _convert_or_cast(y, target_type, "y") + return weak_tensor.maybe_convert_to_weak_tensor(op(**bound_kwargs), is_weak) + + wrapper = tf_decorator.make_decorator(op, wrapper) + + # Update dispatch dictionary to store monkey-patched op references. + _update_weak_tensor_patched_ops_in_dispatch_dict(wrapper) + + # Add the updated function to list of binary ops with WeakTensor support. + _TF_BINARY_APIS.append(wrapper) + return wrapper + + +# TODO(b/290672237): Investigate if there is a more elegant solution. +def _update_weak_tensor_patched_ops_in_dispatch_dict(patched_op): + """Update dispatch dictionary to store WeakTensor patched op references. + + _TYPE_BASED_DISPATCH_SIGNATURES in dispatch.py stores mappings from op + reference to all the dispatchers it's registered with. We need to update + this dictionary to add a mapping from the patched-op reference to the + signature dictionary the unpatched-op reference is mapped to. This ensures + that dispatch can be reigstered and unregistered with monkey-patched ops. + """ + dispatch_dict = dispatch._TYPE_BASED_DISPATCH_SIGNATURES # pylint: disable=protected-access + unpatched_api = patched_op.__wrapped__ + if unpatched_api in dispatch_dict: + dispatch_dict[patched_op] = dispatch_dict[unpatched_api] + + +# ============================================================================== +# Monkey patching to add WeakTensor Support. +# ============================================================================== +# Elementwise unary ops +math_ops.abs = weak_tensor_unary_op_wrapper(math_ops.abs) +math_ops.softplus = weak_tensor_unary_op_wrapper(math_ops.softplus) +math_ops.sign = weak_tensor_unary_op_wrapper(math_ops.sign) +math_ops.real = weak_tensor_unary_op_wrapper(math_ops.real) +math_ops.imag = weak_tensor_unary_op_wrapper(math_ops.imag) +math_ops.angle = weak_tensor_unary_op_wrapper(math_ops.angle) +math_ops.round = weak_tensor_unary_op_wrapper(math_ops.round) +math_ops.sigmoid = weak_tensor_unary_op_wrapper(math_ops.sigmoid) +math_ops.log_sigmoid = weak_tensor_unary_op_wrapper(math_ops.log_sigmoid) +math_ops.conj = weak_tensor_unary_op_wrapper(math_ops.conj) +math_ops.reciprocal_no_nan = weak_tensor_unary_op_wrapper( + math_ops.reciprocal_no_nan +) +math_ops.erfinv = weak_tensor_unary_op_wrapper(math_ops.erfinv) +math_ops.ndtri = weak_tensor_unary_op_wrapper(math_ops.ndtri) +math_ops.erfcinv = weak_tensor_unary_op_wrapper(math_ops.erfcinv) +math_ops.ceil = weak_tensor_unary_op_wrapper(math_ops.ceil) +math_ops.sqrt = weak_tensor_unary_op_wrapper(math_ops.sqrt) +math_ops.exp = weak_tensor_unary_op_wrapper(math_ops.exp) +math_ops.rsqrt = weak_tensor_unary_op_wrapper(math_ops.rsqrt) +math_ops.acos = weak_tensor_unary_op_wrapper(math_ops.acos) +math_ops.floor = weak_tensor_unary_op_wrapper(math_ops.floor) +gen_bitwise_ops.invert = weak_tensor_unary_op_wrapper(gen_bitwise_ops.invert) +gen_math_ops.acosh = weak_tensor_unary_op_wrapper(gen_math_ops.acosh) +gen_math_ops.asin = weak_tensor_unary_op_wrapper(gen_math_ops.asin) +gen_math_ops.asinh = weak_tensor_unary_op_wrapper(gen_math_ops.asinh) +gen_math_ops.atan = weak_tensor_unary_op_wrapper(gen_math_ops.atan) +gen_math_ops.atanh = weak_tensor_unary_op_wrapper(gen_math_ops.atanh) +gen_math_ops.cos = weak_tensor_unary_op_wrapper(gen_math_ops.cos) +gen_math_ops.cosh = weak_tensor_unary_op_wrapper(gen_math_ops.cosh) +gen_math_ops.digamma = weak_tensor_unary_op_wrapper(gen_math_ops.digamma) +gen_math_ops.erf = weak_tensor_unary_op_wrapper(gen_math_ops.erf) +gen_math_ops.erfc = weak_tensor_unary_op_wrapper(gen_math_ops.erfc) +gen_math_ops.expm1 = weak_tensor_unary_op_wrapper(gen_math_ops.expm1) +gen_math_ops.lgamma = weak_tensor_unary_op_wrapper(gen_math_ops.lgamma) +gen_math_ops.log = weak_tensor_unary_op_wrapper(gen_math_ops.log) +gen_math_ops.log1p = weak_tensor_unary_op_wrapper(gen_math_ops.log1p) +gen_math_ops.neg = weak_tensor_unary_op_wrapper(gen_math_ops.neg) +gen_math_ops.reciprocal = weak_tensor_unary_op_wrapper(gen_math_ops.reciprocal) +gen_math_ops.rint = weak_tensor_unary_op_wrapper(gen_math_ops.rint) +gen_math_ops.sin = weak_tensor_unary_op_wrapper(gen_math_ops.sin) +gen_math_ops.sinh = weak_tensor_unary_op_wrapper(gen_math_ops.sinh) +gen_math_ops.square = weak_tensor_unary_op_wrapper(gen_math_ops.square) +gen_math_ops.tan = weak_tensor_unary_op_wrapper(gen_math_ops.tan) +gen_math_ops.tanh = weak_tensor_unary_op_wrapper(gen_math_ops.tanh) +array_ops.zeros_like = weak_tensor_unary_op_wrapper(array_ops.zeros_like) +array_ops.zeros_like_v2 = weak_tensor_unary_op_wrapper(array_ops.zeros_like_v2) +array_ops.ones_like = weak_tensor_unary_op_wrapper(array_ops.ones_like) +array_ops.ones_like_v2 = weak_tensor_unary_op_wrapper(array_ops.ones_like_v2) +gen_array_ops.check_numerics = weak_tensor_unary_op_wrapper( + gen_array_ops.check_numerics +) +nn_ops.relu6 = weak_tensor_unary_op_wrapper(nn_ops.relu6) +nn_ops.leaky_relu = weak_tensor_unary_op_wrapper(nn_ops.leaky_relu) +nn_ops.gelu = weak_tensor_unary_op_wrapper(nn_ops.gelu) +nn_ops.log_softmax = weak_tensor_unary_op_wrapper(nn_ops.log_softmax) +nn_ops.log_softmax_v2 = weak_tensor_unary_op_wrapper(nn_ops.log_softmax_v2) +nn_impl.swish = weak_tensor_unary_op_wrapper(nn_impl.swish) +gen_nn_ops.elu = weak_tensor_unary_op_wrapper(gen_nn_ops.elu) +gen_nn_ops.relu = weak_tensor_unary_op_wrapper(gen_nn_ops.relu) +gen_nn_ops.selu = weak_tensor_unary_op_wrapper(gen_nn_ops.selu) +gen_nn_ops.softsign = weak_tensor_unary_op_wrapper(gen_nn_ops.softsign) +image_ops_impl.random_brightness = weak_tensor_unary_op_wrapper( + image_ops_impl.random_brightness +) +image_ops_impl.stateless_random_brightness = weak_tensor_unary_op_wrapper( + image_ops_impl.stateless_random_brightness +) +image_ops_impl.adjust_brightness = weak_tensor_unary_op_wrapper( + image_ops_impl.adjust_brightness +) +image_ops_impl.adjust_gamma = weak_tensor_unary_op_wrapper( + image_ops_impl.adjust_gamma +) +clip_ops.clip_by_value = weak_tensor_unary_op_wrapper(clip_ops.clip_by_value) +special_math_ops.dawsn = weak_tensor_unary_op_wrapper(special_math_ops.dawsn) +special_math_ops.expint = weak_tensor_unary_op_wrapper(special_math_ops.expint) +special_math_ops.fresnel_cos = weak_tensor_unary_op_wrapper( + special_math_ops.fresnel_cos +) +special_math_ops.fresnel_sin = weak_tensor_unary_op_wrapper( + special_math_ops.fresnel_sin +) +special_math_ops.spence = weak_tensor_unary_op_wrapper(special_math_ops.spence) +special_math_ops.bessel_i0 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_i0 +) +special_math_ops.bessel_i0e = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_i0e +) +special_math_ops.bessel_i1 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_i1 +) +special_math_ops.bessel_i1e = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_i1e +) +special_math_ops.bessel_k0 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_k0 +) +special_math_ops.bessel_k0e = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_k0e +) +special_math_ops.bessel_k1 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_k1 +) +special_math_ops.bessel_k1e = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_k1e +) +special_math_ops.bessel_j0 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_j0 +) +special_math_ops.bessel_j1 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_j1 +) +special_math_ops.bessel_y0 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_y0 +) +special_math_ops.bessel_y1 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_y1 +) + +# TF Non-Elementwise Unary Ops +math_ops.reduce_euclidean_norm = weak_tensor_unary_op_wrapper( + math_ops.reduce_euclidean_norm +) +math_ops.reduce_logsumexp = weak_tensor_unary_op_wrapper( + math_ops.reduce_logsumexp +) +math_ops.reduce_max = weak_tensor_unary_op_wrapper(math_ops.reduce_max) +math_ops.reduce_max_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_max_v1) +math_ops.reduce_mean = weak_tensor_unary_op_wrapper(math_ops.reduce_mean) +math_ops.reduce_mean_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_mean_v1) +math_ops.reduce_min = weak_tensor_unary_op_wrapper(math_ops.reduce_min) +math_ops.reduce_min_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_min_v1) +math_ops.reduce_prod = weak_tensor_unary_op_wrapper(math_ops.reduce_prod) +math_ops.reduce_prod_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_prod_v1) +math_ops.reduce_std = weak_tensor_unary_op_wrapper(math_ops.reduce_std) +math_ops.reduce_sum = weak_tensor_unary_op_wrapper(math_ops.reduce_sum) +math_ops.reduce_sum_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_sum_v1) +math_ops.reduce_variance = weak_tensor_unary_op_wrapper( + math_ops.reduce_variance +) +math_ops.trace = weak_tensor_unary_op_wrapper(math_ops.trace) +array_ops.reshape = weak_tensor_unary_op_wrapper(array_ops.reshape) +array_ops.depth_to_space = weak_tensor_unary_op_wrapper( + array_ops.depth_to_space +) +array_ops.depth_to_space_v2 = weak_tensor_unary_op_wrapper( + array_ops.depth_to_space_v2 +) +array_ops.expand_dims = weak_tensor_unary_op_wrapper(array_ops.expand_dims) +array_ops.expand_dims_v2 = weak_tensor_unary_op_wrapper( + array_ops.expand_dims_v2 +) +array_ops.extract_image_patches = weak_tensor_unary_op_wrapper( + array_ops.extract_image_patches +) +array_ops.extract_image_patches_v2 = weak_tensor_unary_op_wrapper( + array_ops.extract_image_patches_v2 +) +array_ops.identity = weak_tensor_unary_op_wrapper(array_ops.identity) +array_ops.matrix_diag = weak_tensor_unary_op_wrapper(array_ops.matrix_diag) +array_ops.matrix_diag_part = weak_tensor_unary_op_wrapper( + array_ops.matrix_diag_part +) +array_ops.matrix_transpose = weak_tensor_unary_op_wrapper( + array_ops.matrix_transpose +) +array_ops.space_to_depth = weak_tensor_unary_op_wrapper( + array_ops.space_to_depth +) +array_ops.space_to_depth_v2 = weak_tensor_unary_op_wrapper( + array_ops.space_to_depth_v2 +) +array_ops.squeeze = weak_tensor_unary_op_wrapper(array_ops.squeeze) +array_ops.squeeze_v2 = weak_tensor_unary_op_wrapper(array_ops.squeeze_v2) +array_ops.stop_gradient = weak_tensor_unary_op_wrapper(array_ops.stop_gradient) +array_ops.tensor_diag_part = weak_tensor_unary_op_wrapper( + array_ops.tensor_diag_part +) +array_ops.transpose = weak_tensor_unary_op_wrapper(array_ops.transpose) +array_ops.transpose_v2 = weak_tensor_unary_op_wrapper(array_ops.transpose_v2) + +# TF NumPy Unary Ops +np_math_ops.abs = weak_tensor_unary_op_wrapper(np_math_ops.abs) +np_math_ops.absolute = weak_tensor_unary_op_wrapper(np_math_ops.absolute) +np_math_ops.angle = weak_tensor_unary_op_wrapper(np_math_ops.angle) +np_math_ops.arccos = weak_tensor_unary_op_wrapper(np_math_ops.arccos) +np_math_ops.arcsin = weak_tensor_unary_op_wrapper(np_math_ops.arcsin) +np_math_ops.arcsinh = weak_tensor_unary_op_wrapper(np_math_ops.arcsinh) +np_math_ops.arctan = weak_tensor_unary_op_wrapper(np_math_ops.arctan) +np_math_ops.arctanh = weak_tensor_unary_op_wrapper(np_math_ops.arctanh) +np_math_ops.bitwise_not = weak_tensor_unary_op_wrapper(np_math_ops.bitwise_not) +np_math_ops.cbrt = weak_tensor_unary_op_wrapper(np_math_ops.cbrt) +np_math_ops.ceil = weak_tensor_unary_op_wrapper(np_math_ops.ceil) +np_math_ops.conj = weak_tensor_unary_op_wrapper(np_math_ops.conj) +np_math_ops.conjugate = weak_tensor_unary_op_wrapper(np_math_ops.conjugate) +np_math_ops.cos = weak_tensor_unary_op_wrapper(np_math_ops.cos) +np_math_ops.cosh = weak_tensor_unary_op_wrapper(np_math_ops.cosh) +np_math_ops.deg2rad = weak_tensor_unary_op_wrapper(np_math_ops.deg2rad) +np_math_ops.exp = weak_tensor_unary_op_wrapper(np_math_ops.exp) +np_math_ops.exp2 = weak_tensor_unary_op_wrapper(np_math_ops.exp2) +np_math_ops.expm1 = weak_tensor_unary_op_wrapper(np_math_ops.expm1) +np_math_ops.fabs = weak_tensor_unary_op_wrapper(np_math_ops.fabs) +np_math_ops.fix = weak_tensor_unary_op_wrapper(np_math_ops.fix) +np_math_ops.floor = weak_tensor_unary_op_wrapper(np_math_ops.floor) +np_math_ops.log = weak_tensor_unary_op_wrapper(np_math_ops.log) +np_math_ops.negative = weak_tensor_unary_op_wrapper(np_math_ops.negative) +np_math_ops.rad2deg = weak_tensor_unary_op_wrapper(np_math_ops.rad2deg) +np_math_ops.reciprocal = weak_tensor_unary_op_wrapper(np_math_ops.reciprocal) +np_math_ops.sin = weak_tensor_unary_op_wrapper(np_math_ops.sin) +np_math_ops.sinh = weak_tensor_unary_op_wrapper(np_math_ops.sinh) +np_math_ops.sqrt = weak_tensor_unary_op_wrapper(np_math_ops.sqrt) +np_math_ops.tan = weak_tensor_unary_op_wrapper(np_math_ops.tan) +np_math_ops.tanh = weak_tensor_unary_op_wrapper(np_math_ops.tanh) +np_math_ops.nanmean = weak_tensor_unary_op_wrapper(np_math_ops.nanmean) +np_math_ops.log2 = weak_tensor_unary_op_wrapper(np_math_ops.log2) +np_math_ops.log10 = weak_tensor_unary_op_wrapper(np_math_ops.log10) +np_math_ops.log1p = weak_tensor_unary_op_wrapper(np_math_ops.log1p) +np_math_ops.positive = weak_tensor_unary_op_wrapper(np_math_ops.positive) +np_math_ops.sinc = weak_tensor_unary_op_wrapper(np_math_ops.sinc) +np_math_ops.square = weak_tensor_unary_op_wrapper(np_math_ops.square) +np_math_ops.diff = weak_tensor_unary_op_wrapper(np_math_ops.diff) +np_math_ops.sort = weak_tensor_unary_op_wrapper(np_math_ops.sort) +np_math_ops.average = weak_tensor_unary_op_wrapper(np_math_ops.average) +np_math_ops.trace = weak_tensor_unary_op_wrapper(np_math_ops.trace) +np_array_ops.amax = weak_tensor_unary_op_wrapper(np_array_ops.amax) +np_array_ops.amin = weak_tensor_unary_op_wrapper(np_array_ops.amin) +np_array_ops.around = weak_tensor_unary_op_wrapper(np_array_ops.around) +np_array_ops.arange = weak_tensor_unary_op_wrapper(np_array_ops.arange) +np_array_ops.array = weak_tensor_unary_op_wrapper(np_array_ops.array) +np_array_ops.asanyarray = weak_tensor_unary_op_wrapper(np_array_ops.asanyarray) +np_array_ops.asarray = weak_tensor_unary_op_wrapper(np_array_ops.asarray) +np_array_ops.ascontiguousarray = weak_tensor_unary_op_wrapper( + np_array_ops.ascontiguousarray +) +np_array_ops.copy = weak_tensor_unary_op_wrapper(np_array_ops.copy) +np_array_ops.cumprod = weak_tensor_unary_op_wrapper(np_array_ops.cumprod) +np_array_ops.cumsum = weak_tensor_unary_op_wrapper(np_array_ops.cumsum) +np_array_ops.diag = weak_tensor_unary_op_wrapper(np_array_ops.diag) +np_array_ops.diagflat = weak_tensor_unary_op_wrapper(np_array_ops.diagflat) +np_array_ops.diagonal = weak_tensor_unary_op_wrapper(np_array_ops.diagonal) +np_array_ops.empty_like = weak_tensor_unary_op_wrapper(np_array_ops.empty_like) +np_array_ops.expand_dims = weak_tensor_unary_op_wrapper( + np_array_ops.expand_dims +) +np_array_ops.flatten = weak_tensor_unary_op_wrapper(np_array_ops.flatten) +np_array_ops.flip = weak_tensor_unary_op_wrapper(np_array_ops.flip) +np_array_ops.fliplr = weak_tensor_unary_op_wrapper(np_array_ops.fliplr) +np_array_ops.flipud = weak_tensor_unary_op_wrapper(np_array_ops.flipud) +np_array_ops.full_like = weak_tensor_unary_op_wrapper(np_array_ops.full_like) +np_array_ops.imag = weak_tensor_unary_op_wrapper(np_array_ops.imag) +np_array_ops.max = weak_tensor_unary_op_wrapper(np_array_ops.max) +np_array_ops.mean = weak_tensor_unary_op_wrapper(np_array_ops.mean) +np_array_ops.min = weak_tensor_unary_op_wrapper(np_array_ops.min) +np_array_ops.moveaxis = weak_tensor_unary_op_wrapper(np_array_ops.moveaxis) +np_array_ops.ones_like = weak_tensor_unary_op_wrapper(np_array_ops.ones_like) +np_array_ops.prod = weak_tensor_unary_op_wrapper(np_array_ops.prod) +np_array_ops.ravel = weak_tensor_unary_op_wrapper(np_array_ops.ravel) +np_array_ops.real = weak_tensor_unary_op_wrapper(np_array_ops.real) +np_array_ops.reshape = weak_tensor_unary_op_wrapper(np_array_ops.reshape) +np_array_ops.rot90 = weak_tensor_unary_op_wrapper(np_array_ops.rot90) +np_array_ops.round = weak_tensor_unary_op_wrapper(np_array_ops.round) +np_array_ops.squeeze = weak_tensor_unary_op_wrapper(np_array_ops.squeeze) +np_array_ops.std = weak_tensor_unary_op_wrapper(np_array_ops.std) +np_array_ops.sum = weak_tensor_unary_op_wrapper(np_array_ops.sum) +np_array_ops.swapaxes = weak_tensor_unary_op_wrapper(np_array_ops.swapaxes) +np_array_ops.transpose = weak_tensor_unary_op_wrapper(np_array_ops.transpose) +np_array_ops.triu = weak_tensor_unary_op_wrapper(np_array_ops.triu) +np_array_ops.vander = weak_tensor_unary_op_wrapper(np_array_ops.vander) +np_array_ops.var = weak_tensor_unary_op_wrapper(np_array_ops.var) +np_array_ops.zeros_like = weak_tensor_unary_op_wrapper(np_array_ops.zeros_like) + +# ============================================================================== +# Update old op references. +# ============================================================================== +# Update Tensor dunder methods. +tensor.Tensor.__add__ = math_ops.add +tensor.Tensor.__sub__ = math_ops.sub +tensor.Tensor.__mul__ = math_ops.multiply +tensor.Tensor.__div__ = math_ops.div +tensor.Tensor.__truediv__ = math_ops.truediv +tensor.Tensor.__floordiv__ = math_ops.floordiv +tensor.Tensor.__mod__ = gen_math_ops.floor_mod +tensor.Tensor.__pow__ = math_ops.pow +tensor.Tensor.__matmul__ = math_ops.matmul + +# Set WeakTensor dunder methods. +weak_tensor.WeakTensor.__invert__ = math_ops.invert_ +weak_tensor.WeakTensor.__neg__ = gen_math_ops.neg +weak_tensor.WeakTensor.__abs__ = math_ops.abs +weak_tensor.WeakTensor.__add__ = math_ops.add +weak_tensor.WeakTensor.__sub__ = math_ops.sub +weak_tensor.WeakTensor.__mul__ = math_ops.multiply +weak_tensor.WeakTensor.__div__ = math_ops.div +weak_tensor.WeakTensor.__truediv__ = math_ops.truediv +weak_tensor.WeakTensor.__floordiv__ = math_ops.floordiv +weak_tensor.WeakTensor.__mod__ = gen_math_ops.floor_mod +weak_tensor.WeakTensor.__pow__ = math_ops.pow +weak_tensor.WeakTensor.__matmul__ = math_ops.matmul + +# Add/Update NumPy methods in Tensor and WeakTensor. +np_math_ops.enable_numpy_methods_on_tensor() +np_math_ops._enable_numpy_methods(weak_tensor.WeakTensor) # pylint: disable=protected-access diff --git a/tensorflow/python/ops/weak_tensor_ops_list.py b/tensorflow/python/ops/weak_tensor_ops_list.py deleted file mode 100644 index 067feb621e0e61..00000000000000 --- a/tensorflow/python/ops/weak_tensor_ops_list.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Lists of ops that support WeakTensor.""" - -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import gen_bitwise_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import gen_nn_ops -from tensorflow.python.ops import image_ops_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_impl -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops.numpy_ops import np_array_ops -from tensorflow.python.ops.numpy_ops import np_math_ops - - -# Below are lists of unary ops that return a WeakTensor when given a WeakTensor -# input. These are some of the reasons why ops may not support WeakTensor. -# (1) The return dtype is specified. (e.g. tofloat(), cast(), is_finite()) -# (2) The list is prioritized to unary elementwise ops, TF-NumPy ops, math_ops, -# and array_ops. -# (3) There is no "weak" string type so any string ops are not supported. -# If you wish to add support to a specific unary op, add the unary op to a -# corresponding list. - -_ELEMENTWISE_UNARY_OPS = [ - math_ops.abs, - math_ops.softplus, - math_ops.sign, - math_ops.real, - math_ops.imag, - math_ops.angle, - math_ops.round, - math_ops.sigmoid, - math_ops.log_sigmoid, - math_ops.conj, - math_ops.reciprocal_no_nan, - math_ops.erfinv, - math_ops.ndtri, - math_ops.erfcinv, - math_ops.ceil, - math_ops.sqrt, - math_ops.exp, - math_ops.rsqrt, - math_ops.acos, - math_ops.floor, - gen_bitwise_ops.invert, - gen_math_ops.acosh, - gen_math_ops.asin, - gen_math_ops.asinh, - gen_math_ops.atan, - gen_math_ops.atanh, - gen_math_ops.cos, - gen_math_ops.cosh, - gen_math_ops.digamma, - gen_math_ops.erf, - gen_math_ops.erfc, - gen_math_ops.expm1, - gen_math_ops.lgamma, - gen_math_ops.log, - gen_math_ops.log1p, - gen_math_ops.neg, - gen_math_ops.reciprocal, - gen_math_ops.rint, - gen_math_ops.sin, - gen_math_ops.sinh, - gen_math_ops.square, - gen_math_ops.tan, - gen_math_ops.tanh, - array_ops.zeros_like, - array_ops.zeros_like_v2, - array_ops.ones_like, - array_ops.ones_like_v2, - gen_array_ops.check_numerics, - nn_ops.relu6, - nn_ops.leaky_relu, - nn_ops.gelu, - nn_ops.log_softmax, - gen_nn_ops.elu, - gen_nn_ops.relu, - gen_nn_ops.selu, - gen_nn_ops.softsign, - image_ops_impl.random_brightness, - image_ops_impl.stateless_random_brightness, - image_ops_impl.adjust_brightness, - image_ops_impl.adjust_gamma, - nn_impl.swish, - clip_ops.clip_by_value, - special_math_ops.dawsn, - special_math_ops.expint, - special_math_ops.fresnel_cos, - special_math_ops.fresnel_sin, - special_math_ops.spence, - special_math_ops.bessel_i0, - special_math_ops.bessel_i0e, - special_math_ops.bessel_i1, - special_math_ops.bessel_i1e, - special_math_ops.bessel_k0, - special_math_ops.bessel_k0e, - special_math_ops.bessel_k1, - special_math_ops.bessel_k1e, - special_math_ops.bessel_j0, - special_math_ops.bessel_j1, - special_math_ops.bessel_y0, - special_math_ops.bessel_y1, -] -_TF_UNARY_OPS = [ - math_ops.reduce_euclidean_norm, - math_ops.reduce_logsumexp, - math_ops.reduce_max, - math_ops.reduce_max_v1, - math_ops.reduce_mean, - math_ops.reduce_mean_v1, - math_ops.reduce_min, - math_ops.reduce_min_v1, - math_ops.reduce_prod, - math_ops.reduce_prod_v1, - math_ops.reduce_std, - math_ops.reduce_sum, - math_ops.reduce_sum_v1, - math_ops.reduce_variance, - math_ops.trace, - array_ops.depth_to_space, - array_ops.depth_to_space_v2, - array_ops.expand_dims, - array_ops.expand_dims_v2, - array_ops.extract_image_patches, - array_ops.extract_image_patches_v2, - array_ops.identity, - array_ops.matrix_diag, - array_ops.matrix_diag_part, - array_ops.matrix_transpose, - array_ops.shape, - array_ops.shape_v2, - array_ops.size, - array_ops.size_v2, - array_ops.space_to_depth, - array_ops.space_to_depth_v2, - array_ops.squeeze, - array_ops.squeeze_v2, - array_ops.stop_gradient, - array_ops.tensor_diag_part, - array_ops.transpose, - array_ops.transpose_v2, -] -_TF_NUMPY_UNARY_OPS = [ - np_math_ops.abs, - np_math_ops.absolute, - np_math_ops.angle, - np_math_ops.arccos, - np_math_ops.arcsin, - np_math_ops.arcsinh, - np_math_ops.arctan, - np_math_ops.arctanh, - np_math_ops.bitwise_not, - np_math_ops.cbrt, - np_math_ops.ceil, - np_math_ops.conj, - np_math_ops.conjugate, - np_math_ops.cos, - np_math_ops.cosh, - np_math_ops.deg2rad, - np_math_ops.exp, - np_math_ops.exp2, - np_math_ops.expm1, - np_math_ops.fabs, - np_math_ops.fix, - np_math_ops.floor, - np_math_ops.log, - np_math_ops.negative, - np_math_ops.rad2deg, - np_math_ops.reciprocal, - np_math_ops.sin, - np_math_ops.sinh, - np_math_ops.sqrt, - np_math_ops.tan, - np_math_ops.tanh, - np_math_ops.nanmean, - np_math_ops.log2, - np_math_ops.log10, - np_math_ops.log1p, - np_math_ops.positive, - np_math_ops.sinc, - np_math_ops.square, - np_math_ops.diff, - np_math_ops.sort, - np_math_ops.average, - np_math_ops.trace, - np_array_ops.amax, - np_array_ops.amin, - np_array_ops.around, - np_array_ops.arange, - np_array_ops.array, - np_array_ops.asanyarray, - np_array_ops.asarray, - np_array_ops.ascontiguousarray, - np_array_ops.copy, - np_array_ops.cumprod, - np_array_ops.cumsum, - np_array_ops.diag, - np_array_ops.diagflat, - np_array_ops.diagonal, - np_array_ops.empty_like, - np_array_ops.expand_dims, - np_array_ops.flatten, - np_array_ops.flip, - np_array_ops.fliplr, - np_array_ops.flipud, - np_array_ops.imag, - np_array_ops.max, - np_array_ops.mean, - np_array_ops.min, - np_array_ops.moveaxis, - np_array_ops.ones_like, - np_array_ops.prod, - np_array_ops.ravel, - np_array_ops.real, - np_array_ops.reshape, - np_array_ops.rot90, - np_array_ops.round, - np_array_ops.squeeze, - np_array_ops.std, - np_array_ops.sum, - np_array_ops.swapaxes, - np_array_ops.transpose, - np_array_ops.triu, - np_array_ops.vander, - np_array_ops.var, - np_array_ops.zeros_like, -] - -# Below are lists of binary ops that have support for WeakTensor input(s). -_ELEMENTWISE_BINARY_OPS = [] - -ALL_UNARY_OPS = _ELEMENTWISE_UNARY_OPS + _TF_UNARY_OPS + _TF_NUMPY_UNARY_OPS -ALL_BINARY_OPS = _ELEMENTWISE_BINARY_OPS diff --git a/tensorflow/python/ops/weak_tensor_ops_test.py b/tensorflow/python/ops/weak_tensor_ops_test.py index db15af9a0b63e4..18708068b5d75f 100644 --- a/tensorflow/python/ops/weak_tensor_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_ops_test.py @@ -13,11 +13,15 @@ # limitations under the License. # ============================================================================== """Tests for TF ops with WeakTensor input.""" + from absl.testing import parameterized +import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.framework.weak_tensor import WeakTensor from tensorflow.python.ops import array_ops @@ -26,15 +30,21 @@ from tensorflow.python.ops import gen_bitwise_ops from tensorflow.python.ops import image_ops_impl from tensorflow.python.ops import math_ops -from tensorflow.python.ops import weak_tensor_ops # pylint: disable=unused-import -from tensorflow.python.ops import weak_tensor_ops_list +from tensorflow.python.ops import weak_tensor_ops +from tensorflow.python.ops import weak_tensor_test_util from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_config from tensorflow.python.ops.numpy_ops import np_math_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import googletest +from tensorflow.python.util import dispatch + + +_get_weak_tensor = weak_tensor_test_util.get_weak_tensor +_convert_to_input_type = weak_tensor_test_util.convert_to_input_type -_TF_UNARY_APIS = weak_tensor_ops_list.ALL_UNARY_OPS +_TF_UNARY_APIS = weak_tensor_ops._TF_UNARY_APIS _TF_UNARY_APIS_SPECIFIC_DTYPE = [ math_ops.to_float, math_ops.to_double, @@ -51,9 +61,11 @@ image_ops_impl.adjust_brightness, clip_ops.clip_by_value, np_array_ops.expand_dims, + np_array_ops.full_like, np_array_ops.moveaxis, np_array_ops.reshape, np_array_ops.swapaxes, + array_ops.reshape, array_ops.depth_to_space, array_ops.depth_to_space_v2, array_ops.expand_dims, @@ -86,101 +98,299 @@ ] +class MyTensor(extension_type.ExtensionType): + value: tensor.Tensor + + class WeakTensorOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): # Test unary ops with one input. - @parameterized.parameters( - set(_TF_UNARY_APIS) - set(_TF_UNARY_APIS_WITH_MULT_INPUT) + @parameterized.named_parameters( + (api.__module__ + "." + api.__name__, api) + for api in set(_TF_UNARY_APIS) - set(_TF_UNARY_APIS_WITH_MULT_INPUT) ) def test_unary_ops_return_weak_tensor(self, unary_api): - op_input = _get_test_input(unary_api) - res = unary_api(op_input) - # Check that WeakTensor is returned. + weak_tensor_input, python_input, tensor_input, numpy_input = ( + _get_test_input(unary_api) + ) + + # Check that WeakTensor input outputs a WeakTensor. + res = unary_api(weak_tensor_input) self.assertIsInstance(res, WeakTensor) + expected_result = unary_api(weak_tensor_input.tensor) # Check that the actual result is correct. - expected_result = unary_api(op_input.tensor) self.assertAllEqual(res, expected_result) + # Check that python nested scalar type (weak type) returns a WeakTensor. + res = unary_api(python_input) + self.assertIsInstance(res, WeakTensor) + + # Check that normal Tensor input outputs a Tensor. + res = unary_api(tensor_input) + self.assertIsInstance(res, tensor.Tensor) + + # Check that numpy type input outputs a Tensor. + res = unary_api(numpy_input) + self.assertIsInstance(res, tensor.Tensor) + # Test unary ops with multiple inputs. - def test_multi_arg_unary_ops_return_weak_tensor(self): - a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) + @parameterized.parameters( + ("WeakTensor", dtypes.float32, WeakTensor), + ("Python", dtypes.float32, WeakTensor), + ("NumPy", np.float32, tensor.Tensor), + ("NumPy", None, tensor.Tensor), + ("Tensor", dtypes.float32, tensor.Tensor), + ) + def test_multi_arg_unary_ops_return_weak_tensor( + self, input_type, input_dtype, result_type + ): + test_input = _convert_to_input_type( + [1.0, 2.0, 3.0], input_type, input_dtype + ) + self.assertIsInstance( + gen_array_ops.check_numerics(test_input, message=""), result_type + ) self.assertIsInstance( - gen_array_ops.check_numerics(a, message=""), WeakTensor + image_ops_impl.random_brightness(test_input, 0.2), result_type ) - self.assertIsInstance(image_ops_impl.random_brightness(a, 0.2), WeakTensor) self.assertIsInstance( image_ops_impl.stateless_random_brightness( - image=a, max_delta=0.2, seed=(1, 2) + image=test_input, max_delta=0.2, seed=(1, 2) ), - WeakTensor, + result_type, ) self.assertIsInstance( - image_ops_impl.adjust_brightness(a, delta=0.2), WeakTensor + image_ops_impl.adjust_brightness(test_input, delta=0.2), result_type ) self.assertIsInstance( - clip_ops.clip_by_value(a, clip_value_min=1.1, clip_value_max=2.2), - WeakTensor, + clip_ops.clip_by_value( + test_input, clip_value_min=1.1, clip_value_max=2.2 + ), + result_type, ) - self.assertIsInstance(np_array_ops.expand_dims(a, axis=0), WeakTensor) self.assertIsInstance( - np_array_ops.moveaxis(a, source=0, destination=0), WeakTensor + np_array_ops.expand_dims(test_input, axis=0), result_type ) - self.assertIsInstance(np_array_ops.reshape(a, newshape=(3,)), WeakTensor) self.assertIsInstance( - np_array_ops.swapaxes(a, axis1=0, axis2=0), WeakTensor + np_array_ops.moveaxis(test_input, source=0, destination=0), result_type + ) + self.assertIsInstance( + np_array_ops.reshape(test_input, newshape=(3,)), result_type + ) + self.assertIsInstance( + np_array_ops.swapaxes(test_input, axis1=0, axis2=0), result_type + ) + self.assertIsInstance( + array_ops.reshape(test_input, shape=(3,)), result_type + ) + self.assertIsInstance( + array_ops.expand_dims(test_input, axis=0), result_type ) - self.assertIsInstance(array_ops.expand_dims(a, axis=0), WeakTensor) # Test unary ops with a specific return dtype. @parameterized.parameters(_TF_UNARY_APIS_SPECIFIC_DTYPE) def test_unary_ops_return_normal_tensor(self, unary_api_specific_dtype): - a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) - res = unary_api_specific_dtype(a) - self.assertIsInstance(res, ops.Tensor) + # All inputs should output a normal Tensor because return dtype is + # specified. + weak_tensor_input = _get_weak_tensor([1, 2, 3], dtypes.float32) + res = unary_api_specific_dtype(weak_tensor_input) + self.assertIsInstance(res, tensor.Tensor) + + python_input = [1.0, 2.0, 3.0] + res = unary_api_specific_dtype(python_input) + self.assertIsInstance(res, tensor.Tensor) + + tensor_input = constant_op.constant([1.0, 2.0, 3.0], dtypes.float32) + res = unary_api_specific_dtype(tensor_input) + self.assertIsInstance(res, tensor.Tensor) + + tensor_input = np.array([1.0, 2.0, 3.0]) + res = unary_api_specific_dtype(tensor_input) + self.assertIsInstance(res, tensor.Tensor) # Test unary ops with optional dtype arg. - def test_elementwise_unary_ops_optional_dtype(self): - a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) + @parameterized.parameters( + ("WeakTensor", dtypes.float32, WeakTensor), + ("Python", None, WeakTensor), + ("NumPy", np.float32, tensor.Tensor), + ("NumPy", None, tensor.Tensor), + ("Tensor", dtypes.float32, tensor.Tensor), + ) + def test_elementwise_unary_ops_optional_dtype( + self, input_type, input_dtype, result_type + ): + test_input = _convert_to_input_type( + [1.0, 2.0, 3.0], input_type, input_dtype + ) # No dtype specified in the argument. - self.assertIsInstance(array_ops.zeros_like(a), WeakTensor) - self.assertIsInstance(array_ops.ones_like(a), WeakTensor) - self.assertIsInstance(array_ops.ones_like(a, dtype=None), WeakTensor) + self.assertIsInstance(array_ops.zeros_like(test_input), result_type) + self.assertIsInstance(array_ops.ones_like(test_input), result_type) + self.assertIsInstance( + array_ops.ones_like(test_input, dtype=None), result_type + ) # dtype specified in the argument. self.assertIsInstance( - array_ops.zeros_like(a, dtype=dtypes.int32), ops.Tensor + array_ops.zeros_like(test_input, dtype=dtypes.int32), tensor.Tensor ) self.assertIsInstance( - array_ops.ones_like(a, dtype=dtypes.int32), ops.Tensor + array_ops.ones_like(test_input, dtype=dtypes.int32), tensor.Tensor ) - self.assertIsInstance(array_ops.zeros_like(a, dtypes.int32), ops.Tensor) - self.assertIsInstance(array_ops.ones_like(a, dtypes.int32), ops.Tensor) self.assertIsInstance( - np_array_ops.arange( - WeakTensor(constant_op.constant(5)), 0, 1, dtypes.float32 - ), - ops.Tensor, + array_ops.zeros_like(test_input, dtypes.int32), tensor.Tensor + ) + self.assertIsInstance( + array_ops.ones_like(test_input, dtypes.int32), tensor.Tensor + ) + + @parameterized.parameters( + ("WeakTensor", dtypes.float32, None, WeakTensor), + ("WeakTensor", dtypes.float32, dtypes.int32, tensor.Tensor), + ("Python", None, None, WeakTensor), + ("Python", None, dtypes.int32, tensor.Tensor), + ("NumPy", None, None, tensor.Tensor), + ("NumPy", None, np.int32, tensor.Tensor), + ("Tensor", dtypes.float32, None, tensor.Tensor), + ("Tensor", dtypes.float32, dtypes.int32, tensor.Tensor), + ) + # Test unary ops with multiple args that includes an optional dtype arg. + def test_elementwise_unary_ops_optional_dtype_with_multi_args( + self, input_type, input_dtype, dtype_arg, result_type + ): + test_input = _convert_to_input_type(5, input_type, input_dtype) + self.assertIsInstance( + np_array_ops.arange(test_input, 10, dtype=dtype_arg), result_type + ) + self.assertIsInstance( + np_array_ops.full_like(test_input, 1, dtype=dtype_arg), result_type ) # Test unary ops that require dtype arg. def test_unary_ops_explicit_dtype_return(self): - a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) - self.assertIsInstance(math_ops.cast(a, dtypes.int32), ops.Tensor) - self.assertIsInstance(math_ops.saturate_cast(a, dtypes.int32), ops.Tensor) + wt_input = _get_weak_tensor([1, 2, 3], dtypes.float32) + self.assertIsInstance(math_ops.cast(wt_input, dtypes.int32), tensor.Tensor) + self.assertIsInstance( + math_ops.saturate_cast(wt_input, dtypes.int32), tensor.Tensor + ) + + python_input = [1.0, 2.0, 3.0] + self.assertIsInstance( + math_ops.cast(python_input, dtypes.int32), tensor.Tensor + ) + self.assertIsInstance( + math_ops.saturate_cast(python_input, dtypes.int32), tensor.Tensor + ) + + def test_unsupported_input_type_in_weak_tensor_ops(self): + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8] + ) + # Any unsupported type should be ignored in WeakTensor wrapper. + self.assertIsInstance(math_ops.abs(rt), ragged_tensor.RaggedTensor) + def test_update_weak_tensor_patched_ops_in_dispatch_dict(self): + dispatch_dict = dispatch._TYPE_BASED_DISPATCH_SIGNATURES + # Test that we can use the updated op reference as a key to the dispatch + # dictionary. + self.assertTrue(hasattr(math_ops.abs, "_tf_decorator")) + self.assertNotEmpty(dispatch_dict[math_ops.abs]) + def test_weak_tensor_ops_dispatch(self): + @dispatch.dispatch_for_api(math_ops.abs) + def my_abs(x: MyTensor): + return MyTensor(math_ops.abs(x.value)) + + self.assertIsInstance(my_abs(MyTensor(constant_op.constant(1.0))), MyTensor) + + # Test unregistering dispatch with patched op reference. + dispatch.unregister_dispatch_for(my_abs) + with self.assertRaises(ValueError): + math_ops.abs(MyTensor(constant_op.constant(1.0))) + + def testWeakTensorDunderMethods(self): + x = _get_weak_tensor([1, 2, 3]) + + self.assertIsInstance(abs(x), WeakTensor) + self.assertIsInstance(~x, WeakTensor) + self.assertIsInstance(-x, WeakTensor) + + @parameterized.parameters( + ("T", WeakTensor), + ("ndim", int), + ("size", None), + ("data", WeakTensor), + ) + def testNumpyAttributesOnWeakTensor(self, np_attribute, result_type): + a = weak_tensor_test_util.get_weak_tensor(([1, 2, 3])) + b = constant_op.constant([1, 2, 3]) + + self.assertTrue(hasattr(a, np_attribute)) + wt_np_attr = getattr(a, np_attribute) + t_np_attr = getattr(b, np_attribute) + if result_type is None: + # The result type may differ depending on which machine test runs on + # (e.g. size) + self.assertEqual(type(wt_np_attr), type(t_np_attr)) + else: + self.assertIsInstance(wt_np_attr, result_type) + self.assertAllEqual(wt_np_attr, t_np_attr) + + @parameterized.parameters( + ("__pos__", WeakTensor), + ("__round__", WeakTensor, 2), + ("tolist", list), + ("flatten", WeakTensor), + ("transpose", WeakTensor), + ("reshape", WeakTensor, (3, 1)), + ("ravel", WeakTensor), + ("clip", tensor.Tensor, 1.1, 2.2), + ("astype", tensor.Tensor, dtypes.float32), + ("max", WeakTensor), + ("mean", WeakTensor), + ("min", WeakTensor), + ) + def testNumpyMethodsOnWeakTensor(self, np_method, result_type, *args): + a = weak_tensor_test_util.get_weak_tensor(([1, 2, 3])) + b = constant_op.constant([1, 2, 3]) + self.assertTrue(hasattr(a, np_method)) + + wt_np_method_call = getattr(a, np_method) + t_np_method_call = getattr(b, np_method) + wt_np_result = wt_np_method_call(*args) + t_np_result = t_np_method_call(*args) + self.assertIsInstance(wt_np_result, result_type) + self.assertAllEqual(wt_np_result, t_np_result) + + +# TODO(b/289333658): Add tf.constant(x) with no dtype arg as a "weak" input +# after adding WeakTensor construction logic to tf.constant. def _get_test_input(op): if op in _TF_UNARY_APIS_WITH_INT_INPUT: - return WeakTensor(constant_op.constant(5, dtypes.int32)) + return ( + _get_weak_tensor(5, dtypes.int32), + 5, + constant_op.constant(5, dtypes.int32), + np.array(5), + ) elif op in _TF_UNARY_APIS_WITH_2D_INPUT: - return WeakTensor(constant_op.constant([[1, 2], [3, 4]], dtypes.int32)) + return ( + _get_weak_tensor([[1, 2], [3, 4]], dtypes.int32), + [[1, 2], [3, 4]], + constant_op.constant([[1, 2], [3, 4]], dtypes.int32), + np.array([[1, 2], [3, 4]]), + ) else: - return WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) + return ( + _get_weak_tensor([1.0, 2.0, 3.0], dtype=dtypes.float32), + [1.0, 2.0, 3.0], + constant_op.constant([1.0, 2.0, 3.0], dtype=dtypes.float32), + np.array([1.0, 2.0, 3.0]), + ) if __name__ == "__main__": ops.enable_eager_execution() # Enabling numpy behavior adds some NumPy methods to the Tensor class, which # TF-NumPy ops depend on. - np_config.enable_numpy_behavior() + np_config.enable_numpy_behavior(dtype_conversion_mode="all") googletest.main() diff --git a/tensorflow/python/ops/weak_tensor_test_util.py b/tensorflow/python/ops/weak_tensor_test_util.py index eec7b936a24015..aa117def50c086 100644 --- a/tensorflow/python/ops/weak_tensor_test_util.py +++ b/tensorflow/python/ops/weak_tensor_test_util.py @@ -14,7 +14,28 @@ # ============================================================================== """Utils for WeakTensor related tests.""" +import numpy as np + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework.weak_tensor import WeakTensor + + +def convert_to_input_type(base_input, input_type, dtype=None): + if input_type == "WeakTensor": + return WeakTensor(constant_op.constant(base_input, dtype=dtype)) + elif input_type == "Tensor": + return constant_op.constant(base_input, dtype=dtype) + elif input_type == "NumPy": + return np.array(base_input, dtype=dtype) + elif input_type == "Python": + return base_input + else: + raise ValueError(f"The provided input_type {input_type} is not supported.") + + +def get_weak_tensor(*args, **kwargs): + return WeakTensor(constant_op.constant(*args, **kwargs)) class DtypeConversionTestEnv: diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index b0b505df8c0cf4..e702a8f7e8bf2b 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util @@ -805,7 +806,7 @@ def _get_structured_grad_output(outputs, grads, body_grad_graph): dense_shape=outputs[outputs_idx + 2])) outputs_idx += 3 else: - assert isinstance(output, ops.Tensor) + assert isinstance(output, tensor_lib.Tensor) result.append(outputs[outputs_idx]) outputs_idx += 1 diff --git a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py index 64f7ed8d40db61..56e352a63c4207 100644 --- a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py +++ b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py @@ -18,7 +18,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import indexed_slices -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -176,7 +176,7 @@ def _create_grad_indexed_slices_init(grad_output_slices, forward_input): Zeros IndexedSlices, created in current Graph. """ assert isinstance(grad_output_slices, indexed_slices.IndexedSlices) - assert isinstance(forward_input, ops.Tensor) + assert isinstance(forward_input, tensor.Tensor) values_out = grad_output_slices.values indices_out = grad_output_slices.indices diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 3ba487d65d4c68..a3a6e8d5bd78cd 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") -load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_cloud", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_portable", "tf_python_pybind_extension") load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( @@ -136,7 +136,7 @@ tf_python_pybind_extension( cc_library( name = "python_hooks", hdrs = ["python_hooks.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], # Incompatible with -fexceptions. visibility = [ diff --git a/tensorflow/python/pywrap_dtensor_device.cc b/tensorflow/python/pywrap_dtensor_device.cc index df42d1ec27eea5..842abaf8393f7c 100644 --- a/tensorflow/python/pywrap_dtensor_device.cc +++ b/tensorflow/python/pywrap_dtensor_device.cc @@ -395,6 +395,8 @@ PYBIND11_MODULE(_pywrap_dtensor_device, m) { tensor_handle, element_layouts, device_info, status.get()); }); py::class_(m, "Mesh") + .def(py::init([](Mesh& mesh) { return mesh; }), py::arg("mesh"), + "Create a copy of a mesh.") .def(py::init(&Mesh::CreateMesh)) .def(py::init([](absl::string_view single_device) { auto mesh = Mesh::GetSingleDeviceMesh(single_device); diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 767114344e2ad9..34f69a011a43f0 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -76,6 +76,7 @@ py_strict_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/lib/io:lib", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:tf_logging", @@ -187,6 +188,7 @@ tf_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/lib/io:lib", @@ -248,11 +250,11 @@ tf_py_strict_test( ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", - "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", @@ -275,6 +277,7 @@ py_strict_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", @@ -326,8 +329,7 @@ py_strict_library( "//tensorflow/python/eager:function", "//tensorflow/python/eager/polymorphic_function:attributes", "//tensorflow/python/framework:composite_tensor", - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/trackable:base", "//tensorflow/python/types:core", @@ -735,7 +737,7 @@ py_strict_library( "//tensorflow/python/framework:function_def_to_graph", "//tensorflow/python/framework:op_def_registry", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:type_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:custom_gradient", @@ -772,10 +774,9 @@ tf_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:extension_type", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:type_spec", diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index bf2e6241d7caab..18bdc53a3d77e8 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging @@ -514,7 +515,7 @@ def _add_train_op(self, train_op): TypeError if Train op is not of type `Operation`. """ if train_op is not None: - if (not isinstance(train_op, ops.Tensor) and + if (not isinstance(train_op, tensor.Tensor) and not isinstance(train_op, ops.Operation)): raise TypeError(f"`train_op` {train_op} needs to be a Tensor or Op.") ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) @@ -737,7 +738,7 @@ def _asset_path_from_tensor(path_tensor): Raises: TypeError if tensor does not match expected op type, dtype or value. """ - if not isinstance(path_tensor, ops.Tensor): + if not isinstance(path_tensor, tensor.Tensor): raise TypeError(f"Asset path tensor {path_tensor} must be a Tensor.") if path_tensor.op.type != "Const": raise TypeError(f"Asset path tensor {path_tensor} must be of type constant." diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 58d185e185b666..4b1c57a746e240 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import function_def_to_graph as function_def_lib from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient @@ -44,7 +44,8 @@ def _is_tensor(t): - return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) + return isinstance( + t, (tensor.Tensor, resource_variable_ops.BaseResourceVariable)) # TODO(b/205016027): Update this to just use ConcreteFunction.__call__ with the @@ -72,7 +73,7 @@ def _call_concrete_function(function, inputs): flatten_expected = nest.flatten(expected_structure, expand_composites=True) tensor_inputs = [] for arg, expected in zip(flatten_inputs, flatten_expected): - if isinstance(expected, tensor_spec.TensorSpec): + if isinstance(expected, tensor.TensorSpec): tensor_inputs.append( ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) elif isinstance(expected, resource_variable_ops.VariableSpec): @@ -89,7 +90,7 @@ def _try_convert_to_tensor_spec(arg, dtype_hint): # Note: try conversion in a FuncGraph to avoid polluting current context. with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) - return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) + return tensor.TensorSpec(shape=result.shape, dtype=result.dtype) except (TypeError, ValueError): return None @@ -103,10 +104,10 @@ def _concrete_function_callable_with(function, inputs, allow_conversion): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): - if isinstance(expected, tensor_spec.TensorSpec): + if isinstance(expected, tensor.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) - if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): + if not _is_tensor(arg) and not isinstance(arg, tensor.TensorSpec): return False if arg.dtype != expected.dtype: return False diff --git a/tensorflow/python/saved_model/model_utils/BUILD b/tensorflow/python/saved_model/model_utils/BUILD index 86485c3b7619b9..4fd87596a74aed 100644 --- a/tensorflow/python/saved_model/model_utils/BUILD +++ b/tensorflow/python/saved_model/model_utils/BUILD @@ -43,6 +43,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/saved_model:signature_def_utils", ], @@ -59,8 +60,8 @@ py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/ops:metrics", diff --git a/tensorflow/python/saved_model/model_utils/export_output.py b/tensorflow/python/saved_model/model_utils/export_output.py index c38b12525d90d9..8903a08fba798d 100644 --- a/tensorflow/python/saved_model/model_utils/export_output.py +++ b/tensorflow/python/saved_model/model_utils/export_output.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.saved_model import signature_def_utils @@ -86,7 +87,7 @@ def _wrap_and_check_outputs( for key, value in outputs.items(): error_name = error_label or single_output_default_name key = self._check_output_key(key, error_name) - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor.Tensor): raise ValueError( '{} output value must be a Tensor; got {}.'.format( error_name, value)) @@ -128,12 +129,12 @@ def __init__(self, scores=None, classes=None): `Tensor` with the correct dtype. """ if (scores is not None - and not (isinstance(scores, ops.Tensor) + and not (isinstance(scores, tensor.Tensor) and scores.dtype.is_floating)): raise ValueError('Classification scores must be a float32 Tensor; ' 'got {}'.format(scores)) if (classes is not None - and not (isinstance(classes, ops.Tensor) + and not (isinstance(classes, tensor.Tensor) and dtypes.as_dtype(classes.dtype) == dtypes.string)): raise ValueError('Classification classes must be a string Tensor; ' 'got {}'.format(classes)) @@ -186,7 +187,7 @@ def __init__(self, value): Raises: ValueError: if the value is not a `Tensor` with dtype tf.float32. """ - if not (isinstance(value, ops.Tensor) and value.dtype.is_floating): + if not (isinstance(value, tensor.Tensor) and value.dtype.is_floating): raise ValueError('Regression output value must be a float32 Tensor; ' 'got {}'.format(value)) self._value = value @@ -355,7 +356,7 @@ def _wrap_and_check_metrics(self, metrics): val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX - if not isinstance(metric_val, ops.Tensor): + if not isinstance(metric_val, tensor.Tensor): raise ValueError( '{} output value must be a Tensor; got {}.'.format( key, metric_val)) @@ -368,7 +369,7 @@ def _wrap_and_check_metrics(self, metrics): # We must wrap any ops (or variables) in a Tensor before export, as the # SignatureDef proto expects tensors only. See b/109740581 metric_op_tensor = metric_op - if not isinstance(metric_op, ops.Tensor): + if not isinstance(metric_op, tensor.Tensor): with ops.control_dependencies([metric_op]): metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') diff --git a/tensorflow/python/saved_model/model_utils/export_output_test.py b/tensorflow/python/saved_model/model_utils/export_output_test.py index 072208e9f30868..9c84e544ec35d2 100644 --- a/tensorflow/python/saved_model/model_utils/export_output_test.py +++ b/tensorflow/python/saved_model/model_utils/export_output_test.py @@ -20,8 +20,8 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import metrics as metrics_module @@ -385,15 +385,15 @@ def test_metric_op_is_tensor(self): self.assertTrue(outputter.metrics['metrics_1/update_op'].name.startswith( 'mean/update_op')) self.assertIsInstance( - outputter.metrics['metrics_1/update_op'], ops.Tensor) - self.assertIsInstance(outputter.metrics['metrics_1/value'], ops.Tensor) + outputter.metrics['metrics_1/update_op'], tensor.Tensor) + self.assertIsInstance(outputter.metrics['metrics_1/value'], tensor.Tensor) self.assertEqual(outputter.metrics['metrics_2/value'], metrics['metrics_2'][0]) self.assertTrue(outputter.metrics['metrics_2/update_op'].name.startswith( 'metric_op_wrapper')) self.assertIsInstance( - outputter.metrics['metrics_2/update_op'], ops.Tensor) + outputter.metrics['metrics_2/update_op'], tensor.Tensor) if __name__ == '__main__': diff --git a/tensorflow/python/saved_model/nested_structure_coder_test.py b/tensorflow/python/saved_model/nested_structure_coder_test.py index f010471da138de..c2b9e12d437605 100644 --- a/tensorflow/python/saved_model/nested_structure_coder_test.py +++ b/tensorflow/python/saved_model/nested_structure_coder_test.py @@ -26,10 +26,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec @@ -165,7 +164,7 @@ def testDtype(self): self.assertEqual(structure, decoded) def testEncodeDecodeTensorSpec(self): - structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64, "hello")] + structure = [tensor.TensorSpec([1, 2, 3], dtypes.int64, "hello")] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() @@ -181,7 +180,7 @@ def testEncodeDecodeTensorSpec(self): self.assertEqual(structure, decoded) def testEncodeDecodeTensorSpecWithNoName(self): - structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64)] + structure = [tensor.TensorSpec([1, 2, 3], dtypes.int64)] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() @@ -276,12 +275,12 @@ def testEncodeDecodeExtensionTypeSpec(self): class Zoo(extension_type.ExtensionType): __name__ = "tf.nested_structure_coder_test.Zoo" zookeepers: typing.Tuple[str, ...] - animals: typing.Mapping[str, ops.Tensor] + animals: typing.Mapping[str, tensor.Tensor] structure = [ Zoo.Spec( zookeepers=["Zoey", "Zack"], - animals={"tiger": tensor_spec.TensorSpec([16])}) + animals={"tiger": tensor.TensorSpec([16])}) ] self.assertTrue(nested_structure_coder.can_encode(structure)) @@ -327,8 +326,7 @@ def testDecodeUnknownTensorSpec(self): def testEncodeDecodeBoundedTensorSpec(self): structure = [ - tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10, - "hello_0_10") + tensor.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10, "hello_0_10") ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) @@ -350,8 +348,7 @@ def testEncodeDecodeBoundedTensorSpec(self): def testEncodeDecodeBoundedTensorSpecNoName(self): structure = [ - tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2, - (1, 1, 20)) + tensor.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2, (1, 1, 20)) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) @@ -378,7 +375,7 @@ def testEncodeDataSetSpec(self): dataset_ops.DatasetSpec({ "rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32), "st": sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32), - "t": tensor_spec.TensorSpec([10, 8], dtypes.string) + "t": tensor.TensorSpec([10, 8], dtypes.string) }) ] self.assertTrue(nested_structure_coder.can_encode(structure)) diff --git a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py index 99a2802f631ad9..94425f2ab92174 100644 --- a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py +++ b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py @@ -89,6 +89,23 @@ def test_read_saved_model_singleprint_from_sm(self): "12074714563970609759", # saved_object_graph_hash ])) + def test_read_chunked_saved_model_fingerprint(self): + if is_oss: + self.skipTest("Experimental image format disabled in OSS.") + export_dir = test.test_src_dir_path( + "cc/saved_model/testdata/chunked_saved_model/chunked_model") + fingerprint = fingerprint_pb2.FingerprintDef().FromString( + pywrap_fingerprinting.CreateFingerprintDef(export_dir)) + self.assertGreater(fingerprint.saved_model_checksum, 0) + # We test for multiple fingerprints due to non-determinism when building + # with different compilation_mode flag options. + self.assertIn(fingerprint.graph_def_program_hash, + [906548630859202535, 9562420523583756263]) + self.assertEqual(fingerprint.signature_def_hash, 1043582354059066488) + self.assertIn(fingerprint.saved_object_graph_hash, + [11894619660760763927, 2766043449526180728]) + self.assertEqual(fingerprint.checkpoint_hash, 0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index a0f6478ae55fa1..87fd9151bba05c 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io @@ -939,7 +940,7 @@ def testTrainOp(self): "AssignAddVariableOp") else: self.assertIsInstance( - loader_impl.get_train_op(meta_graph_def), ops.Tensor) + loader_impl.get_train_op(meta_graph_def), tensor_lib.Tensor) def testTrainOpGroup(self): export_dir = self._get_export_dir("test_train_op_group") @@ -995,7 +996,7 @@ def testTrainOpAfterVariables(self): "AssignAddVariableOp") else: self.assertIsInstance( - loader_impl.get_train_op(meta_graph_def), ops.Tensor) + loader_impl.get_train_op(meta_graph_def), tensor_lib.Tensor) with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["pre_foo"], export_dir) diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index b2911b174b2239..5de175e2a85687 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -19,6 +19,7 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import utils_impl as utils @@ -104,7 +105,7 @@ def regression_signature_def(examples, predictions): """ if examples is None: raise ValueError('Regression `examples` cannot be None.') - if not isinstance(examples, ops.Tensor): + if not isinstance(examples, tensor_lib.Tensor): raise ValueError('Expected regression `examples` to be of type Tensor. ' f'Found `examples` of type {type(examples)}.') if predictions is None: @@ -157,7 +158,7 @@ def classification_signature_def(examples, classes, scores): """ if examples is None: raise ValueError('Classification `examples` cannot be None.') - if not isinstance(examples, ops.Tensor): + if not isinstance(examples, tensor_lib.Tensor): raise ValueError('Classification `examples` must be a string Tensor. ' f'Found `examples` of type {type(examples)}.') if classes is None and scores is None: diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py index 9cadfc9076e3ff..38362c8087a838 100644 --- a/tensorflow/python/saved_model/signature_serialization.py +++ b/tensorflow/python/saved_model/signature_serialization.py @@ -20,8 +20,7 @@ from tensorflow.python.eager import function as defun from tensorflow.python.eager.polymorphic_function import attributes from tensorflow.python.framework import composite_tensor -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.ops import resource_variable_ops from tensorflow.python.saved_model import function_serialization from tensorflow.python.saved_model import revived_types @@ -192,7 +191,7 @@ def signature_wrapper(**kwargs): if signature_function.structured_input_signature is not None: # The structured input signature may contain other non-tensor arguments. inputs = filter( - lambda x: isinstance(x, tensor_spec.TensorSpec), + lambda x: isinstance(x, tensor.TensorSpec), nest.flatten( signature_function.structured_input_signature, expand_composites=True, @@ -207,10 +206,10 @@ def signature_wrapper(**kwargs): inputs, ): keyword = compat.as_str(keyword) - if isinstance(inp, tensor_spec.TensorSpec): - spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword) + if isinstance(inp, tensor.TensorSpec): + spec = tensor.TensorSpec(inp.shape, inp.dtype, name=keyword) else: - spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword) + spec = tensor.TensorSpec.from_tensor(inp, name=keyword) tensor_spec_signature[keyword] = spec final_concrete = wrapped_function._get_concrete_function_garbage_collected( # pylint: disable=protected-access **tensor_spec_signature @@ -240,7 +239,7 @@ def signature_wrapper(**kwargs): arg_names[-len_default:], # pylint: disable=protected-access flattened_defaults or [], ): - if not isinstance(default, ops.Tensor): + if not isinstance(default, tensor.Tensor): continue defaults.setdefault(signature_key, {})[arg] = default return concrete_signatures, wrapped_functions, defaults @@ -269,7 +268,7 @@ def _normalize_outputs(outputs, function_name, signature_key): f"the function {compat.as_str_any(function_name)} used to generate " f"the SavedModel signature {signature_key!r}." ) - if not isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): + if not isinstance(value, (tensor.Tensor, composite_tensor.CompositeTensor)): raise ValueError( f"Got a non-Tensor value {value!r} for key {key!r} in the output of " f"the function {compat.as_str_any(function_name)} used to generate " diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py index a30cd5253d1136..52d44b8999b603 100644 --- a/tensorflow/python/saved_model/utils_test.py +++ b/tensorflow/python/saved_model/utils_test.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -102,7 +103,7 @@ def testGetTensorFromInfoDense(self): expected = array_ops.placeholder(dtypes.float32, 1, name="x") tensor_info = utils.build_tensor_info(expected) actual = utils.get_tensor_from_tensor_info(tensor_info) - self.assertIsInstance(actual, ops.Tensor) + self.assertIsInstance(actual, tensor.Tensor) self.assertEqual(expected.name, actual.name) @test_util.run_v1_only( @@ -134,7 +135,7 @@ def testGetTensorFromInfoInOtherGraph(self): array_ops.placeholder(dtypes.float32, 1, name="other") actual = utils.get_tensor_from_tensor_info(tensor_info, graph=expected_graph) - self.assertIsInstance(actual, ops.Tensor) + self.assertIsInstance(actual, tensor.Tensor) self.assertIs(actual.graph, expected_graph) self.assertEqual(expected.name, actual.name) diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index a8f7a81b14954c..d3e94328c64911 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -263,6 +263,7 @@ pytype_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/lib/io:lib", "//tensorflow/python/ops:array_ops", @@ -781,6 +782,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -827,6 +829,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/ops:embedding_ops", @@ -876,6 +879,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:embedding_ops", "//tensorflow/python/ops:math_ops", @@ -926,7 +930,7 @@ tpu_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/lib/io:lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:cond", diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index 8a2e69e46961c5..e8550ddeb1de87 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops @@ -1300,7 +1301,7 @@ def _filter_execution_path_operations(self, operations, fetches): for fetch in fetches: if isinstance(fetch, ops.Operation): op_fetches.append(fetch) - elif isinstance(fetch, ops.Tensor): + elif isinstance(fetch, tensor_lib.Tensor): op_fetches.append(fetch.op) else: raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' @@ -1741,7 +1742,7 @@ def _process_tensor_fetches(self, tensor_fetches): 'empty list.') fetches = [] for fetch in tensor_fetches: - if isinstance(fetch, ops.Tensor): + if isinstance(fetch, tensor_lib.Tensor): fetches.append(fetch) else: raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch) @@ -1759,7 +1760,7 @@ def _process_op_fetches(self, op_fetches): for fetch in op_fetches: if isinstance(fetch, ops.Operation): fetches.append(fetch) - elif isinstance(fetch, ops.Tensor): + elif isinstance(fetch, tensor_lib.Tensor): fetches.append(fetch.op) else: logging.warning('Ignoring the given op_fetch:%s, which is not an op.' % @@ -1768,7 +1769,7 @@ def _process_op_fetches(self, op_fetches): def _convert_fetches_to_input_format(self, input_fetches, current_fetches): """Changes current_fetches' format, so that it matches input_fetches.""" - if isinstance(input_fetches, ops.Tensor): + if isinstance(input_fetches, tensor_lib.Tensor): if len(current_fetches) != 1: raise RuntimeError('Tensor tracer input/output fetches do not match.') return current_fetches[0] diff --git a/tensorflow/python/tpu/tpu_embedding_for_serving.py b/tensorflow/python/tpu/tpu_embedding_for_serving.py index 9914e084bb1f37..fb8a3205e1358c 100644 --- a/tensorflow/python/tpu/tpu_embedding_for_serving.py +++ b/tensorflow/python/tpu/tpu_embedding_for_serving.py @@ -22,6 +22,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import embedding_ops @@ -292,7 +293,7 @@ def serve_tensors(embedding_features): table = tables[feature.table] if weight is not None: - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor.Tensor): raise ValueError( "Weight specified for {}, but input is dense.".format(path)) elif type(weight) is not type(inp): @@ -303,7 +304,7 @@ def serve_tensors(embedding_features): raise ValueError("Weight specified for {}, but this is a sequence " "feature.".format(path)) - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor.Tensor): if feature.max_sequence_length > 0: raise ValueError("Feature {} is a sequence feature but a dense tensor " "was passed.".format(path)) @@ -324,7 +325,7 @@ def serve_tensors(embedding_features): def _embedding_lookup_for_sparse_tensor( inp: sparse_tensor.SparseTensor, weight: Optional[sparse_tensor.SparseTensor], table: tf_variables.Variable, - feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: + feature: tpu_embedding_v2_utils.FeatureConfig) -> tensor.Tensor: """Embedding lookup for sparse tensor based on its feature config. Args: @@ -380,7 +381,7 @@ def _embedding_lookup_for_sparse_tensor( def _embedding_lookup_for_ragged_tensor( inp: ragged_tensor.RaggedTensor, weight: Optional[ragged_tensor.RaggedTensor], table: tf_variables.Variable, - feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: + feature: tpu_embedding_v2_utils.FeatureConfig) -> tensor.Tensor: """Embedding lookup for ragged tensor based on its feature config. Args: diff --git a/tensorflow/python/tpu/tpu_embedding_v1.py b/tensorflow/python/tpu/tpu_embedding_v1.py index 7b19500025bbc1..259650cd9f8396 100644 --- a/tensorflow/python/tpu/tpu_embedding_v1.py +++ b/tensorflow/python/tpu/tpu_embedding_v1.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops @@ -179,9 +180,9 @@ def _maybe_build(self): def _apply_combiner_to_embeddings( self, - embeddings: ops.Tensor, - weight: ops.Tensor, - combiner: Optional[Text] = None) -> ops.Tensor: + embeddings: tensor.Tensor, + weight: tensor.Tensor, + combiner: Optional[Text] = None) -> tensor.Tensor: """Apply the combiner to the embedding look up result on second to last axis. Args: @@ -213,8 +214,9 @@ def _apply_combiner_to_embeddings( f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}") return embeddings - def _pad_or_truncate_with_sequence_length(self, embeddings: ops.Tensor, - sequence_length: int) -> ops.Tensor: + def _pad_or_truncate_with_sequence_length( + self, embeddings: tensor.Tensor, sequence_length: int + ) -> tensor.Tensor: """Pad or truncate the embedding lookup result based on the sequence length. Args: @@ -272,7 +274,7 @@ def embedding_lookup(self, table = self.embedding_tables[feature.table] if weight is not None: - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor.Tensor): raise ValueError( "Weight specified for {}, but input is dense.".format(path)) elif type(weight) is not type(inp): @@ -283,7 +285,7 @@ def embedding_lookup(self, raise ValueError("Weight specified for {}, but this is a sequence " "feature.".format(path)) - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor.Tensor): if feature.max_sequence_length > 0: raise ValueError( "Feature {} is a sequence feature but a dense tensor " @@ -307,7 +309,7 @@ def _embedding_lookup_for_sparse_tensor( self, inp: sparse_tensor.SparseTensor, weight: Optional[sparse_tensor.SparseTensor], table: tf_variables.Variable, - feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: + feature: tpu_embedding_v2_utils.FeatureConfig) -> tensor.Tensor: """Embedding lookup for sparse tensor based on its feature config. Args: @@ -352,7 +354,7 @@ def _embedding_lookup_for_ragged_tensor( self, inp: ragged_tensor.RaggedTensor, weight: Optional[ragged_tensor.RaggedTensor], table: tf_variables.Variable, - feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: + feature: tpu_embedding_v2_utils.FeatureConfig) -> tensor.Tensor: """Embedding lookup for ragged tensor based on its feature config. Args: @@ -398,7 +400,10 @@ def ragged_to_dense_outside_compilation(inp, weight, batch_size, feature): # If the data batch size is a factor of the output batch size, the # divide result will be the sequence length. Ignore the weights and # combiner. - elif output_batch_size > batch_size and output_batch_size % batch_size == 0: + elif ( + output_batch_size > batch_size + and output_batch_size % batch_size == 0 + ): # Pad or truncate in the sequence dimension seq_length = output_batch_size // batch_size inp = inp.to_tensor(shape=(batch_size, seq_length)) diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index 787bf6e23bd6b0..bf954ac0a55e77 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework.tensor_shape import TensorShape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -687,7 +688,7 @@ def tpu_step(tpu_features): full_output_shape = [x * num_cores_per_replica for x in output_shape] + [ feature.table.dim ] - if gradient is not None and not isinstance(gradient, ops.Tensor): + if gradient is not None and not isinstance(gradient, tensor_lib.Tensor): raise ValueError( f"found non-tensor type: {type(gradient)} at path {path}.") if gradient is not None: @@ -992,7 +993,7 @@ def _generate_enqueue_op( # early. for inp, weight, (path, feature) in zip( flat_inputs, flat_weights, flat_features): - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor_lib.Tensor): self._add_data_for_tensor(inp, weight, indices_or_row_splits, values, weights, int_zeros, float_zeros, path) elif isinstance(inp, sparse_tensor.SparseTensor): @@ -1310,7 +1311,7 @@ def generate_enqueue_ops(): def _split_fn(ts, idx): if ts is None: return None - elif isinstance(ts, ops.Tensor): + elif isinstance(ts, tensor_lib.Tensor): return array_ops.split( ts, num_or_size_splits=self._num_cores_per_replica, @@ -1389,7 +1390,7 @@ def _get_input_shapes( else: tensor = maybe_tensor - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): input_shapes.append( self._get_input_shape_for_tensor(tensor, feature, per_replica, path) ) diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index eb0ce826ca69c6..505aea97aa021a 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.lib.io import tf_record from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond @@ -812,11 +812,11 @@ def train_step(x): partitioned_tpu_fn = _tpu_partitioned_call_wrapper(tpu_fn) concrete = partitioned_tpu_fn.get_concrete_function( - x=tensor_spec.TensorSpec( + x=tensor.TensorSpec( shape=(1), dtype=dtypes.float32, name="input_tensor")) self.assertIsInstance( - concrete(array_ops.ones((1), dtype=dtypes.float32))[0], ops.Tensor) + concrete(array_ops.ones((1), dtype=dtypes.float32))[0], tensor.Tensor) if __name__ == "__main__": diff --git a/tensorflow/python/trackable/BUILD b/tensorflow/python/trackable/BUILD index f84fad992352b0..6d787f3c242e86 100644 --- a/tensorflow/python/trackable/BUILD +++ b/tensorflow/python/trackable/BUILD @@ -185,6 +185,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:tf_decorator", "//tensorflow/python/util:tf_export", ], diff --git a/tensorflow/python/trackable/resource.py b/tensorflow/python/trackable/resource.py index 823d70b8f10c34..ee4a5c1361cbba 100644 --- a/tensorflow/python/trackable/resource.py +++ b/tensorflow/python/trackable/resource.py @@ -21,6 +21,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.trackable import base from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export @@ -152,7 +153,7 @@ def _resource_handle(self): @_resource_handle.setter def _resource_handle(self, value): - if isinstance(value, (ops.Tensor, ops.EagerTensor)): + if isinstance(value, (tensor.Tensor, ops.EagerTensor)): value._parent_trackable = weakref.ref(self) # pylint: disable=protected-access self._resource_handle_value = value diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index 30b5f7bf81f2f4..c134527b87b731 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -297,6 +297,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/layers:layers_util", "//tensorflow/python/ops:array_ops", @@ -347,6 +348,7 @@ py_strict_library( "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/ops:init_ops", @@ -379,6 +381,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/ops:gradients", @@ -532,6 +535,7 @@ py_strict_library( "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/ops:data_flow_ops", @@ -1080,6 +1084,7 @@ py_strict_library( "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:init_ops", "//tensorflow/python/ops:resource_variable_ops", @@ -1687,6 +1692,7 @@ tf_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_assert", diff --git a/tensorflow/python/training/experimental/BUILD b/tensorflow/python/training/experimental/BUILD index 02c64add3f36a1..604c97765c0710 100644 --- a/tensorflow/python/training/experimental/BUILD +++ b/tensorflow/python/training/experimental/BUILD @@ -87,6 +87,7 @@ py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py index eaa9a55d5c5a73..e42a95d65c2acb 100644 --- a/tensorflow/python/training/experimental/loss_scale_test.py +++ b/tensorflow/python/training/experimental/loss_scale_test.py @@ -22,6 +22,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -86,7 +87,7 @@ def test_serialization(self): @test_util.run_in_graph_and_eager_modes def test_call_type(self): scalar = loss_scale_module.FixedLossScale(123) - self.assertIsInstance(scalar(), ops.Tensor) + self.assertIsInstance(scalar(), tensor_lib.Tensor) @test_util.run_in_graph_and_eager_modes def test_repr(self): @@ -301,7 +302,7 @@ def test_get(self): @test_util.run_in_graph_and_eager_modes def test_call_type(self): scalar = loss_scale_module.DynamicLossScale() - self.assertIsInstance(scalar(), ops.Tensor) + self.assertIsInstance(scalar(), tensor_lib.Tensor) @parameterized.named_parameters(*TESTCASES) @test_util.run_in_graph_and_eager_modes diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index ee1d019f75630e..23bd73220042c1 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -251,7 +252,7 @@ def string_input_producer(string_tensor, @end_compatibility """ not_null_err = "string_input_producer requires a non-null input tensor" - if not isinstance(string_tensor, ops.Tensor) and not string_tensor: + if not isinstance(string_tensor, tensor_lib.Tensor) and not string_tensor: raise ValueError(not_null_err) with ops.name_scope(name, "input_producer", [string_tensor]) as name: diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 6e0e1a99ec2205..7ad3e772f2a983 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_assert @@ -84,8 +85,9 @@ def test_defaults_empty_graph(self): self.assertTrue(isinstance(scaffold.init_op, ops.Operation)) self.assertEqual(None, scaffold.init_feed_dict) self.assertEqual(None, scaffold.init_fn) - self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor)) - self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) + self.assertTrue(isinstance(scaffold.ready_op, tensor.Tensor)) + self.assertTrue(isinstance( + scaffold.ready_for_local_init_op, tensor.Tensor)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) self.assertEqual(None, scaffold.local_init_feed_dict) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) @@ -107,8 +109,9 @@ def test_defaults_no_variables(self): self.assertTrue(isinstance(scaffold.init_op, ops.Operation)) self.assertEqual(None, scaffold.init_feed_dict) self.assertEqual(None, scaffold.init_fn) - self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor)) - self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) + self.assertTrue(isinstance(scaffold.ready_op, tensor.Tensor)) + self.assertTrue(isinstance( + scaffold.ready_for_local_init_op, tensor.Tensor)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) self.assertEqual(None, scaffold.local_init_feed_dict) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 1f3d05a2163d2f..d310f3488f1524 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -17,6 +17,7 @@ from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops @@ -531,7 +532,7 @@ def apply(self, var_list=None): if var_list is None: var_list = variables.trainable_variables() for v in var_list: - if (isinstance(v, ops.Tensor) + if (isinstance(v, tensor.Tensor) and ops.executing_eagerly_outside_functions()): raise TypeError( "tf.train.ExponentialMovingAverage does not support non-Variable" diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 4af9a1ebe43664..aa59f2e343cdf8 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients @@ -115,7 +116,7 @@ def target(self): return self._v._ref() # pylint: disable=protected-access def update_op(self, optimizer, g): - if isinstance(g, ops.Tensor): + if isinstance(g, tensor.Tensor): update_op = optimizer._apply_dense(g, self._v) # pylint: disable=protected-access if self._v.constraint is not None: with ops.control_dependencies([update_op]): @@ -197,7 +198,7 @@ def update_op(self, optimizer, g): def _get_processor(v): """The processor of v.""" if context.executing_eagerly(): - if isinstance(v, ops.Tensor): + if isinstance(v, tensor.Tensor): return _TensorProcessor(v) else: return _DenseResourceVariableProcessor(v) @@ -208,7 +209,7 @@ def _get_processor(v): return _DenseResourceVariableProcessor(v) if isinstance(v, variables.Variable): return _RefVariableProcessor(v) - if isinstance(v, ops.Tensor): + if isinstance(v, tensor.Tensor): return _TensorProcessor(v) raise NotImplementedError("Trying to optimize unsupported type ", v) @@ -690,7 +691,7 @@ def apply_gradients( raise TypeError( "Gradient must be convertible to a Tensor" " or IndexedSlices, or None: %s" % g) - if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)): + if not isinstance(g, (tensor.Tensor, indexed_slices.IndexedSlices)): raise TypeError( "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) p = _get_processor(v) @@ -739,7 +740,7 @@ def apply_gradients( apply_updates = state_ops.assign_add(global_step, 1, name=name) if not context.executing_eagerly(): - if isinstance(apply_updates, ops.Tensor): + if isinstance(apply_updates, tensor.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if apply_updates not in train_op: @@ -791,7 +792,7 @@ def update(v, g): except TypeError: raise TypeError("Gradient must be convertible to a Tensor" " or IndexedSlices, or None: %s" % g) - if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)): + if not isinstance(g, (tensor.Tensor, indexed_slices.IndexedSlices)): raise TypeError( "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) p = _get_processor(v) @@ -834,7 +835,7 @@ def finish(self, update_ops): kwargs={"name": name}) if not context.executing_eagerly(): - if isinstance(apply_updates, ops.Tensor): + if isinstance(apply_updates, tensor.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if apply_updates not in train_op: diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD index 52a0af997f706c..4538bbdf973775 100644 --- a/tensorflow/python/training/saving/BUILD +++ b/tensorflow/python/training/saving/BUILD @@ -51,6 +51,7 @@ py_strict_library( "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops_gen", diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index ecf4d319df58d7..c14361a23f2a14 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -94,7 +95,7 @@ class ResourceVariableSaveable(saveable_object.SaveableObject): def __init__(self, var, slice_spec, name): self._var_device = var.device self._var_shape = var.shape - if isinstance(var, ops.Tensor): + if isinstance(var, tensor_lib.Tensor): self.handle_op = var.op.inputs[0] tensor = var elif resource_variable_ops.is_resource_variable(var): @@ -145,7 +146,7 @@ def restore(self, restored_tensors, restored_shapes): def _tensor_comes_from_variable(v): - return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS + return isinstance(v, tensor_lib.Tensor) and v.op.type in _VARIABLE_OPS def saveable_objects_for_op(op, name): @@ -589,7 +590,7 @@ def restore(self, restored_tensors, restored_shapes): if not ops.executing_eagerly_outside_functions() and any([ spec._tensor.op.type in _REF_VARIABLE_OPS for spec in self.specs - if isinstance(spec._tensor, ops.Tensor)]): + if isinstance(spec._tensor, tensor_lib.Tensor)]): return restore_fn(restored_tensor_dict) # pylint: enable=protected-access diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py index 195c928764a0f4..26c26dc1ff7627 100644 --- a/tensorflow/python/training/sync_replicas_optimizer.py +++ b/tensorflow/python/training/sync_replicas_optimizer.py @@ -17,6 +17,7 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -277,7 +278,7 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): if grad is None: aggregated_grad.append(None) # pass-through. continue - elif isinstance(grad, ops.Tensor): + elif isinstance(grad, tensor.Tensor): grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=var.get_shape(), diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 050828f7637c6d..778ad9771f8591 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -17,6 +17,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import cond from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops @@ -333,7 +334,7 @@ def assert_global_step(global_step_tensor): global_step_tensor: `Tensor` to test. """ if not (isinstance(global_step_tensor, variables.Variable) or - isinstance(global_step_tensor, ops.Tensor) or + isinstance(global_step_tensor, tensor.Tensor) or resource_variable_ops.is_resource_variable(global_step_tensor)): raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' % global_step_tensor) diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py index f6d6158a12f22a..83b67ef71ba3fd 100644 --- a/tensorflow/python/types/core.py +++ b/tensorflow/python/types/core.py @@ -56,7 +56,7 @@ def shape(self): pass -# `ops.EagerTensor` subclasses `Symbol` by way of subclassing `ops.Tensor`; +# `ops.EagerTensor` subclasses `Symbol` by way of subclassing `tensor.Tensor`; # care should be taken when performing `isinstance` checks on `Value`, e.g.: # # ``` diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index d74446c38cbf5b..1ac862f43a66eb 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -253,7 +253,7 @@ tf_py_strict_test( ":tf_inspect", "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", @@ -276,6 +276,7 @@ tf_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", @@ -1115,6 +1116,7 @@ tf_py_strict_test( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/ops:variables", diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index 024ef220260417..898af79480875e 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -23,7 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -1062,14 +1062,14 @@ def test_deprecated_arg_values_when_value_is_none(self, mock_warning): def _fn(arg0): # pylint: disable=unused-argument pass - ops.enable_tensor_equality() + tensor.enable_tensor_equality() initial_count = mock_warning.call_count # Check that we avoid error from explicit `var == None` check. _fn(arg0=variables.Variable(0)) self.assertEqual(initial_count, mock_warning.call_count) _fn(arg0=None) self.assertEqual(initial_count + 1, mock_warning.call_count) - ops.disable_tensor_equality() + tensor.disable_tensor_equality() class DeprecationArgumentsTest(test.TestCase): diff --git a/tensorflow/python/util/dispatch_test.py b/tensorflow/python/util/dispatch_test.py index db01441afba9f5..7bb8e8f8898f6a 100644 --- a/tensorflow/python/util/dispatch_test.py +++ b/tensorflow/python/util/dispatch_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -105,13 +106,13 @@ def is_tensor_like(self): @classmethod def _overload_all_operators(cls): # pylint: disable=invalid-name """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: cls._overload_operator(operator) @classmethod def _overload_operator(cls, operator): # pylint: disable=invalid-name - """Overload an operator with the same overloading as `ops.Tensor`.""" - tensor_oper = getattr(ops.Tensor, operator) + """Overload an operator with the same overloading as `tensor_lib.Tensor`.""" + tensor_oper = getattr(tensor_lib.Tensor, operator) # Compatibility with Python 2: # Python 2 unbound methods have type checks for the first arg, @@ -459,13 +460,13 @@ def testGlobalDispatcherLinearOperators(self): class MaskedTensor(extension_type.ExtensionType): """Simple ExtensionType for testing v2 dispatch.""" - values: ops.Tensor - mask: ops.Tensor + values: tensor_lib.Tensor + mask: tensor_lib.Tensor class SillyTensor(extension_type.ExtensionType): """Simple ExtensionType for testing v2 dispatch.""" - value: ops.Tensor + value: tensor_lib.Tensor how_silly: float @@ -565,7 +566,7 @@ def masked_concat(values, axis, name=None): dispatch.unregister_dispatch_for(masked_concat) def testDispatchForUnion(self): - MaybeMasked = typing.Union[MaskedTensor, ops.Tensor] + MaybeMasked = typing.Union[MaskedTensor, tensor_lib.Tensor] @dispatch.dispatch_for_api(math_ops.add, { "x": MaybeMasked, @@ -936,7 +937,8 @@ def testGetApisWithTypeBasedDispatch(self): self.assertIn(array_ops.concat, dispatch_apis) def testTypeBasedDispatchTargetsFor(self): - MaskedTensorList = typing.List[typing.Union[MaskedTensor, ops.Tensor]] + MaskedTensorList = typing.List[ + typing.Union[MaskedTensor, tensor_lib.Tensor]] try: @dispatch.dispatch_for_api(math_ops.add) diff --git a/tensorflow/python/util/variable_utils_test.py b/tensorflow/python/util/variable_utils_test.py index 41c81812e3322b..9aaa0d0e1b5fc8 100644 --- a/tensorflow/python/util/variable_utils_test.py +++ b/tensorflow/python/util/variable_utils_test.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -65,9 +66,9 @@ def test_convert_variables_to_tensors(self): results = variable_utils.convert_variables_to_tensors(data) expected_results = [1, 2, 3, [4], 5, ct] # Only ResourceVariables are converted to Tensors. - self.assertIsInstance(results[0], ops.Tensor) - self.assertIsInstance(results[1], ops.Tensor) - self.assertIsInstance(results[2], ops.Tensor) + self.assertIsInstance(results[0], tensor.Tensor) + self.assertIsInstance(results[1], tensor.Tensor) + self.assertIsInstance(results[2], tensor.Tensor) self.assertIsInstance(results[3], list) self.assertIsInstance(results[4], int) self.assertIs(results[5], ct) @@ -82,7 +83,7 @@ def test_convert_variables_in_composite_tensor(self): self.assertIsInstance(ct2.component, resource_variable_ops.ResourceVariable) result = variable_utils.convert_variables_to_tensors(ct2) - self.assertIsInstance(result.component, ops.Tensor) + self.assertIsInstance(result.component, tensor.Tensor) self.assertAllEqual(result.component, 1) def test_replace_variables_with_atoms(self): @@ -99,7 +100,7 @@ def test_replace_variables_with_atoms(self): # Only ResourceVariables are replaced with int 0s. self.assertIsInstance(results[0], int) self.assertIsInstance(results[1], int) - self.assertIsInstance(results[2], ops.Tensor) + self.assertIsInstance(results[2], tensor.Tensor) self.assertIsInstance(results[3], list) self.assertIsInstance(results[4], int) results[2] = self.evaluate(results[2]) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index faa9b490b6d324..ca57f9081e4643 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -29,11 +29,13 @@ load( load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", + "if_tensorrt_exec", ) load( "@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda", + "if_cuda_exec", ) load( "@local_config_rocm//rocm:build_defs.bzl", @@ -61,6 +63,12 @@ load( "//third_party/llvm_openmp:openmp.bzl", "windows_llvm_openmp_linkopts", ) +load( + "//tensorflow:py.default.bzl", + _plain_py_binary = "py_binary", + _plain_py_library = "py_library", + _plain_py_test = "py_test", +) load("@bazel_skylib//lib:new_sets.bzl", "sets") load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") @@ -457,6 +465,16 @@ def tf_copts( }) ) +def tf_copts_exec( + android_optimization_level_override = "-O2", + is_external = False, + allow_exceptions = False): + return tf_copts( + android_optimization_level_override, + is_external, + allow_exceptions, + ) + if_cuda_exec(["-DGOOGLE_CUDA=1"]) + if_tensorrt_exec(["-DGOOGLE_TENSORRT=1"]) + def tf_openmp_copts(): # We assume when compiling on Linux gcc/clang will be used and MSVC on Windows return select({ @@ -549,7 +567,7 @@ def tf_gen_op_libs( for n in op_lib_names: cc_library( name = n + "_op_lib", - copts = tf_copts(is_external = is_external), + copts = tf_copts_exec(is_external = is_external), features = features, srcs = [sub_directory + n + ".cc"], deps = deps + [clean_dep("//tensorflow/core:framework")], @@ -1314,32 +1332,6 @@ generate_op_reg_offsets = rule( implementation = _generate_op_reg_offsets_impl, ) -# Generates a Python library target wrapping the ops registered in "deps". -# -# Args: -# name: used as the name of the generated target and as a name component of -# the intermediate files. -# out: name of the python file created by this rule. If None, then -# "ops/gen_{name}.py" is used. -# hidden: Optional list of ops names to make private in the Python module. -# It is invalid to specify both "hidden" and "op_allowlist". -# visibility: passed to py_library. -# deps: list of dependencies for the intermediate tool used to generate the -# python target. NOTE these `deps` are not applied to the final python -# library target itself. -# require_shape_functions: Unused. Leave this as False. -# hidden_file: optional file that contains a list of op names to make private -# in the generated Python module. Each op name should be on a line by -# itself. Lines that start with characters that are invalid op name -# starting characters are treated as comments and ignored. -# generated_target_name: name of the generated target (overrides the -# "name" arg) -# op_whitelist: [DEPRECATED] an older spelling for "op_allowlist" -# op_allowlist: if not empty, only op names in this list will be wrapped. It -# is invalid to specify both "hidden" and "op_allowlist". -# cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the -# specified ops. - def tf_gen_op_wrapper_py( name, out = None, @@ -1357,7 +1349,39 @@ def tf_gen_op_wrapper_py( testonly = False, copts = [], extra_py_deps = None, - py_lib_rule = native.py_library): + py_lib_rule = _plain_py_library): + """Generates a Python library target wrapping the ops registered in "deps". + + Args: + name: used as the name of the generated target and as a name component of + the intermediate files. + out: name of the python file created by this rule. If None, then + "ops/gen_{name}.py" is used. + hidden: Optional list of ops names to make private in the Python module. + It is invalid to specify both "hidden" and "op_allowlist". + visibility: passed to py_library. + deps: list of dependencies for the intermediate tool used to generate the + python target. NOTE these `deps` are not applied to the final python + library target itself. + require_shape_functions: Unused. Leave this as False. + hidden_file: optional file that contains a list of op names to make private + in the generated Python module. Each op name should be on a line by + itself. Lines that start with characters that are invalid op name + starting characters are treated as comments and ignored. + generated_target_name: name of the generated target (overrides the + "name" arg) + op_whitelist: [DEPRECATED] an older spelling for "op_allowlist" + op_allowlist: if not empty, only op names in this list will be wrapped. It + is invalid to specify both "hidden" and "op_allowlist". + cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the + specified ops. + api_def_srcs: undocumented. + compatible_with: undocumented. + testonly: undocumented. + copts: undocumented. + extra_py_deps: undocumented. + py_lib_rule: undocumented. + """ _ = require_shape_functions # Unused. if op_whitelist and op_allowlist: fail("op_whitelist is deprecated. Only use op_allowlist.") @@ -1377,7 +1401,7 @@ def tf_gen_op_wrapper_py( deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))] tf_cc_binary( name = tool_name, - copts = copts + tf_copts(), + copts = copts + tf_copts_exec(), linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + cc_linkopts, linkstatic = 1, # Faster to link this one-time-use binary dynamically visibility = [clean_dep("//tensorflow:internal")], @@ -1566,7 +1590,7 @@ register_extension_info( label_regex_for_dep = "{extension_name}", ) -# TODO(jakeharmon): Replace with or implement in terms of tsl_gpu_cc_test, which doesn't add a +# TODO(jakeharmon): Replace with an implementation which doesn't add a # dependency on core:common_runtime def tf_gpu_cc_test( name, @@ -2261,7 +2285,8 @@ def tf_custom_op_py_library( deps = [], **kwargs): _ignore = [kernels] - native.py_library( + _make_tags_mutable(kwargs) + _plain_py_library( name = name, data = dso, srcs = srcs, @@ -2448,7 +2473,7 @@ def pywrap_tensorflow_macro_opensource( # link the pyd (which is just a dll) because of missing dependencies. _create_symlink("ml_dtypes.so", "//tensorflow/tsl/python/lib/core:ml_dtypes.so") - native.py_library( + _plain_py_library( name = name, srcs = [":" + name + ".py"], srcs_version = "PY3", @@ -2483,10 +2508,11 @@ pywrap_tensorflow_macro = pywrap_tensorflow_macro_opensource # Note that this only works on Windows. See the definition of # //third_party/tensorflow/tools/pip_package:win_pip_package_marker for specific reasons. # 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test. -def py_test(deps = [], data = [], kernels = [], exec_properties = None, test_rule = native.py_test, **kwargs): +def py_test(deps = [], data = [], kernels = [], exec_properties = None, test_rule = _plain_py_test, **kwargs): if not exec_properties: exec_properties = tf_exec_properties(kwargs) + _make_tags_mutable(kwargs) test_rule( deps = select({ "//conditions:default": deps, @@ -2510,13 +2536,14 @@ register_extension_info( # See https://github.com/tensorflow/tensorflow/issues/22390 def py_binary(name, deps = [], **kwargs): # Add an extra target for dependencies to avoid nested select statement. - native.py_library( + _plain_py_library( name = name + "_deps", deps = deps, ) # Python version placeholder - native.py_binary( + _make_tags_mutable(kwargs) + _plain_py_binary( name = name, deps = select({ "//conditions:default": [":" + name + "_deps"], @@ -2527,7 +2554,18 @@ def py_binary(name, deps = [], **kwargs): def pytype_library(name, pytype_deps = [], pytype_srcs = [], **kwargs): # Types not enforced in OSS. - native.py_library(name = name, **kwargs) + _make_tags_mutable(kwargs) + _plain_py_library(name = name, **kwargs) + +# Tensorflow uses rules_python 0.0.1, and in that version of rules_python, +# the rules require the tags value to be a mutable list because they +# modify it in-place. Later versions of rules_python don't have this +# requirement. +def _make_tags_mutable(kwargs): + if "tags" in kwargs and kwargs["tags"] != None: + # The value might be a frozen list, which looks just like + # a regular list. So always make a copy. + kwargs["tags"] = list(kwargs["tags"]) def tf_py_test( name, @@ -3136,7 +3174,7 @@ def pybind_extension_opensource( testonly = testonly, ) - native.py_library( + _plain_py_library( name = name, data = select({ clean_dep("//tensorflow:windows"): [pyd_file], @@ -3407,9 +3445,6 @@ def tf_grpc_cc_dependencies(): def get_compatible_with_portable(): return [] -def get_compatible_with_cloud(): - return [] - def filegroup(**kwargs): native.filegroup(**kwargs) diff --git a/tensorflow/tensorflow.default.bzl b/tensorflow/tensorflow.default.bzl index 017268250c3c6c..9c6515f9798e5e 100644 --- a/tensorflow/tensorflow.default.bzl +++ b/tensorflow/tensorflow.default.bzl @@ -8,7 +8,6 @@ load( _cuda_py_test = "cuda_py_test", _filegroup = "filegroup", _genrule = "genrule", - _get_compatible_with_cloud = "get_compatible_with_cloud", _get_compatible_with_portable = "get_compatible_with_portable", _if_indexing_source_code = "if_indexing_source_code", _if_not_mobile_or_arm_or_lgpl_restricted = "if_not_mobile_or_arm_or_lgpl_restricted", @@ -81,7 +80,6 @@ tf_external_workspace_visible = _tf_external_workspace_visible tf_grpc_dependencies = _tf_grpc_dependencies tf_grpc_cc_dependencies = _tf_grpc_cc_dependencies get_compatible_with_portable = _get_compatible_with_portable -get_compatible_with_cloud = _get_compatible_with_cloud cc_header_only_library = _cc_header_only_library tf_gen_op_libs = _tf_gen_op_libs tf_gen_op_wrapper_cc = _tf_gen_op_wrapper_cc diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index 3066c0e597f3f0..0212b07d7600cf 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -89,9 +89,9 @@ def _SkipMember(cls, member): # pylint: disable=unused-argument # Differences created by typing implementations. -_NORMALIZE_TYPE[( - 'tensorflow.python.framework.ops.Tensor')] = ( - "") +_NORMALIZE_TYPE[ + 'tensorflow.python.framework.tensor.Tensor' +] = "" _NORMALIZE_TYPE['typing.Generic'] = "" # TODO(b/203104448): Remove once the golden files are generated in Python 3.7. _NORMALIZE_TYPE[""] = 'typing.Union' diff --git a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh index d2846da30469e8..8e8ccab623261c 100644 --- a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh +++ b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh @@ -16,6 +16,7 @@ set -x ARM_SKIP_TESTS="-//tensorflow/lite/... \ +-//tensorflow/compiler/xla/service/gpu:fusion_merger_test \ -//tensorflow/python/kernel_tests/nn_ops:atrous_conv2d_test \ -//tensorflow/python/kernel_tests/nn_ops:conv_ops_test \ " diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index effcbfcf85a386..37394c6eb9a010 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -9,7 +9,7 @@ load( "tf_cc_test", "tf_copts", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "tf_py_strict_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", "tf_py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -25,7 +25,7 @@ cc_library( hdrs = [ "transform_utils.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_copts(), visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/tools/optimization/BUILD b/tensorflow/tools/optimization/BUILD index f6ab1fb0d2a64a..fc12208932e00b 100644 --- a/tensorflow/tools/optimization/BUILD +++ b/tensorflow/tools/optimization/BUILD @@ -6,7 +6,7 @@ load( "tf_cc_binary", "tf_cuda_library", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +18,7 @@ tf_cuda_library( name = "optimization_pass_runner_lib", srcs = ["optimization_pass_runner.cc"], hdrs = ["optimization_pass_runner.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index a8ef26478df3a3..d9d87af17935c1 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -115,7 +115,6 @@ COMMON_PIP_DEPS = [ "//tensorflow/lite/python:tflite_convert", "//tensorflow/lite/toco/python:toco_from_protos", "//tensorflow/lite/tools:visualize", - "//tensorflow/python/autograph/converters:list_comprehensions", "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/autograph/impl/testing:pybind_for_testing", "//tensorflow/python/autograph/pyct/testing:basic_definitions", @@ -160,7 +159,6 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/kernel_tests/signal:test_util", "//tensorflow/python/kernel_tests/sparse_ops:sparse_xent_op_test_base", "//tensorflow/python/lib:__init__", - "//tensorflow/python/ops:weak_tensor_ops", "//tensorflow/python/ops/parallel_for:test_util", "//tensorflow/python/ops/structured:structured_tensor_dynamic", "//tensorflow/python/platform:resource_loader", diff --git a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl index fff3dd70496e89..a2bdd6a7eedafe 100644 --- a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl +++ b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl @@ -34,7 +34,6 @@ def aarch64_compiler_configure(): ml2014_tf_aarch64_configs( name_container_map = { "ml2014_aarch64": "docker://localhost/tensorflow-build-aarch64", - "ml2014_aarch64-python3.8": "docker://localhost/tensorflow-build-aarch64:latest-python3.8", "ml2014_aarch64-python3.9": "docker://localhost/tensorflow-build-aarch64:latest-python3.9", "ml2014_aarch64-python3.10": "docker://localhost/tensorflow-build-aarch64:latest-python3.10", "ml2014_aarch64-python3.11": "docker://localhost/tensorflow-build-aarch64:latest-python3.11", @@ -72,7 +71,6 @@ def aarch64_compiler_configure(): ml2014_tf_aarch64_configs( name_container_map = { "ml2014_clang_aarch64": "docker://localhost/tensorflow-build-aarch64", - "ml2014_clang_aarch64-python3.8": "docker://localhost/tensorflow-build-aarch64:latest-python3.8", "ml2014_clang_aarch64-python3.9": "docker://localhost/tensorflow-build-aarch64:latest-python3.9", "ml2014_clang_aarch64-python3.10": "docker://localhost/tensorflow-build-aarch64:latest-python3.10", "ml2014_clang_aarch64-python3.11": "docker://localhost/tensorflow-build-aarch64:latest-python3.11", diff --git a/tensorflow/tsl/BUILD b/tensorflow/tsl/BUILD index da15e3cf07d892..47e987dcdab035 100644 --- a/tensorflow/tsl/BUILD +++ b/tensorflow/tsl/BUILD @@ -485,7 +485,6 @@ cc_library( name = "grpc++", visibility = ["//visibility:public"], deps = select({ - ":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"], "//conditions:default": ["@com_github_grpc_grpc//:grpc++"], }), ) diff --git a/tensorflow/tsl/distributed_runtime/preemption/BUILD b/tensorflow/tsl/distributed_runtime/preemption/BUILD index 97572b3285543a..6eeb7e8ae47bff 100644 --- a/tensorflow/tsl/distributed_runtime/preemption/BUILD +++ b/tensorflow/tsl/distributed_runtime/preemption/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud", "tsl_grpc_cc_dependencies") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_grpc_cc_dependencies") load("//tensorflow/tsl:tsl.bzl", "set_external_visibility") package( @@ -15,7 +15,7 @@ cc_library( name = "preemption_notifier", srcs = ["preemption_notifier.cc"], hdrs = ["preemption_notifier.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", diff --git a/tensorflow/tsl/platform/BUILD b/tensorflow/tsl/platform/BUILD index 3eb5ff9089abc6..fd61677798e865 100644 --- a/tensorflow/tsl/platform/BUILD +++ b/tensorflow/tsl/platform/BUILD @@ -14,6 +14,7 @@ load( load( "//tensorflow/tsl/platform:build_config.bzl", "tf_cuda_libdevice_path_deps", + "tf_error_logging_deps", "tf_fingerprint_deps", "tf_google_mobile_srcs_no_runtime", "tf_logging_deps", @@ -327,13 +328,14 @@ cc_library( ":errors", ":logging", ":macros", + ":platform", ":status", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - ], + ] + tf_platform_deps("statusor"), ) cc_library( @@ -665,6 +667,7 @@ exports_files( "env.cc", "env.h", "env_time.h", + "error_logging.h", "file_system.cc", "file_system.h", "file_system_helper.cc", @@ -730,6 +733,7 @@ filegroup( "setround.h", "snappy.h", "status.h", + "statusor.h", "tracing.h", "unbounded_work_queue.h", ], @@ -1096,6 +1100,16 @@ cc_library( deps = tf_logging_deps(), ) +cc_library( + name = "error_logging", + compatible_with = get_compatible_with_portable(), + textual_hdrs = ["error_logging.h"], + visibility = [ + "//visibility:public", + ], + deps = tf_error_logging_deps(), +) + cc_library( name = "prefetch", hdrs = ["prefetch.h"], diff --git a/tensorflow/tsl/platform/build_config.bzl b/tensorflow/tsl/platform/build_config.bzl index a257152eea8ebd..8515b784a585cc 100644 --- a/tensorflow/tsl/platform/build_config.bzl +++ b/tensorflow/tsl/platform/build_config.bzl @@ -14,6 +14,7 @@ load( _tf_additional_tensor_coding_deps = "tf_additional_tensor_coding_deps", _tf_additional_test_deps = "tf_additional_test_deps", _tf_cuda_libdevice_path_deps = "tf_cuda_libdevice_path_deps", + _tf_error_logging_deps = "tf_error_logging_deps", _tf_fingerprint_deps = "tf_fingerprint_deps", _tf_google_mobile_srcs_no_runtime = "tf_google_mobile_srcs_no_runtime", _tf_google_mobile_srcs_only_runtime = "tf_google_mobile_srcs_only_runtime", @@ -54,6 +55,7 @@ tf_additional_rpc_deps = _tf_additional_rpc_deps tf_additional_tensor_coding_deps = _tf_additional_tensor_coding_deps tf_additional_test_deps = _tf_additional_test_deps tf_cuda_libdevice_path_deps = _tf_cuda_libdevice_path_deps +tf_error_logging_deps = _tf_error_logging_deps tf_fingerprint_deps = _tf_fingerprint_deps tf_google_mobile_srcs_no_runtime = _tf_google_mobile_srcs_no_runtime tf_google_mobile_srcs_only_runtime = _tf_google_mobile_srcs_only_runtime diff --git a/tensorflow/tsl/platform/cloud/BUILD b/tensorflow/tsl/platform/cloud/BUILD index a8bf3e30fee028..6266747a67248a 100644 --- a/tensorflow/tsl/platform/cloud/BUILD +++ b/tensorflow/tsl/platform/cloud/BUILD @@ -118,7 +118,6 @@ cc_library( ":http_request", ":ram_file_block_cache", ":time_util", - "//tensorflow/tsl/lib/gtl:map_util", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:file_statistics", @@ -136,6 +135,7 @@ cc_library( "//tensorflow/tsl/profiler/lib:traceme", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", ], alwayslink = 1, @@ -162,7 +162,6 @@ cc_library( ":http_request", ":ram_file_block_cache", ":time_util", - "//tensorflow/tsl/lib/gtl:map_util", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:file_statistics", @@ -180,6 +179,7 @@ cc_library( "//tensorflow/tsl/profiler/lib:traceme", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", ], alwayslink = 1, diff --git a/tensorflow/tsl/platform/cloud/gcs_file_system.cc b/tensorflow/tsl/platform/cloud/gcs_file_system.cc index f451279053e74c..9baeb4d8266aa4 100644 --- a/tensorflow/tsl/platform/cloud/gcs_file_system.cc +++ b/tensorflow/tsl/platform/cloud/gcs_file_system.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #ifndef _WIN32 #include @@ -1459,6 +1460,9 @@ Status GcsFileSystem::FileExists(const string& fname, TransactionToken* token) { TF_RETURN_IF_ERROR(BucketExists(bucket, &result)); if (result) { return OkStatus(); + } else { + return absl::NotFoundError( + absl::StrCat("The specified bucket ", fname, " was not found.")); } } diff --git a/tensorflow/tsl/platform/cloud/gcs_file_system_test.cc b/tensorflow/tsl/platform/cloud/gcs_file_system_test.cc index 61a588eea96b2a..d9bea9b43d54e8 100644 --- a/tensorflow/tsl/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/tsl/platform/cloud/gcs_file_system_test.cc @@ -1649,10 +1649,8 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) { 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */, false /* compose append */); - EXPECT_TRUE( - errors::IsInvalidArgument(fs.FileExists("gs://bucket2/", nullptr))); - EXPECT_TRUE( - errors::IsInvalidArgument(fs.FileExists("gs://bucket2", nullptr))); + EXPECT_TRUE(absl::IsNotFound(fs.FileExists("gs://bucket2/", nullptr))); + EXPECT_TRUE(absl::IsNotFound(fs.FileExists("gs://bucket2", nullptr))); } TEST(GcsFileSystemTest, FileExists_StatCache) { diff --git a/tensorflow/tsl/platform/default/BUILD b/tensorflow/tsl/platform/default/BUILD index 84af5a2fbca3ab..cc59928cb393ba 100644 --- a/tensorflow/tsl/platform/default/BUILD +++ b/tensorflow/tsl/platform/default/BUILD @@ -206,6 +206,24 @@ cc_library( deps = ["//tensorflow/tsl/platform:types"], ) +cc_library( + name = "error_logging", + srcs = ["error_logging.cc"], + hdrs = ["//tensorflow/tsl/platform:error_logging.h"], + tags = [ + "manual", + "no_oss", + "nobuilder", + ], + textual_hdrs = ["error_logging.h"], + deps = [ + "//tensorflow/tsl/platform", + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "human_readable_json", srcs = ["human_readable_json.cc"], @@ -546,6 +564,22 @@ cc_library( visibility = set_external_visibility(["//tensorflow:__subpackages__"]), ) +cc_library( + name = "statusor", + tags = [ + "manual", + "no_oss", + "nobuilder", + ], + textual_hdrs = ["statusor.h"], + visibility = set_external_visibility(["//tensorflow:__subpackages__"]), + deps = [ + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:status", + "@com_google_absl//absl/status:statusor", + ], +) + bzl_library( name = "cuda_build_defs_bzl", srcs = ["cuda_build_defs.bzl"], @@ -577,6 +611,7 @@ filegroup( "posix_file_system.h", "stacktrace.h", "status.h", + "statusor.h", "tracing_impl.h", "//tensorflow/tsl/platform/profile_utils:cpu_utils.h", "//tensorflow/tsl/platform/profile_utils:i_cpu_utils_helper.h", @@ -608,6 +643,7 @@ exports_files( srcs = glob( ["*"], exclude = [ + "error_logging.h", "integral_types.h", "logging.h", "test.cc", @@ -618,6 +654,7 @@ exports_files( exports_files( srcs = [ + "error_logging.h", "integral_types.h", "logging.h", "test.cc", diff --git a/tensorflow/tsl/platform/default/build_config.bzl b/tensorflow/tsl/platform/default/build_config.bzl index 8c8f606dd644e3..e2720d0eec03dd 100644 --- a/tensorflow/tsl/platform/default/build_config.bzl +++ b/tensorflow/tsl/platform/default/build_config.bzl @@ -243,7 +243,6 @@ def cc_proto_library( if use_grpc_plugin: cc_libs += select({ - clean_dep("//tensorflow/tsl:linux_s390x"): ["//external:grpc_lib_unsecure"], "//conditions:default": ["//external:grpc_lib"], }) @@ -326,7 +325,6 @@ def cc_grpc_library( proto_targets += srcs extra_deps += select({ - clean_dep("//tensorflow/tsl:linux_s390x"): ["//external:grpc_lib_unsecure"], "//conditions:default": ["//external:grpc_lib"], }) @@ -663,6 +661,7 @@ def tf_additional_lib_hdrs(): clean_dep("//tensorflow/tsl/platform/default:notification.h"), clean_dep("//tensorflow/tsl/platform/default:stacktrace.h"), clean_dep("//tensorflow/tsl/platform/default:status.h"), + clean_dep("//tensorflow/tsl/platform/default:statusor.h"), clean_dep("//tensorflow/tsl/platform/default:tracing_impl.h"), clean_dep("//tensorflow/tsl/platform/default:unbounded_work_queue.h"), ] + select({ @@ -842,6 +841,9 @@ def tf_platform_alias(name, platform_dir = "//tensorflow/tsl/platform/"): def tf_logging_deps(): return [clean_dep("//tensorflow/tsl/platform/default:logging")] +def tf_error_logging_deps(): + return [clean_dep("//tensorflow/tsl/platform/default:error_logging")] + def tf_resource_deps(): return [clean_dep("//tensorflow/tsl/platform/default:resource")] diff --git a/tensorflow/tsl/platform/default/error_logging.cc b/tensorflow/tsl/platform/default/error_logging.cc new file mode 100644 index 00000000000000..59efa3dc148124 --- /dev/null +++ b/tensorflow/tsl/platform/default/error_logging.cc @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tsl/platform/default/error_logging.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace tsl::error_logging { + +absl::Status Log(absl::string_view component, absl::string_view subcomponent, + absl::string_view error_msg) { + // no-op, intentionally empty function + return absl::OkStatus(); +} + +} // namespace tsl::error_logging diff --git a/tensorflow/tsl/platform/default/error_logging.h b/tensorflow/tsl/platform/default/error_logging.h new file mode 100644 index 00000000000000..26360b7e5b72e4 --- /dev/null +++ b/tensorflow/tsl/platform/default/error_logging.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_ERROR_LOGGING_H_ +#define TENSORFLOW_TSL_PLATFORM_DEFAULT_ERROR_LOGGING_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace tsl::error_logging { + +absl::Status Log(absl::string_view component, absl::string_view subcomponent, + absl::string_view error_msg); + +} + +#endif // TENSORFLOW_TSL_PLATFORM_DEFAULT_ERROR_LOGGING_H_ diff --git a/tensorflow/tsl/platform/default/statusor.h b/tensorflow/tsl/platform/default/statusor.h new file mode 100644 index 00000000000000..300b4906f0f8db --- /dev/null +++ b/tensorflow/tsl/platform/default/statusor.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_STATUSOR_H_ +#define TENSORFLOW_TSL_PLATFORM_DEFAULT_STATUSOR_H_ + +#include "absl/status/statusor.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/status.h" + +#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ + TF_ASSIGN_OR_RETURN_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) + +#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (TF_PREDICT_FALSE(!statusor.ok())) { \ + return statusor.status(); \ + } \ + lhs = std::move(statusor).value() + +#endif // TENSORFLOW_TSL_PLATFORM_DEFAULT_STATUSOR_H_ diff --git a/tensorflow/tsl/platform/error_logging.h b/tensorflow/tsl/platform/error_logging.h new file mode 100644 index 00000000000000..d27d0115f37391 --- /dev/null +++ b/tensorflow/tsl/platform/error_logging.h @@ -0,0 +1,27 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PLATFORM_ERROR_LOGGING_H_ +#define TENSORFLOW_TSL_PLATFORM_ERROR_LOGGING_H_ + +#include "tensorflow/tsl/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/tsl/platform/google/error_logging.h" // IWYU pragma: export +#else +#include "tensorflow/tsl/platform/default/error_logging.h" // IWYU pragma: export +#endif + +#endif // TENSORFLOW_TSL_PLATFORM_ERROR_LOGGING_H_ diff --git a/tensorflow/tsl/platform/statusor.h b/tensorflow/tsl/platform/statusor.h index 34bf3e38d20e6e..cf7a95d45a7ec8 100644 --- a/tensorflow/tsl/platform/statusor.h +++ b/tensorflow/tsl/platform/statusor.h @@ -72,12 +72,22 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/platform.h" #include "tensorflow/tsl/platform/status.h" +// Include appropriate platform-dependent `TF_ASSIGN_OR_RETURN`. +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/tsl/platform/google/statusor.h" // IWYU pragma: export +#else +#include "tensorflow/tsl/platform/default/statusor.h" // IWYU pragma: export +#endif + namespace tsl { using absl::StatusOr; +} // namespace tsl + #define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ TF_ASSERT_OK_AND_ASSIGN_IMPL( \ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ @@ -91,17 +101,4 @@ using absl::StatusOr; #define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) #define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y -#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ - TF_ASSIGN_OR_RETURN_IMPL( \ - TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) - -#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - if (TF_PREDICT_FALSE(!statusor.ok())) { \ - return statusor.status(); \ - } \ - lhs = std::move(statusor).value() - -} // namespace tsl - #endif // TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ diff --git a/tensorflow/tsl/platform/statusor_test.cc b/tensorflow/tsl/platform/statusor_test.cc index 9dca858fa2f854..c6eb12c782a5e0 100644 --- a/tensorflow/tsl/platform/statusor_test.cc +++ b/tensorflow/tsl/platform/statusor_test.cc @@ -697,5 +697,34 @@ void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) { } BENCHMARK(BM_StatusOrFactoryFailLongMsg); +#if defined(PLATFORM_GOOGLE) + +StatusOr GetError() { + return absl::InvalidArgumentError("An invalid argument error"); +} + +StatusOr PropagateError() { + TF_ASSIGN_OR_RETURN(int a, GetError()); + return a; +} + +StatusOr PropagateError2() { + TF_ASSIGN_OR_RETURN(int a, PropagateError()); + return a; +} + +TEST(Status, StackTracePropagation) { + StatusOr s = PropagateError2(); + auto sources = s.status().GetSourceLocations(); + ASSERT_EQ(sources.size(), 3); + + for (int i = 0; i < 3; ++i) { + ASSERT_EQ(sources[i].file_name(), + "third_party/tensorflow/tsl/platform/statusor_test.cc"); + } +} + +#endif + } // namespace } // namespace tsl diff --git a/tensorflow/tsl/tsl.default.bzl b/tensorflow/tsl/tsl.default.bzl index 37b772a9f3b12c..b4f01e6aec0efe 100644 --- a/tensorflow/tsl/tsl.default.bzl +++ b/tensorflow/tsl/tsl.default.bzl @@ -2,8 +2,6 @@ load( "//tensorflow/tsl:tsl.bzl", - "clean_dep", - "two_gpu_tags", _filegroup = "filegroup", _get_compatible_with_portable = "get_compatible_with_portable", _if_not_mobile_or_arm_or_lgpl_restricted = "if_not_mobile_or_arm_or_lgpl_restricted", @@ -11,18 +9,6 @@ load( _tsl_grpc_cc_dependencies = "tsl_grpc_cc_dependencies", _tsl_pybind_extension = "tsl_pybind_extension", ) -load( - "//tensorflow/tsl/platform:build_config.bzl", - "tsl_cc_test", -) -load( - "//tensorflow/tsl/platform:build_config_root.bzl", - "tf_gpu_tests_tags", -) -load( - "@local_config_cuda//cuda:build_defs.bzl", - "if_cuda", -) get_compatible_with_portable = _get_compatible_with_portable filegroup = _filegroup @@ -30,92 +16,3 @@ if_not_mobile_or_arm_or_lgpl_restricted = _if_not_mobile_or_arm_or_lgpl_restrict internal_hlo_deps = _internal_hlo_deps tsl_grpc_cc_dependencies = _tsl_grpc_cc_dependencies tsl_pybind_extension = _tsl_pybind_extension - -def get_compatible_with_cloud(): - return [] - -def tsl_gpu_cc_test( - name, - srcs = [], - deps = [], - tags = [], - data = [], - size = "medium", - linkstatic = 0, - args = [], - linkopts = [], - **kwargs): - """Create tests for cpu, gpu and optionally 2gpu - - Args: - name: unique name for this test target. - srcs: list of C and C++ files that are processed to create the binary target. - deps: list of other libraries to be linked in to the binary target. - tags: useful for categorizing the tests - data: files needed by this rule at runtime. - size: classification of how much time/resources the test requires. - linkstatic: link the binary in static mode. - args: command line arguments that Bazel passes to the target. - linkopts: add these flags to the C++ linker command. - **kwargs: Extra arguments to the rule. - """ - targets = [] - tsl_cc_test( - name = name + "_cpu", - size = size, - srcs = srcs, - args = args, - data = data, - copts = if_cuda(["-DNV_CUDNN_DISABLE_EXCEPTION"]), - linkopts = linkopts, - linkstatic = linkstatic, - tags = tags, - deps = deps, - **kwargs - ) - targets.append(name + "_cpu") - tsl_cc_test( - name = name + "_gpu", - size = size, - srcs = srcs, - args = args, - data = data, - copts = if_cuda(["-DNV_CUDNN_DISABLE_EXCEPTION"]), - linkopts = linkopts, - linkstatic = select({ - # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out. - clean_dep("//tensorflow/tsl:macos"): 1, - "@local_config_cuda//cuda:using_nvcc": 1, - "@local_config_cuda//cuda:using_clang": 1, - "//conditions:default": 0, - }), - tags = tags + tf_gpu_tests_tags(), - deps = deps, - **kwargs - ) - targets.append(name + "_gpu") - if "multi_gpu" in tags or "multi_and_single_gpu" in tags: - cleaned_tags = tags + two_gpu_tags - if "requires-gpu-nvidia" in cleaned_tags: - cleaned_tags.remove("requires-gpu-nvidia") - tsl_cc_test( - name = name + "_2gpu", - size = size, - srcs = srcs, - args = args, - data = data, - linkopts = linkopts, - linkstatic = select({ - # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out. - clean_dep("//tensorflow/tsl:macos"): 1, - "@local_config_cuda//cuda:using_nvcc": 1, - "@local_config_cuda//cuda:using_clang": 1, - "//conditions:default": 0, - }), - tags = cleaned_tags, - deps = deps, - **kwargs - ) - targets.append(name + "_2gpu") - - native.test_suite(name = name, tests = targets, tags = tags) diff --git a/tensorflow/tsl/util/BUILD b/tensorflow/tsl/util/BUILD index 0e624d68a91e22..e7d376ad04f369 100644 --- a/tensorflow/tsl/util/BUILD +++ b/tensorflow/tsl/util/BUILD @@ -303,6 +303,18 @@ filegroup( visibility = set_external_visibility(["//tensorflow/core/util:__pkg__"]), ) +filegroup( + name = "onednn_util_hdrs", + srcs = [ + "onednn_threadpool.h", + ], + visibility = set_external_visibility([ + "//tensorflow/compiler/xla:__pkg__", + "//tensorflow/core:__pkg__", + "//tensorflow/core/framework:__pkg__", + ]), +) + filegroup( name = "android_test_hdrs", testonly = 1, diff --git a/tensorflow/core/util/mkl_threadpool.h b/tensorflow/tsl/util/onednn_threadpool.h similarity index 72% rename from tensorflow/core/util/mkl_threadpool.h rename to tensorflow/tsl/util/onednn_threadpool.h index e160c75661265a..ed9989bd3c2511 100644 --- a/tensorflow/core/util/mkl_threadpool.h +++ b/tensorflow/tsl/util/onednn_threadpool.h @@ -14,8 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ -#define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ +#ifndef TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ +#define TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ #ifdef INTEL_MKL #include @@ -25,17 +25,15 @@ limitations under the License. #include #include -#include "dnnl_threadpool.hpp" -#include "dnnl.hpp" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/blocking_counter.h" -#include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/platform/threadpool.h" -#include "tensorflow/core/util/onednn_env_vars.h" - #define EIGEN_USE_THREADS -namespace tensorflow { +#include "dnnl.hpp" +#include "dnnl_threadpool.hpp" +#include "tensorflow/tsl/platform/blocking_counter.h" +#include "tensorflow/tsl/platform/cpu_info.h" +#include "tensorflow/tsl/platform/threadpool.h" + +namespace tsl { #ifndef ENABLE_ONEDNN_OPENMP using dnnl::threadpool_interop::threadpool_iface; @@ -75,28 +73,20 @@ inline void run_jobs(bool balance, int i, int n, int njobs, } } -struct MklDnnThreadPool : public threadpool_iface { - MklDnnThreadPool() = default; +class OneDnnThreadPool : public threadpool_iface { + public: + OneDnnThreadPool() = default; - MklDnnThreadPool(OpKernelContext* ctx, int num_threads = -1) { - eigen_interface_ = ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); -#if DNNL_VERSION_MAJOR >= 3 || \ - (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) - if (num_threads == -1) { - dnnl_threadpool_interop_set_max_concurrency( - eigen_interface_->NumThreads()); - num_threads_ = eigen_interface_->NumThreads(); - } else { - dnnl_threadpool_interop_set_max_concurrency(num_threads); - num_threads_ = num_threads; - } -#else - num_threads_ = - num_threads == -1 ? eigen_interface_->NumThreads() : num_threads; -#endif // DNNL_VERSION_MAJOR >= 3 || - // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface, + int num_threads = -1) + : eigen_interface_(eigen_interface) { + set_num_and_max_threads(num_threads); + } + OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface, + bool can_use_caller_thread, int num_threads = -1) + : eigen_interface_(eigen_interface), + can_use_caller_thread_(can_use_caller_thread) { + set_num_and_max_threads(num_threads); } virtual int get_num_threads() const override { return num_threads_; } virtual bool get_in_parallel() const override { @@ -121,10 +111,10 @@ struct MklDnnThreadPool : public threadpool_iface { // If use_caller_thread, schedule njobs-1 jobs to thread pool and run last // job directly. const bool use_caller_thread = - ThreadPoolUseCallerThread() && nthr == port::NumSchedulableCPUs(); + can_use_caller_thread_ && nthr == port::NumSchedulableCPUs(); const int njobs_to_schedule = use_caller_thread ? njobs - 1 : njobs; - BlockingCounter counter(njobs_to_schedule); + tsl::BlockingCounter counter(njobs_to_schedule); std::function handle_range = [=, &handle_range, &counter]( int first, int last) { while (last - first > 1) { @@ -152,25 +142,38 @@ struct MklDnnThreadPool : public threadpool_iface { counter.Wait(); } - ~MklDnnThreadPool() {} + ~OneDnnThreadPool() {} private: Eigen::ThreadPoolInterface* eigen_interface_ = nullptr; - int num_threads_ = 1; // Execute in caller thread. + int num_threads_ = 1; // Execute in caller thread. + bool can_use_caller_thread_ = false; // true if the user set the env variable + // to use caller thread also. + inline void set_num_and_max_threads(int num_threads) { + num_threads_ = + num_threads == -1 ? eigen_interface_->NumThreads() : num_threads; +#if DNNL_VERSION_MAJOR >= 3 || \ + (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + dnnl_threadpool_interop_set_max_concurrency(num_threads_); +#endif // DNNL_VERSION_MAJOR >= 3 || + // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + } }; #else -// This struct was just added to enable successful OMP-based build. -struct MklDnnThreadPool { - MklDnnThreadPool() = default; - MklDnnThreadPool(OpKernelContext* ctx) {} - MklDnnThreadPool(OpKernelContext* ctx, int num_threads) {} +// This class was just added to enable successful OMP-based build. +class OneDnnThreadPool { + public: + OneDnnThreadPool() = default; + OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface) {} + OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface, + bool can_use_caller_thread, int num_threads = -1) {} }; #endif // !ENABLE_ONEDNN_OPENMP -} // namespace tensorflow +} // namespace tsl #endif // INTEL_MKL -#endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ +#endif // TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 26196eef24e520..c4e64dbfa66d25 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -169,9 +169,9 @@ def _tf_repositories(): tf_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-3dc310302210c1891ffcfb12ae67b11a3ad3a150", - sha256 = "ba668f9f8ea5b4890309b7db1ed2e152aaaf98af6f9a8a63dbe1b75c04e52cb9", - urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/3dc310302210c1891ffcfb12ae67b11a3ad3a150.zip"), + strip_prefix = "cpuinfo-87d8234510367db49a65535021af5e1838a65ac2", + sha256 = "609fc42c47482c1fc125dccac65e843f640e792540162581c4b7eb6ff81c826a", + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/87d8234510367db49a65535021af5e1838a65ac2.zip"), ) tf_http_archive( diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index 71acfa7cb7d629..189d3e3e784003 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -4,7 +4,6 @@ def if_cuda(if_true, if_false = []): Returns a select statement which evaluates to if_true if we're building with CUDA enabled. Otherwise, the select statement evaluates to if_false. - """ return select({ "@local_config_cuda//:is_cuda_enabled": if_true, @@ -16,13 +15,21 @@ def if_cuda_clang(if_true, if_false = []): Returns a select statement which evaluates to if_true if we're building with cuda-clang. Otherwise, the select statement evaluates to if_false. - """ return select({ "@local_config_cuda//cuda:using_clang": if_true, "//conditions:default": if_false }) +def if_cuda_exec(if_true, if_false = []): + """Synonym for if_cuda. + + Selects if_true both in target and in exec configurations. In principle, + if_cuda would only need to select if_true in a target configuration, but + not in an exec configuration, but this is not currently implemented. + """ + return if_cuda(if_true, if_false) + def cuda_compiler(if_cuda_clang, if_nvcc, neither = []): """Shorthand for select()'ing on wheteher we're building with cuda-clang or nvcc. diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..e69de29bb2d1d6 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +0,0 @@ -Auto generated patch. Do not edit or delete it, even if empty. diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 278db4f8354cd9..ddb4704d4b456c 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "dbaa5838c13e5593b9de37b8f3daffe4cb914a17" - LLVM_SHA256 = "79a02eb8733ec1f51c23fdc0cfc123fb023d855fe53ca59515cd8c6cb2af8993" + LLVM_COMMIT = "1936bb81aafdbb3b4c9770a24fc77ba07669bd19" + LLVM_SHA256 = "7b519ebd1b17dd59b94e5836b431486ec3e2d020c7799c70ce9ee706d25d3c5d" tf_http_archive( name = name, diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 3ad60b9c2c5bb0..3bb703ca414a73 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1275,15 +1275,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl } }; -@@ -1143,7 +1210,7 @@ - // function. This is sufficient because we only support one function per - // program at the moment. - // TODO(#1048): Find out why .maxIterations = 1 no longer works. -- // There have been recent refactors to applyPatternsAndFoldGreedily -+ // There have been recent refactors in applyPatternsAndFoldGreedily - // upstream, and that might be the reason. - GreedyRewriteConfig config; - config.useTopDownTraversal = true; @@ -1181,7 +1248,9 @@ patterns.add(&getContext()); patterns.add(&getContext()); diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 1934d33bf9868a..c7d30c5b560974 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "20b1da42266a1f351b8315bc195faabceaa74f3e" - STABLEHLO_SHA256 = "7ab70ba2d0aa3c7331df912b674c2825cc168cb691db171a2343d453e4a53811" + STABLEHLO_COMMIT = "41bad512515d609ccd3896d74bf697e7d456e1d3" + STABLEHLO_SHA256 = "01d143b57efda2fcf5e3482cbd0c4beae2a51164082e0797f0093cdbd8c82b06" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl index 6d00513827b380..83fcc7d69717b1 100644 --- a/third_party/tensorrt/build_defs.bzl.tpl +++ b/third_party/tensorrt/build_defs.bzl.tpl @@ -3,3 +3,7 @@ def if_tensorrt(if_true, if_false=[]): """Tests whether TensorRT was enabled during the configure process.""" return %{if_tensorrt} + +def if_tensorrt_exec(if_true, if_false=[]): + """Synonym for if_tensorrt.""" + return %{if_tensorrt} diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 8b5a8623f70020..f456dbaf7cae33 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "00a31b7ce92e062de48321d9ff50ad414144a47b" - TFRT_SHA256 = "113355c7dd55eb34346e2264544f309068acee5f8102a1a8f2146fc6a571cece" + TFRT_COMMIT = "08a6b6ecfc7cce7d0c8388fe7a9c73352467091e" + TFRT_SHA256 = "bb6f479caeba3b28f033a9a420b23cb00f9d235ac8df312b1a57fda1ef2f8039" tf_http_archive( name = "tf_runtime", diff --git a/third_party/triton/cl545371535.patch b/third_party/triton/cl545371535.patch deleted file mode 100644 index f010d586bb02da..00000000000000 --- a/third_party/triton/cl545371535.patch +++ /dev/null @@ -1,29 +0,0 @@ -diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp -index cd8d1b82e..f0f1127d5 100644 ---- a/lib/Dialect/Triton/IR/Ops.cpp -+++ b/lib/Dialect/Triton/IR/Ops.cpp -@@ -6,7 +6,7 @@ - #include "mlir/IR/OperationSupport.h" - #include "triton/Dialect/Triton/IR/Dialect.h" - #include "triton/Dialect/Triton/IR/Types.h" --#include "triton/Dialect/TritonGPU/IR/Attributes.h" -+//#include "triton/Dialect/TritonGPU/IR/Attributes.h" - - namespace mlir { - namespace triton { -@@ -404,6 +404,7 @@ LogicalResult mlir::triton::DotOp::verify() { - auto bTy = getOperand(1).getType().cast(); - if (aTy.getElementType() != bTy.getElementType()) - return emitError("element types of operands A and B must match"); -+#if 0 // TODO(csigg): avoid cyclic BUILD dependency. - auto aEncoding = - aTy.getEncoding().dyn_cast_or_null(); - auto bEncoding = -@@ -415,6 +416,7 @@ LogicalResult mlir::triton::DotOp::verify() { - return emitError("mismatching encoding between A and B operands"); - if (aEncoding.getMMAv2kWidth() != bEncoding.getMMAv2kWidth()) - return emitError("mismatching kWidth between A and B operands"); -+#endif - return mlir::success(); - } - diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 0488d04b39c462..219c10b85e2e7f 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl545371535" - TRITON_SHA256 = "97e9af5aa986744b9d3807e8a473b2b2056c8bedc74842b607d40cf780e8ac5a" + TRITON_COMMIT = "cl546794996" + TRITON_SHA256 = "57d4b5f1e68bb4df93528bd5394ba3338bef7bf9c0afdc96b44371fba650c037" tf_http_archive( name = "triton", @@ -16,6 +16,5 @@ def repo(): # For temporary changes which haven't landed upstream yet. patch_file = [ "//third_party/triton:cl536931041.patch", - "//third_party/triton:cl545371535.patch", ], )