diff --git a/.bazelrc b/.bazelrc index 76f824f372e0d3..9e565e91a1b903 100644 --- a/.bazelrc +++ b/.bazelrc @@ -219,13 +219,16 @@ build:mkl_aarch64_threadpool -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda +# Default CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # CUDA: This config refers to building CUDA op kernels with clang. build:cuda_clang --config=cuda -# Enable TensorRT optimizations https://developer.nvidia.com/tensorrt -build:cuda_clang --config=tensorrt -build:cuda_clang --action_env=TF_CUDA_CLANG="1" build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --copt=-Qunused-arguments # Select supported compute capabilities (supported graphics cards). # This is the same as the official TensorFlow builds. # See https://developer.nvidia.com/cuda-gpus#compute @@ -234,22 +237,22 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Set lld as the linker. +build:cuda_clang --host_linkopt="-fuse-ld=lld" +build:cuda_clang --host_linkopt="-lm" +build:cuda_clang --linkopt="-fuse-ld=lld" +build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" -build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" -build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc @@ -545,10 +548,6 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.17-clang_config_tensorrt" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" -test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --config=nvcc_clang @@ -633,7 +632,6 @@ build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/cla # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS -test:release_linux_base --test_env=LD_LIBRARY_PATH # Give only the list of failed tests at the end of the log test:release_linux_base --test_summary=short @@ -647,7 +645,6 @@ build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. # Note that linux cpu and cuda builds share the same toolchain now. build:release_gpu_linux --config=cuda_clang_official -test:release_gpu_linux --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" # Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think test:release_gpu_linux --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute @@ -656,6 +653,7 @@ build:release_arm64_linux --config=linux_arm64 build:release_arm64_linux --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" build:release_arm64_linux --config=mkl_aarch64_threadpool build:release_arm64_linux --copt=-flax-vector-conversions +test:release_arm64_linux --flaky_test_attempts=3 # The old gcc linux build options are preserved in the unsupported_*_linux # configs. If your project fails to build with Clang, you can use these @@ -677,9 +675,8 @@ build:unsupported_gpu_linux --config=unsupported_cpu_linux build:unsupported_gpu_linux --action_env=TF_CUDA_VERSION="11" build:unsupported_gpu_linux --action_env=TF_CUDNN_VERSION="8" build:unsupported_gpu_linux --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" -build:unsupported_gpu_linux --config=tensorrt build:unsupported_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" -build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64:/usr/local/tensorrt/lib" +build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain @@ -774,7 +771,7 @@ test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflo # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium --flaky_test_attempts=3 +test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 @@ -812,7 +809,7 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # inherit from build. build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml index 78e72b4ef419b0..78e7fd75085523 100644 --- a/.github/workflows/sigbuild-docker.yml +++ b/.github/workflows/sigbuild-docker.yml @@ -60,6 +60,14 @@ jobs: registry: gcr.io username: _json_key password: ${{ secrets.GCP_CREDS }} + - + name: Login to AR + # Once this is verified, removed gcr.io actions. + uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 + with: + registry: us-central1-docker.pkg.dev + username: _json_key + password: ${{ secrets.GCP_CREDS }} - name: Grab the upcoming TF version to tag this container run: | @@ -87,6 +95,8 @@ jobs: tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} gcr.io/tensorflow-sigs/build:latest-${{ matrix.python-version }} gcr.io/tensorflow-sigs/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} + us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:latest-${{ matrix.python-version }} + us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} cache-from: type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }} cache-to: type=inline - diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 89c61463462745..17b77f808d9c80 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -253,13 +253,21 @@ There are two ways to run TensorFlow unit tests. export flags="--config=opt -k" ``` - If the tests are to be run on the GPU, add CUDA paths to LD_LIBRARY_PATH and - add the `cuda` option flag + If the tests are to be run on the GPU: + * For TensorFlow versions starting from v.2.18.0: + Add the `cuda` option flag. - ```bash - export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - export flags="--config=opt --config=cuda -k" - ``` + ```bash + export flags="--config=opt --config=cuda -k" + ``` + + * For TensorFlow versions prior v.2.18.0: + Add CUDA paths to LD_LIBRARY_PATH and add the `cuda` option flag. + + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + export flags="--config=opt --config=cuda -k" + ``` For example, to run all tests under tensorflow/python, do: diff --git a/RELEASE.md b/RELEASE.md index eb2d9394c95a1a..9adbb494b09b86 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,7 +11,23 @@ * `tf.lite` * C API: - * An optional, fourth parameter was added `TfLiteOperatorCreate` as a step forward towards a cleaner API for `TfLiteOperator`. Function `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, released on 7/11/2024, and we do not expect there will be much code using this function yet. Any code breakages can be easily resolved by passing nullptr as the new, 4th parameter. + * An optional, fourth parameter was added `TfLiteOperatorCreate` as a step + forward towards a cleaner API for `TfLiteOperator`. Function + `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, + released on 7/11/2024, and we do not expect there will be much code using this + function yet. Any code breakages can be easily resolved by passing nullptr as + the new, 4th parameter. + * SignatureRunner is now supported for models with no signatures. + +* TensorRT support is disabled in CUDA builds for code health improvement. + +* Hermetic CUDA support is added. + + Hermetic CUDA uses a specific downloadable version of CUDA instead of the + user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL + distributions, and then use CUDA libraries and tools as dependencies in + various Bazel targets. This enables more reproducible builds for Google ML + projects and supported CUDA versions. ### Known Caveats diff --git a/WORKSPACE b/WORKSPACE index f8f467fccf5ce2..32ffd0433108c7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -64,3 +64,50 @@ tf_workspace1() load("@//tensorflow:workspace0.bzl", "tf_workspace0") tf_workspace0() + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "@local_tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "@local_tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/ci/devinfra/docker/windows/Dockerfile b/ci/devinfra/docker/windows/Dockerfile new file mode 100644 index 00000000000000..e1a7f949d5f48b --- /dev/null +++ b/ci/devinfra/docker/windows/Dockerfile @@ -0,0 +1,155 @@ +# This Dockerfile creates an image that has: +# - the correct MTU setting for networking from inside the container to work. +# - Visual Studio 2022 Build Tools +# - MSVC 14.39 +# - LLVM/Clang 18.1.4 +# - MSYS2 + curl, git, patch, vim, unzip, zip +# - Python 3.12.3 +# - Bazelisk 1.19.0 +# - JDK 21 (Azul Zulu) + +FROM mcr.microsoft.com/windows/servercore:ltsc2019 + +SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", \ + "$ErrorActionPreference='Stop'; $ProgressPreference='SilentlyContinue';$VerbosePreference = 'Continue';"] + +# This should only be necessary when running on A GCP VM, on a default +# network, which has the MTU of 1460, +# due to 40 bytes being reserved for GCP's internal usage. +# Note, an invalid sub-interface name will lead to an obscure error, e.g.: +# "The filename, directory name, or volume label syntax is incorrect." +# In such cases, check that the name of the sub-interface is valid: +# `netsh interface show interface` +RUN netsh interface ipv4 set subinterface \"vEthernet (Ethernet)\" mtu=1460 store=persistent + +RUN md C:\TEMP +RUN md C:\TMP + +# Install 7-Zip. +RUN (New-Object Net.WebClient).DownloadFile('https://www.7-zip.org/a/7z2201-x64.msi', '7z.msi'); \ + Start-Process msiexec.exe -ArgumentList \"/i 7z.msi /qn /norestart /log C:\\TEMP\\7z_install_log.txt\" -wait; \ + Remove-Item .\7z.msi; + +# Download the Visual Studio 2022 Installer. +RUN (New-Object Net.WebClient).DownloadFile('https://aka.ms/vs/17/release/vs_community.exe', 'C:\TEMP\vs_community.exe'); +# Install Visual Studio 2022 Build Tools + Compiler +SHELL ["cmd", "/S", "/C"] +# Packages, and component versions, can be found here: +# https://learn.microsoft.com/en-us/visualstudio/install/workload-component-id-vs-build-tools +RUN C:\TEMP\vs_community.exe \ + --quiet --wait --norestart --nocache \ + --add Microsoft.VisualStudio.Component.VC.Tools.x86.x64 \ + --add Microsoft.VisualStudio.Workload.NativeDesktop \ + --add Microsoft.VisualStudio.Component.VC.14.39.17.9.x86.64 \ + --add Microsoft.VisualStudio.Component.Windows11SDK.22621 \ + || IF "%ERRORLEVEL%"=="3010" EXIT 0 + +SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", \ + "$ErrorActionPreference='Stop'; $ProgressPreference='SilentlyContinue'; $VerbosePreference = 'Continue';"] + +# Install Clang. +RUN (New-Object Net.WebClient).DownloadFile( \ + 'https://github.com/llvm/llvm-project/releases/download/llvmorg-18.1.4/LLVM-18.1.4-win64.exe', \ + 'LLVM.exe'); \ + Start-Process -FilePath \"C:\Program Files\7-Zip\7z.exe\" -ArgumentList 'x LLVM.exe -oC:\tools\LLVM' -Wait; \ + $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine') + ';C:\tools\LLVM\bin'; \ + [Environment]::SetEnvironmentVariable('PATH', $env:PATH, 'Machine'); + +# Install MSYS2, and add some extra tools. +RUN (New-Object Net.WebClient).DownloadFile( \ + 'https://repo.msys2.org/distrib/x86_64/msys2-base-x86_64-20240113.tar.xz', \ + 'msys2.tar.xz'); \ + Start-Process -FilePath \"C:\Program Files\7-Zip\7z.exe\" -ArgumentList 'x msys2.tar.xz -oC:\TEMP\msys2.tar' -Wait; \ + Start-Process -FilePath \"C:\Program Files\7-Zip\7z.exe\" -ArgumentList 'x C:\TEMP\msys2.tar -oC:\tools' -Wait; \ + $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine') + ';C:\tools\msys64;C:\tools\msys64\usr\bin\'; \ + [Environment]::SetEnvironmentVariable('PATH', $env:PATH, 'Machine'); + +# Disable signature checking on pacman because we cannot initialize the keyring. +RUN Add-Content -Path C:\tools\msys64\etc\pacman.d\mirrorlist.mingw32 -Value 'SigLevel = Never' +RUN Add-Content -Path C:\tools\msys64\etc\pacman.d\mirrorlist.mingw64 -Value 'SigLevel = Never' +RUN Add-Content -Path C:\tools\msys64\etc\pacman.d\mirrorlist.msys -Value 'SigLevel = Never' + +# Install pacman packages. +RUN C:\tools\msys64\usr\bin\bash.exe -lc \ + 'pacman --noconfirm -Syy curl git patch vim unzip zip' + +# Install Python as a general utility/tool. +ENV PYTHON_VERSION 3.12.3 + +RUN $url = ('https://www.python.org/ftp/python/{0}/python-{0}-amd64.exe' -f $env:PYTHON_VERSION); \ + Write-Host ('Downloading {0} ...' -f $url); \ + [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; \ + (New-Object Net.WebClient).DownloadFile($url, 'C:\tmp\pyinstall.exe'); \ + \ + Write-Host 'Installing...'; \ + Start-Process -FilePath \"C:\tmp\pyinstall.exe\" -ArgumentList '/quiet InstallAllUsers=1 PrependPath=1 TargetDir=C:\Python312' -Wait; \ + \ + Write-Host 'Verifying install ...'; \ + Write-Host ' python --version'; C:\python312\python.exe --version; \ + \ + Write-Host 'Verifying pip install ...'; \ + C:\python312\python.exe -m pip --version; \ + \ + Write-Host 'Removing ...'; \ + Remove-Item C:\tmp\pyinstall.exe -Force; \ + \ + Write-Host 'Complete.'; + +# Install pip packages. +RUN python -m pip install --ignore-installed --force-reinstall --upgrade \ + setuptools packaging + +# Install JDK 21. +RUN \ + Add-Type -AssemblyName \"System.IO.Compression.FileSystem\"; \ + $zulu_pkg = \"zulu21.34.19-ca-jdk21.0.3-win_x64.zip\"; \ + $zulu_url = \"https://cdn.azul.com/zulu/bin/${zulu_pkg}\"; \ + $zulu_zip = \"c:\\temp\\${zulu_pkg}\"; \ + $zulu_extracted_path = \"c:\\temp\\\" + [IO.Path]::GetFileNameWithoutExtension($zulu_zip); \ + $zulu_root = \"c:\\openjdk\"; \ + (New-Object Net.WebClient).DownloadFile($zulu_url, $zulu_zip); \ + [System.IO.Compression.ZipFile]::ExtractToDirectory($zulu_zip, \"c:\\temp\"); \ + Move-Item $zulu_extracted_path -Destination $zulu_root; \ + Remove-Item $zulu_zip; \ + $env:PATH = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\") + \";${zulu_root}\\bin\"; \ + [Environment]::SetEnvironmentVariable(\"PATH\", $env:PATH, \"Machine\"); \ + $env:JAVA_HOME = $zulu_root; \ + [Environment]::SetEnvironmentVariable(\"JAVA_HOME\", $env:JAVA_HOME, \"Machine\") + +# Point to the LLVM installation. +# The Bazel Windows guide claims it can find LLVM automatically, +# but it likely only works if it's installed somewhere inside C:\Program Files. +ENV BAZEL_LLVM "C:\tools\LLVM" + +# These variables may be useful, but so far haven't been. Keeping for posterity. +# ENV CLANG_COMPILER_PATH "C:\tools\llvm\bin\clang.exe" +# ENV CC "C:\tools\llvm\bin\clang.exe" +# ENV BAZEL_COMPILER "C:\tools\llvm\bin\clang.exe" + +ENV BAZEL_SH "C:\tools\msys64\usr\bin\bash.exe" +ENV BAZEL_VS "C:\Program Files\Microsoft Visual Studio\2022\BuildTools" +ENV BAZEL_VC "C:\Program Files\Microsoft Visual Studio\2022\Community\VC" + +# Environment variables to work around MSYS issues. +ENV MSYS_NO_PATHCONV 1 +ENV MSYS2_ARG_CONV_EXCL * + +# This should only be necessary if there are multiple, differently-versioned +# MSVC compilers installed, and a particular one should be used. +# To find exact versions available: +# - Navigate to the relevant folder, e.g. +# C:\Program Files\Microsoft Visual Studio\2022 +# - Search for the `cl.exe` file: `gci -r -fi cl.exe` +# - The version will be part of the found path, e.g. +# 2022\Community\VC\Tools\MSVC\14.39.33519\bin\Hostx64\x64 +# ENV BAZEL_VC_FULL_VERSION 14.39.33519 + +# Install Bazelisk. +RUN md C:\tools\bazel +RUN (New-Object Net.WebClient).DownloadFile( \ + 'https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-windows-amd64.exe', \ + 'C:\tools\bazel\bazel.exe'); \ + $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine') + ';C:\tools\bazel'; \ + [Environment]::SetEnvironmentVariable('PATH', $env:PATH, 'Machine'); + +SHELL ["cmd.exe", "/s", "/c"] diff --git a/ci/devinfra/docker_windows/Dockerfile b/ci/devinfra/docker_windows/Dockerfile deleted file mode 100644 index 9666b6bef9d319..00000000000000 --- a/ci/devinfra/docker_windows/Dockerfile +++ /dev/null @@ -1,256 +0,0 @@ -FROM mcr.microsoft.com/dotnet/framework/sdk:4.8-windowsservercore-ltsc2019@sha256:04e06ae8f595b48bdee73c3334e82f46ba61217e2fe29702350d7b90e9c4b787 - -# Set default powershell policy for this script (ProgressPreference='SilentlyContinue' makes -# downloads with Invoke-WebRequest not show the progress bar and is MUCH faster). -SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", "$ErrorActionPreference='Stop'; $ProgressPreference='SilentlyContinue'; $VerbosePreference = 'Continue';"] - -# Workaround for networking (b/112379377) was closed as won't fix for MTU setting. -# Remaining lines handle making the metadata server on the VM accessible inside docker. -RUN Get-NetAdapter | Where-Object Name -like "*Ethernet*" | ForEach-Object { \ - & netsh interface ipv4 set subinterface $_.InterfaceIndex mtu=1460 store=persistent }; \ - $gateway = (Get-NetRoute | Where { $_.DestinationPrefix -eq \"0.0.0.0/0\" } | Sort-Object RouteMetric \ - | Select NextHop).NextHop; \ - $ifIndex = (Get-NetAdapter -InterfaceDescription \"Hyper-V Virtual Ethernet*\" | Sort-Object \ - | Select ifIndex).ifIndex; \ - New-NetRoute -DestinationPrefix 169.254.169.254/32 -InterfaceIndex $ifIndex -NextHop $gateway - -# Enable Long Paths for Win32 File/Folder APIs. -RUN New-ItemProperty -Path HKLM:\SYSTEM\CurrentControlSet\Control\FileSystem \ - -Name LongPathsEnabled -Value 1 -PropertyType DWORD -Force - -# Install Visual C++ Redistributable for Visual Studio 2015-2022. -RUN New-Item -Path "C:/" -Name "TEMP" -ItemType "directory"; \ - Invoke-WebRequest "https://aka.ms/vs/17/release/vc_redist.x64.exe" \ - -OutFile C:/TEMP/vc_redist.x64.exe -UseBasicParsing; \ - Start-Process -filepath C:/TEMP/vc_redist.x64.exe -ArgumentList '/install', '/passive', '/norestart' -Wait; \ - Remove-Item C:/TEMP/vc_redist.x64.exe - -# Install Visual Studio 2022 Build Tools. Install ManagedDesktopBuildTools separately to ensure all Optional workloads are installed too. -RUN Invoke-WebRequest "https://aka.ms/vs/17/release/vs_buildtools.exe" \ - -OutFile C:/TEMP/vs_buildtools.exe -UseBasicParsing; \ - Start-Process -FilePath C:/TEMP/vs_buildtools.exe -ArgumentList "--installPath", "C:/VS", \ - "--quiet", "--wait", "--nocache", \ - "--add", "Microsoft.VisualStudio.Workload.VCTools", \ - "--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", \ - "--add", "Microsoft.VisualStudio.Component.Windows10SDK.19041" -Wait; \ - Start-Process -FilePath C:/TEMP/vs_buildtools.exe -ArgumentList "--installPath", "C:/VS", \ - "--quiet", "--wait", "--nocache", "--includeOptional", \ - "--add", "Microsoft.VisualStudio.Workload.ManagedDesktopBuildTools" -Wait; \ - Remove-Item C:/TEMP/vs_buildtools.exe; \ - [Environment]::SetEnvironmentVariable(\"BAZEL_VC\", \"C:\VS\VC\", \"Machine\"); \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\VS\VC\Tools\MSVC\14.33.31629\bin\Hostx64\x64;C:\VS\Common7\Tools;C:\VS\MSBuild\Current\Bin\", \"Machine\"); - -# Add signtool.exe to the PATH. Note this path may need to be edited if updates -# are made to the Windows 10 SDK. -RUN $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Program Files (x86)\Windows Kits\10\App Certification Kit\", \"Machine\"); - -# Install WiX toolset (v4) - Necessary for MSI Installer/Signing builds -RUN dotnet tool install --global wix - -# Install msys2, packages and add to path. -RUN [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; \ - Invoke-WebRequest "https://repo.msys2.org/distrib/x86_64/msys2-base-x86_64-20220319.sfx.exe" \ - -OutFile msys2_install.exe -UseBasicParsing; \ - .\msys2_install.exe -y -oC:\; \ - Remove-Item msys2_install.exe; \ - function msys() { C:\msys64\usr\bin\bash.exe @('-lc') + @Args; } \ - msys ' '; \ - msys 'pacman --noconfirm -Syy bsdcpio bsdtar bzip2'; \ - msys 'pacman --noconfirm -Syy coreutils curl dash file filesystem findutils'; \ - msys 'pacman --noconfirm -Syy flex gawk gcc-libs grep gzip inetutils info'; \ - msys 'pacman --noconfirm -Syy less lndir mintty ncurses pactoys-git patch'; \ - msys 'pacman --noconfirm -Syy pax-git pkgfile rebase sed tar tftp-hpa time tzcode util-linux which'; \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\msys64;C:\msys64\usr\bin\", \"Machine\"); - -# Install Go 1.19.1 -RUN Invoke-WebRequest "https://go.dev/dl/go1.19.1.windows-amd64.msi" \ - -OutFile C:/TEMP/go_install.msi -UseBasicParsing; \ - Start-Process C:/TEMP/go_install.msi -ArgumentList "/quiet", "/log", "C:/TEMP/go_install_log.txt", \ - "InstallAllUsers=1", "PrependPath=1" -wait; \ - Remove-Item C:/TEMP/go_install.msi; \ - Remove-Item C:/TEMP/go_install_log.txt - -# Install Python 3. -RUN Invoke-WebRequest "https://www.python.org/ftp/python/3.10.4/python-3.10.4-amd64.exe" \ - -OutFile C:/TEMP/python_install.exe -UseBasicParsing; \ - Start-Process C:/TEMP/python_install.exe -ArgumentList "/quiet", "/log", "C:/TEMP/python_install_log.txt", \ - "InstallAllUsers=1", "PrependPath=1" -wait; \ - Remove-Item C:/TEMP/python_install.exe; \ - Remove-Item C:/TEMP/python_install_log.txt - -# Install JDK 17 -RUN Add-Type -AssemblyName "System.IO.Compression.FileSystem"; \ - $zulu_url = \"https://cdn.azul.com/zulu/bin/zulu17.32.13-ca-jdk17.0.2-win_x64.zip\"; \ - $zulu_zip = \"c:/temp/jdk_install.zip\"; \ - $zulu_extracted_path = \"c:/temp/\" + [IO.Path]::GetFileNameWithoutExtension($zulu_url); \ - $zulu_root = \"c:/openjdk\"; \ - (New-Object Net.WebClient).DownloadFile($zulu_url, $zulu_zip); \ - [System.IO.Compression.ZipFile]::ExtractToDirectory($zulu_zip, \"c:/temp\"); \ - Move-Item $zulu_extracted_path -Destination $zulu_root; \ - Remove-Item $zulu_zip; \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";${zulu_root}\bin\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"JAVA_HOME\", $zulu_root, \"Machine\") - -# Install gcloud (install.bat installs directly into bin folder of extracted zip contents) -# Install needed gcloud components -RUN Add-Type -AssemblyName "System.IO.Compression.FileSystem"; \ - $pkg_url = \"https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-396.0.0-windows-x86_64.zip\"; \ - $pkg_zip = \"c:/temp/gcloud.zip\"; \ - $pkg_extracted_path = \"c:/google-cloud-sdk\"; \ - (New-Object Net.WebClient).DownloadFile($pkg_url, $pkg_zip); \ - [System.IO.Compression.ZipFile]::ExtractToDirectory($pkg_zip, \"c:/\"); \ - Start-Process cmd.exe -ArgumentList "/c", "/s", "$pkg_extracted_path/install.bat", "-q" -wait; \ - Remove-Item $pkg_zip; \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";${pkg_extracted_path}\bin\", \"Machine\"); \ - $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \ - gcloud components install docker-credential-gcr kubectl gsutil; - -# Install cygwin and packages -# Running a seperate ps1 file since when running inside a Dockerfile, it does -# not work. -COPY install/install_cygwin.ps1 c:/ -RUN c:/install_cygwin.ps1; \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Cygwin64\bin\", \"Machine\"); -RUN Remove-Item c:/install_cygwin.ps1 - -# Install Chocolatey and packages -RUN Invoke-Expression ((New-Object Net.WebClient).DownloadString('https://chocolatey.org/install.ps1')); \ - $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \ - choco feature enable -n allowGlobalConfirmation; \ - choco install 7zip; \ - choco install 7zip.install; \ - choco install 7zip.portable; \ - choco install anaconda2 --version 5.0.1; \ - choco install anaconda3 --version 5.0.1; \ - choco install android-sdk --version 25.2.3.1; \ - choco install AndroidStudio --version 3.0.1.0; \ - choco install ant --version 1.10.1; \ - choco install ccleaner; \ - choco install chocolatey; \ - choco install chocolatey-core.extension; \ - choco install chocolatey-visualstudio.extension; \ - choco install chocolatey-windowsupdate.extension; \ - choco install cmake.install; \ - choco install dotnetcore-sdk; \ - choco install git; \ - choco install git.install; \ - choco install GoogleChrome; \ - choco install gradle --version 4.4.1; \ - choco install jdk8; \ - choco install KB2533623; \ - choco install KB2919355; \ - choco install KB2919442; \ - choco install KB2999226; \ - choco install KB3033929; \ - choco install KB3035131; \ - choco install maven; \ - choco install ninja; \ - choco install nodejs --version 9.3.0; \ - choco install nodejs.install --version 9.3.0; \ - choco install nuget.commandline; \ - choco install openjdk11; \ - choco install peazip; \ - choco install peazip.install; \ - choco install peazip.portable; \ - choco install php --version 7.2.0; \ - choco install protoc --version 3.2.0; \ - choco install ruby --version 2.5.0.1; \ - choco install swig --version 3.0.9; \ - choco install sysinternals; \ - choco install unrar; \ - choco install unzip; \ - choco install vcredist140; \ - choco install vcredist2015; \ - choco install vim; \ - choco install winrar; \ - choco install zip; \ - choco install Firefox; \ - choco install iisexpress; - -RUN cmd /c 'mklink /J c:\Anaconda c:\tools\anaconda2'; -RUN cmd /c 'mklink c:\programdata\chocolatey\bin\rar.exe \"c:\program files\winrar\rar.exe\"'; - -# Installing pip packages -RUN pip install --upgrade setuptools; \ - pip install altgraph appdirs cachetools certifi cffi chardet colorama \ - cryptography cycler Cython decorator google-api-python-client \ - google-auth google-auth-httplib2 grpcio httplib2 idna ipython-genutils \ - kiwisolver macholib matplotlib nose numpy packaging pandas pickleshare pip \ - prompt-toolkit protobuf psutil pyasn1 pyasn1-modules pycparser Pygments \ - pyparsing pyreadline python-dateutil pytz pywin32 requests rsa setuptools \ - simplegeneric six Tempita traitlets uritemplate urllib3 virtualenv wcwidth \ - wheel win-unicode-console; - -# Hardcoding Android license since I did not find any solution on accepting it -# through the docker build command. If the licensing agreement changes, this -# will need to be updated as well. -RUN New-Item -ItemType Directory -Path C:\Android\android-sdk\licenses; \ - Set-Content -Path .\Android\android-sdk\licenses\android-sdk-license -Value "`n24333f8a63b6825ea9c5514f83c2829b004d1fee" -NoNewLine; - -# Add sdkmanager to PATH -RUN $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Android\android-sdk\tools\bin\", \"Machine\"); - -# Install android packages -RUN $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \ - New-Item C:\Users\ContainerAdministrator\.android\repositories.cfg; \ - sdkmanager 'ndk-bundle'; \ - sdkmanager 'platforms;android-33'; \ - sdkmanager 'add-ons;addon-google_apis-google-24'; \ - sdkmanager 'cmake;3.10.2.4988404'; \ - sdkmanager 'cmake;3.18.1'; \ - sdkmanager 'cmake;3.22.1'; \ - sdkmanager 'cmake;3.6.4111459'; \ - sdkmanager 'emulator'; \ - sdkmanager 'system-images;android-27;google_apis;x86'; \ - sdkmanager 'sources;android-27'; \ - sdkmanager 'extras;google;Android_Emulator_Hypervisor_Driver'; \ - sdkmanager 'extras;google;auto'; \ - sdkmanager 'extras;google;google_play_services'; \ - sdkmanager 'extras;google;instantapps'; \ - sdkmanager 'extras;google;m2repository'; \ - sdkmanager 'extras;google;market_apk_expansion'; \ - sdkmanager 'extras;google;market_licensing'; \ - sdkmanager 'extras;google;simulators'; \ - sdkmanager 'extras;google;usb_driver'; \ - sdkmanager 'extras;google;webdriver'; \ - sdkmanager 'extras;android;m2repository'; \ - sdkmanager 'extras;intel;Hardware_Accelerated_Execution_Manager'; \ - sdkmanager 'extras;m2repository;com;android;support;constraint;constraint-layout;1.0.0'; \ - sdkmanager 'extras;m2repository;com;android;support;constraint;constraint-layout-solver;1.0.2'; \ - sdkmanager 'patcher;v4'; \ - sdkmanager 'ndk;25.1.8937393'; \ - sdkmanager 'build-tools;27.0.3'; - -# Install Scoop and packages -RUN iex \"& {$(irm get.scoop.sh)} -RunAsAdmin\"; \ - scoop install perl; \ - scoop install bazel; \ - scoop install cuda; \ - scoop install azure-functions-core-tools; \ - scoop install azure-cli; - -# Setting environment variables -RUN [Environment]::SetEnvironmentVariable('CYGWIN', 'winsymlinks:native', 'Machine'); \ - [Environment]::SetEnvironmentVariable('HOME', 'C:\Users\ContainerAdministrator\', 'Machine'); \ - [Environment]::SetEnvironmentVariable('HOMEDRIVE', 'C:', 'Machine'); \ - [Environment]::SetEnvironmentVariable('HOMEPATH', '\Users\ContainerAdministrator\', 'Machine'); \ - [Environment]::SetEnvironmentVariable('GOROOT', 'C:\Program Files\Go\', 'Machine'); \ - [Environment]::SetEnvironmentVariable('KOKORO_POSIX_ROOT', '/tmpfs', 'Machine'); \ - [Environment]::SetEnvironmentVariable('KOKORO_ROOT', 'T:\', 'Machine'); \ - [Environment]::SetEnvironmentVariable('SHELL', '/bin/bash', 'Machine'); \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Program Files\CMake\bin\", \"Machine\"); - - -# Restore default shell for Windows containers. -SHELL ["cmd.exe", "/s", "/c"] - -# Default to PowerShell if no other command specified. -CMD ["powershell.exe", "-NoLogo", "-ExecutionPolicy", "Bypass"] diff --git a/ci/official/containers/linux_arm64/Dockerfile b/ci/official/containers/linux_arm64/Dockerfile index c2161dfe4ad6f3..428347a5b6a847 100644 --- a/ci/official/containers/linux_arm64/Dockerfile +++ b/ci/official/containers/linux_arm64/Dockerfile @@ -62,6 +62,9 @@ COPY devel.usertools /usertools COPY devel.bashrc /root/.bashrc COPY ld.so.conf /dt10/etc/ +# Make sure clang is on the path +RUN ln -s /usr/lib/llvm-18/bin/clang /usr/bin/clang + # Setup JAX Python environment. FROM devel as jax RUN /setup.packages.sh /cuda.packages.txt diff --git a/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats b/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats index 85cbc7b7058148..cdfc81499af7f0 100644 --- a/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats +++ b/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats @@ -216,6 +216,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ "somepath(//tensorflow/tools/pip_package:build_pip_package, " \ "@local_config_cuda//cuda:cudart + "\ "@local_config_cuda//cuda:cudart + "\ @@ -237,6 +239,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ --define framework_shared_object=false \ "somepath(//tensorflow/tools/pip_package:build_pip_package, " \ "@local_config_cuda//cuda:cudart + "\ diff --git a/ci/official/envs/rbe b/ci/official/envs/rbe index 31f204e281ed39..35f817310b2f36 100644 --- a/ci/official/envs/rbe +++ b/ci/official/envs/rbe @@ -38,7 +38,8 @@ if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then # port-forwarding is required for the container to detect it's running on GCE. export IP_ADDR=$(powershell -command "(Get-NetIPAddress -AddressFamily IPv4 -InterfaceAlias 'vEthernet (nat)').IPAddress") netsh interface portproxy add v4tov4 listenaddress=$IP_ADDR listenport=80 connectaddress=169.254.169.254 connectport=80 - # A firewall rule is added for the Docker container IP in setup_d + # A local firewall rule for the container is added in + # ci/official/utilities/setup_docker.sh. else # The volume mapping flag below shares the user's gcloud credentials, if any, # with the container, in case the user has credentials stored there. diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index 691ec3a3a025ae..ede80f4372bc14 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -216,6 +216,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ "somepath(//tensorflow/tools/pip_package:wheel, " \ "@local_config_cuda//cuda:cudart + "\ "@local_config_cuda//cuda:cudart + "\ @@ -237,6 +239,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ --define framework_shared_object=false \ "somepath(//tensorflow/tools/pip_package:wheel, " \ "@local_config_cuda//cuda:cudart + "\ diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh index c721bad577eee6..61db7c2e124d0a 100755 --- a/ci/official/utilities/setup_docker.sh +++ b/ci/official/utilities/setup_docker.sh @@ -47,7 +47,6 @@ if ! docker container inspect tf >/dev/null 2>&1 ; then sed -iE 's|^TFCI_OUTPUT_DIR=.*|TFCI_OUTPUT_DIR='"$_TFCI_OUTPUT_DIR_WIN"'|g' $env_file WORKING_DIR=$(replace_drive_letter_with_c "$TFCI_GIT_DIR") echo "GCE_METADATA_HOST=$IP_ADDR" > $env_file - # Allow requests from the container. fi docker run $TFCI_DOCKER_ARGS --name tf -w "$WORKING_DIR" -itd --rm \ @@ -58,8 +57,9 @@ if ! docker container inspect tf >/dev/null 2>&1 ; then if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then # Allow requests from the container. - CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' tf) - netsh advfirewall firewall add rule name="Allow Metadata Proxy" dir=in action=allow protocol=TCP localport=80 remoteip="$CONTAINER_IP" + # Additional setup is contained in ci/official/envs/rbe. + CONTAINER_IP_ADDR=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' tf) + netsh advfirewall firewall add rule name="Allow Metadata Proxy" dir=in action=allow protocol=TCP localport=80 remoteip="$CONTAINER_IP_ADDR" fi fi diff --git a/configure.py b/configure.py index 592f5c0d2117e1..50ed76e9f23d14 100644 --- a/configure.py +++ b/configure.py @@ -16,7 +16,6 @@ import argparse import errno -import glob import json import os import platform @@ -31,9 +30,6 @@ from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_DEFAULT_CUDA_VERSION = '11' -_DEFAULT_CUDNN_VERSION = '2' -_DEFAULT_TENSORRT_VERSION = '6' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _SUPPORTED_ANDROID_NDK_VERSIONS = [19, 20, 21, 25] @@ -128,6 +124,12 @@ def write_action_env_to_bazelrc(var_name, var): write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var))) +def write_repo_env_to_bazelrc(config_name, var_name, var): + write_to_bazelrc( + 'build:{} --repo_env {}="{}"'.format(config_name, var_name, str(var)) + ) + + def run_shell(cmd, allow_non_zero=False, stderr=None): if stderr is None: stderr = sys.stdout @@ -239,7 +241,7 @@ def setup_python(environ_cp): write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path)) environ_cp['PYTHON_BIN_PATH'] = python_bin_path - # If choosen python_lib_path is from a path specified in the PYTHONPATH + # If chosen python_lib_path is from a path specified in the PYTHONPATH # variable, need to tell bazel to include PYTHONPATH if environ_cp.get('PYTHONPATH'): python_paths = environ_cp.get('PYTHONPATH').split(':') @@ -778,11 +780,6 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path): def set_gcc_host_compiler_path(environ_cp): """Set GCC_HOST_COMPILER_PATH.""" default_gcc_host_compiler_path = which('gcc') or '' - cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH') - - if os.path.islink(cuda_bin_symlink): - # os.readlink is only available in linux - default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) gcc_host_compiler_path = prompt_loop_or_load_from_env( environ_cp, @@ -947,108 +944,42 @@ def disable_clang_offsetof_extension(clang_version): write_to_bazelrc('build --copt=-Wno-gnu-offsetof-extensions') -def set_tf_cuda_paths(environ_cp): - """Set TF_CUDA_PATHS.""" - ask_cuda_paths = ( - 'Please specify the comma-separated list of base paths to look for CUDA ' - 'libraries and headers. [Leave empty to use the default]: ') - tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS', - ask_cuda_paths, '') - if tf_cuda_paths: - environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths - - -def set_tf_cuda_version(environ_cp): - """Set TF_CUDA_VERSION.""" +def set_hermetic_cuda_version(environ_cp): + """Set HERMETIC_CUDA_VERSION.""" ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use. ' - '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - tf_cuda_version = get_from_env_or_user_or_default(environ_cp, - 'TF_CUDA_VERSION', - ask_cuda_version, - _DEFAULT_CUDA_VERSION) - environ_cp['TF_CUDA_VERSION'] = tf_cuda_version + 'Please specify the hermetic CUDA version you want to use ' + 'or leave empty to use the default version. ' + ) + hermetic_cuda_version = get_from_env_or_user_or_default( + environ_cp, 'HERMETIC_CUDA_VERSION', ask_cuda_version, None + ) + if hermetic_cuda_version: + environ_cp['HERMETIC_CUDA_VERSION'] = hermetic_cuda_version + write_repo_env_to_bazelrc( + 'cuda', 'HERMETIC_CUDA_VERSION', hermetic_cuda_version + ) -def set_tf_cudnn_version(environ_cp): - """Set TF_CUDNN_VERSION.""" +def set_hermetic_cudnn_version(environ_cp): + """Set HERMETIC_CUDNN_VERSION.""" ask_cudnn_version = ( - 'Please specify the cuDNN version you want to use. ' - '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION - tf_cudnn_version = get_from_env_or_user_or_default(environ_cp, - 'TF_CUDNN_VERSION', - ask_cudnn_version, - _DEFAULT_CUDNN_VERSION) - environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version - - -def set_tf_tensorrt_version(environ_cp): - """Set TF_TENSORRT_VERSION.""" - if not (is_linux() or is_windows()): - raise ValueError('Currently TensorRT is only supported on Linux platform.') - - if not int(environ_cp.get('TF_NEED_TENSORRT', False)): - return - - ask_tensorrt_version = ( - 'Please specify the TensorRT version you want to use. ' - '[Leave empty to default to TensorRT %s]: ') % _DEFAULT_TENSORRT_VERSION - tf_tensorrt_version = get_from_env_or_user_or_default( - environ_cp, 'TF_TENSORRT_VERSION', ask_tensorrt_version, - _DEFAULT_TENSORRT_VERSION) - environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version - - -def set_tf_nccl_version(environ_cp): - """Set TF_NCCL_VERSION.""" - if not is_linux(): - raise ValueError('Currently NCCL is only supported on Linux platform.') - - if 'TF_NCCL_VERSION' in environ_cp: - return - - ask_nccl_version = ( - 'Please specify the locally installed NCCL version you want to use. ' - '[Leave empty to use http://github.com/nvidia/nccl]: ') - tf_nccl_version = get_from_env_or_user_or_default(environ_cp, - 'TF_NCCL_VERSION', - ask_nccl_version, '') - environ_cp['TF_NCCL_VERSION'] = tf_nccl_version - - -def get_native_cuda_compute_capabilities(environ_cp): - """Get native cuda compute capabilities. - - Args: - environ_cp: copy of the os.environ. - - Returns: - string of native cuda compute capabilities, separated by comma. - """ - device_query_bin = os.path.join( - environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery') - if os.path.isfile(device_query_bin) and os.access(device_query_bin, os.X_OK): - try: - output = run_shell(device_query_bin).split('\n') - pattern = re.compile('\d*\\.\d*') - output = [pattern.search(x) for x in output if 'Capability' in x] - output = ','.join(x.group() for x in output if x is not None) - except subprocess.CalledProcessError: - output = '' - else: - output = '' - return output + 'Please specify the hermetic cuDNN version you want to use ' + 'or leave empty to use the default version. ' + ) + hermetic_cudnn_version = get_from_env_or_user_or_default( + environ_cp, 'HERMETIC_CUDNN_VERSION', ask_cudnn_version, None + ) + if hermetic_cudnn_version: + environ_cp['HERMETIC_CUDNN_VERSION'] = hermetic_cudnn_version + write_repo_env_to_bazelrc( + 'cuda', 'HERMETIC_CUDNN_VERSION', hermetic_cudnn_version + ) -def set_tf_cuda_compute_capabilities(environ_cp): - """Set TF_CUDA_COMPUTE_CAPABILITIES.""" +def set_hermetic_cuda_compute_capabilities(environ_cp): + """Set HERMETIC_CUDA_COMPUTE_CAPABILITIES.""" while True: - native_cuda_compute_capabilities = get_native_cuda_compute_capabilities( - environ_cp) - if not native_cuda_compute_capabilities: - default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES - else: - default_cuda_compute_capabilities = native_cuda_compute_capabilities + default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES ask_cuda_compute_capabilities = ( 'Please specify a list of comma-separated CUDA compute capabilities ' @@ -1060,15 +991,20 @@ def set_tf_cuda_compute_capabilities(environ_cp): 'significantly increases your build time and binary size, and that ' 'TensorFlow only supports compute capabilities >= 3.5 [Default is: ' '%s]: ' % default_cuda_compute_capabilities) - tf_cuda_compute_capabilities = get_from_env_or_user_or_default( - environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', - ask_cuda_compute_capabilities, default_cuda_compute_capabilities) + hermetic_cuda_compute_capabilities = get_from_env_or_user_or_default( + environ_cp, + 'HERMETIC_CUDA_COMPUTE_CAPABILITIES', + ask_cuda_compute_capabilities, + default_cuda_compute_capabilities, + ) # Check whether all capabilities from the input is valid all_valid = True # Remove all whitespace characters before splitting the string # that users may insert by accident, as this will result in error - tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) - for compute_capability in tf_cuda_compute_capabilities.split(','): + hermetic_cuda_compute_capabilities = ''.join( + hermetic_cuda_compute_capabilities.split() + ) + for compute_capability in hermetic_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: # We now support sm_35,sm_50,sm_60,compute_70. @@ -1103,15 +1039,32 @@ def set_tf_cuda_compute_capabilities(environ_cp): break # Reset and Retry - environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = '' + environ_cp['HERMETIC_CUDA_COMPUTE_CAPABILITIES'] = '' - # Set TF_CUDA_COMPUTE_CAPABILITIES - environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = tf_cuda_compute_capabilities - write_action_env_to_bazelrc( - 'TF_CUDA_COMPUTE_CAPABILITIES', tf_cuda_compute_capabilities + # Set HERMETIC_CUDA_COMPUTE_CAPABILITIES + environ_cp['HERMETIC_CUDA_COMPUTE_CAPABILITIES'] = ( + hermetic_cuda_compute_capabilities + ) + write_repo_env_to_bazelrc( + 'cuda', + 'HERMETIC_CUDA_COMPUTE_CAPABILITIES', + hermetic_cuda_compute_capabilities, ) +def set_cuda_local_path(environ_cp, dist_name, env_var): + ask_path = ( + 'Please specify the local {} path you want to use ' + 'or leave empty to use the default version. ' + ).format(dist_name) + local_path = get_from_env_or_user_or_default( + environ_cp, env_var, ask_path, None + ) + if local_path: + environ_cp[env_var] = local_path + write_repo_env_to_bazelrc('cuda', env_var, local_path) + + def set_other_cuda_vars(environ_cp): """Set other CUDA related variables.""" # If CUDA is enabled, always use GPU during build and test. @@ -1227,73 +1180,6 @@ def configure_ios(environ_cp): symlink_force(filepath, new_filepath) -def validate_cuda_config(environ_cp): - """Run find_cuda_config.py and return cuda_toolkit_path, or None.""" - - def maybe_encode_env(env): - """Encodes unicode in env to str on Windows python 2.x.""" - if not is_windows() or sys.version_info[0] != 2: - return env - for k, v in env.items(): - if isinstance(k, unicode): - k = k.encode('ascii') - if isinstance(v, unicode): - v = v.encode('ascii') - env[k] = v - return env - - cuda_libraries = ['cuda', 'cudnn'] - if is_linux(): - if int(environ_cp.get('TF_NEED_TENSORRT', False)): - cuda_libraries.append('tensorrt') - if environ_cp.get('TF_NCCL_VERSION', None): - cuda_libraries.append('nccl') - if is_windows(): - if int(environ_cp.get('TF_NEED_TENSORRT', False)): - cuda_libraries.append('tensorrt') - print('WARNING: TensorRT support on Windows is experimental\n') - - paths = glob.glob('**/third_party/gpus/find_cuda_config.py', recursive=True) - if not paths: - raise FileNotFoundError( - "Can't find 'find_cuda_config.py' script inside working directory") - proc = subprocess.Popen( - [environ_cp['PYTHON_BIN_PATH'], paths[0]] + cuda_libraries, - stdout=subprocess.PIPE, - env=maybe_encode_env(environ_cp)) - - if proc.wait(): - # Errors from find_cuda_config.py were sent to stderr. - print('Asking for detailed CUDA configuration...\n') - return False - - config = dict( - tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout) - - print('Found CUDA %s in:' % config['cuda_version']) - print(' %s' % config['cuda_library_dir']) - print(' %s' % config['cuda_include_dir']) - - print('Found cuDNN %s in:' % config['cudnn_version']) - print(' %s' % config['cudnn_library_dir']) - print(' %s' % config['cudnn_include_dir']) - - if 'tensorrt_version' in config: - print('Found TensorRT %s in:' % config['tensorrt_version']) - print(' %s' % config['tensorrt_library_dir']) - print(' %s' % config['tensorrt_include_dir']) - - if config.get('nccl_version', None): - print('Found NCCL %s in:' % config['nccl_version']) - print(' %s' % config['nccl_library_dir']) - print(' %s' % config['nccl_include_dir']) - - print('\n') - - environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path'] - return True - - def get_gcc_compiler(environ_cp): gcc_env = environ_cp.get('CXX') or environ_cp.get('CC') or which('gcc') if gcc_env is not None: @@ -1344,9 +1230,6 @@ def main(): environ_cp['TF_DOWNLOAD_CLANG'] = '0' environ_cp['TF_NEED_MPI'] = '0' - if is_macos(): - environ_cp['TF_NEED_TENSORRT'] = '0' - if is_ppc64le(): # Enable MMA Dynamic Dispatch support if 'gcc' and if linker >= 2.35 gcc_env = get_gcc_compiler(environ_cp) @@ -1395,62 +1278,14 @@ def main(): else: environ_cp['TF_NEED_CUDA'] = str( int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) - if (environ_cp.get('TF_NEED_CUDA') == '1' and - 'TF_CUDA_CONFIG_REPO' not in environ_cp): - - set_action_env_var( - environ_cp, - 'TF_NEED_TENSORRT', - 'TensorRT', - False, - bazel_config_name='tensorrt') - - environ_save = dict(environ_cp) - for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): - - if validate_cuda_config(environ_cp): - cuda_env_names = [ - 'TF_CUDA_VERSION', - 'TF_CUBLAS_VERSION', - 'TF_CUDNN_VERSION', - 'TF_TENSORRT_VERSION', - 'TF_NCCL_VERSION', - 'TF_CUDA_PATHS', - # Items below are for backwards compatibility when not using - # TF_CUDA_PATHS. - 'CUDA_TOOLKIT_PATH', - 'CUDNN_INSTALL_PATH', - 'NCCL_INSTALL_PATH', - 'NCCL_HDR_PATH', - 'TENSORRT_INSTALL_PATH' - ] - # Note: set_action_env_var above already writes to bazelrc. - for name in cuda_env_names: - if name in environ_cp: - write_action_env_to_bazelrc(name, environ_cp[name]) - break - - # Restore settings changed below if CUDA config could not be validated. - environ_cp = dict(environ_save) - - set_tf_cuda_version(environ_cp) - set_tf_cudnn_version(environ_cp) - if is_windows(): - set_tf_tensorrt_version(environ_cp) - if is_linux(): - set_tf_tensorrt_version(environ_cp) - set_tf_nccl_version(environ_cp) - - set_tf_cuda_paths(environ_cp) - - else: - raise UserInputError( - 'Invalid CUDA setting were provided %d ' - 'times in a row. Assuming to be a scripting mistake.' - % _DEFAULT_PROMPT_ASK_ATTEMPTS - ) + if environ_cp.get('TF_NEED_CUDA') == '1': + set_hermetic_cuda_version(environ_cp) + set_hermetic_cudnn_version(environ_cp) + set_hermetic_cuda_compute_capabilities(environ_cp) + set_cuda_local_path(environ_cp, 'CUDA', 'LOCAL_CUDA_PATH') + set_cuda_local_path(environ_cp, 'CUDNN', 'LOCAL_CUDNN_PATH') + set_cuda_local_path(environ_cp, 'NCCL', 'LOCAL_NCCL_PATH') - set_tf_cuda_compute_capabilities(environ_cp) if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( 'LD_LIBRARY_PATH') != '1': write_action_env_to_bazelrc('LD_LIBRARY_PATH', diff --git a/tensorflow/BUILD b/tensorflow/BUILD index c96cd8c4244797..5ebf9b1fa20fed 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1053,7 +1053,7 @@ package_group( "//learning/serving/experimental/remote_predict/...", "//perftools/accelerators/xprof/convert/...", "//perftools/accelerators/xprof/integration_tests/...", - "//smartass/brain/configure/...", + "//smartass/brain/...", "//tensorflow/...", "//tensorflow_decision_forests/...", "//tensorflow_federated/...", @@ -1350,7 +1350,7 @@ tf_cc_shared_library( "//tensorflow/core:tensorflow", "//tensorflow/core/data:standalone", # Exports for pywrap_tensorflow_internal. Many of these are transitive - # depedencies of the above, but must be explicitly listed for + # dependencies of the above, but must be explicitly listed for # cc_shared_library to work. "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_internal", diff --git a/tensorflow/c/experimental/ops/gen/model/BUILD b/tensorflow/c/experimental/ops/gen/model/BUILD index 918acaabb6b8cb..89e51ec57df46e 100644 --- a/tensorflow/c/experimental/ops/gen/model/BUILD +++ b/tensorflow/c/experimental/ops/gen/model/BUILD @@ -9,13 +9,10 @@ cc_library( srcs = glob(["*.cc"]), hdrs = glob(["*.h"]), deps = [ - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:str_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/c/experimental/ops/gen/model/arg_spec.cc b/tensorflow/c/experimental/ops/gen/model/arg_spec.cc index 2a9dd4882d92de..43e3b3f0b8bfa9 100644 --- a/tensorflow/c/experimental/ops/gen/model/arg_spec.cc +++ b/tensorflow/c/experimental/ops/gen/model/arg_spec.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/model/arg_spec.h" +#include "tensorflow/c/experimental/ops/gen/model/arg_type.h" +#include "tensorflow/core/framework/op_def.pb.h" + namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/model/arg_type.cc b/tensorflow/c/experimental/ops/gen/model/arg_type.cc index afc05adc16788f..9286e2dd6f09cd 100644 --- a/tensorflow/c/experimental/ops/gen/model/arg_type.cc +++ b/tensorflow/c/experimental/ops/gen/model/arg_type.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/model/arg_type.h" +#include "tensorflow/core/framework/op_def.pb.h" + namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/model/attr_spec.cc b/tensorflow/c/experimental/ops/gen/model/attr_spec.cc index 59a7a16cff3f5e..ae27a352694d98 100644 --- a/tensorflow/c/experimental/ops/gen/model/attr_spec.cc +++ b/tensorflow/c/experimental/ops/gen/model/attr_spec.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/model/attr_spec.h" -#include "tensorflow/core/lib/strings/str_util.h" +#include "absl/strings/match.h" +#include "tensorflow/core/framework/op_def.pb.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/model/op_spec.cc b/tensorflow/c/experimental/ops/gen/model/op_spec.cc index d590e2dfddc80e..1adc0c45d40291 100644 --- a/tensorflow/c/experimental/ops/gen/model/op_spec.cc +++ b/tensorflow/c/experimental/ops/gen/model/op_spec.cc @@ -17,6 +17,11 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "tensorflow/c/experimental/ops/gen/model/arg_spec.h" +#include "tensorflow/c/experimental/ops/gen/model/attr_spec.h" +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc index 58c7a22fee787c..630638b6c6cc9f 100644 --- a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc +++ b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc @@ -175,7 +175,7 @@ Status InitPluginProfiler(TFInitProfilerFn init_fn) { return factory.CreatePluggableProfiler(options); }; - tensorflow::profiler::RegisterProfilerFactory(std::move(create_func)); + tsl::profiler::RegisterProfilerFactory(std::move(create_func)); return OkStatus(); } diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 2c91c8d0f402d6..36e4b8e5e66b67 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -448,24 +448,21 @@ CPlatform::DescriptionForDevice(int ordinal) const { builder.set_name(name_); return builder.Build(); } -absl::StatusOr CPlatform::ExecutorForDevice(int ordinal) { - stream_executor::StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); +absl::StatusOr CPlatform::FindExisting(int ordinal) { + return executor_cache_.Get(ordinal); } -absl::StatusOr CPlatform::GetExecutor( - const StreamExecutorConfig& config) { +absl::StatusOr CPlatform::ExecutorForDevice(int ordinal) { return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> CPlatform::GetUncachedExecutor( - const StreamExecutorConfig& config) { + int ordinal) { // Fill device creation params SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE}; SP_Device device{SP_DEVICE_STRUCT_SIZE}; device_params.device = &device; device_params.ext = nullptr; - device_params.ordinal = config.ordinal; + device_params.ordinal = ordinal; OwnedTFStatus c_status(TF_NewStatus()); // Create Device diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 1a525b2e4179e7..769f640d6968d2 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -98,15 +98,14 @@ class CPlatform : public Platform { absl::StatusOr> DescriptionForDevice( int ordinal) const override; absl::StatusOr ExecutorForDevice(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; + absl::StatusOr FindExisting(int ordinal) override; private: - // Returns a device constructed with the options specified in "config" without + // Returns a device constructed with the ordinal without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config); + int ordinal); SP_Platform platform_; void (*destroy_platform_)(SP_Platform*); diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 4a6e752b984fd3..fc351b2cd829b5 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -167,17 +167,14 @@ cc_library( deps = [ ":tfcompile_lib", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", - "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", - "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:status", "@local_xla//xla:debug_options_flags", - "@local_xla//xla/service:compiler", ], ) diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index e2ab2504319e80..b6b70a6f04d0f5 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -18,29 +18,18 @@ limitations under the License. #include #include -#include "absl/strings/match.h" -#include "absl/strings/str_join.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "xla/debug_options_flags.h" -#include "xla/service/compiler.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tsl/platform/status.h" namespace tensorflow { namespace tfcompile { diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 98576eed52361c..6efe665f4c9f99 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -518,7 +518,6 @@ cc_library( ":internal", # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp. "//learning/brain/tfrt/tf_tpu:__pkg__", - "//learning/brain/tfrt/tpu_plugin:__pkg__", "//learning/brain/tfrt/tpu_common:__pkg__", "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], @@ -539,9 +538,6 @@ cc_library( ":internal", # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp. "//learning/brain/tfrt/tf_tpu:__pkg__", - "//learning/brain/tfrt/tpu_plugin:__pkg__", - "//learning/brain/tfrt/tpu_common:__pkg__", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], deps = [ ":variable_info", @@ -612,8 +608,6 @@ cc_library( # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp. "//learning/brain/tfrt/tf_tpu:__pkg__", "//learning/brain/tfrt/tpu_plugin:__pkg__", - "//learning/brain/tfrt/tpu_common:__pkg__", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", "//tensorflow/core/tfrt/gpu/kernel:__pkg__", ], deps = [ @@ -726,7 +720,6 @@ cc_library( hdrs = ["xla_compile_util.h"], visibility = [ ":internal", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", "//tensorflow/core/tfrt/gpu/kernel:__pkg__", ], deps = [ @@ -770,10 +763,7 @@ cc_library( name = "device_compiler", hdrs = ["device_compiler.h"], copts = tf_copts(), - visibility = [ - ":internal", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", - ], + visibility = [":internal"], deps = [ ":device_compilation_cache", ":device_compilation_cluster_signature", @@ -1118,7 +1108,6 @@ cc_library( ], visibility = [ ":internal", - "//tensorflow/core/tfrt/utils:__pkg__", "//third_party/cloud_tpu/inference_converter:__pkg__", "//waymo/onboard/ml/chauffeur_net:__pkg__", ], @@ -1564,10 +1553,7 @@ cc_library( name = "device_compiler_client", srcs = ["device_compiler_client.cc"], hdrs = ["device_compiler_client.h"], - visibility = [ - ":internal", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", - ], + visibility = [":internal"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core/util:determinism", @@ -1596,6 +1582,7 @@ cc_library( cc_library( name = "device_executable_persistor", + srcs = ["device_executable_persistor.cc"], hdrs = ["device_executable_persistor.h"], deps = [ ":xla_compilation_cache_proto_cc", @@ -1608,6 +1595,8 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:statusor", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:util", "@local_xla//xla/pjrt:pjrt_client", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index ef46760f5065b4..5421637e80e5e0 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -734,7 +734,7 @@ static auto const ops_triggering_xla_compilation = "XlaVariadicSort", "XlaWhile"}; -static bool NodeCanTriggerXlaCompilation(const NodeDef& node) { +bool NodeCanTriggerXlaCompilation(const NodeDef& node) { return node.attr().find(kXlaClusterIdAttr) != node.attr().end() || HasBoolAttr(node, kXlaMustCompileAttr) || HasBoolAttr(node, kXlaCompileAttr) || diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 7c38cc92c541b7..18f6e5197b9cae 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -333,6 +333,9 @@ tensorflow::MemoryTypeVector GetOutputMemoryTypes( // Check whether graph can trigger XLA compilation. bool CanTriggerXlaCompilation(const GraphDef& graph); +// Returns true iff the node can trigger XLA compilation. +bool NodeCanTriggerXlaCompilation(const NodeDef& node); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ diff --git a/tensorflow/compiler/jit/device_executable_persistor.cc b/tensorflow/compiler/jit/device_executable_persistor.cc new file mode 100644 index 00000000000000..b673af75cbdcd9 --- /dev/null +++ b/tensorflow/compiler/jit/device_executable_persistor.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_executable_persistor.h" + +#include + +#include "absl/strings/str_cat.h" + +namespace tensorflow { + +std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key) { + static constexpr char kXlaSerializedCacheKeySeparator[] = "__"; + return absl::StrCat( + key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator, + key.signature_fingerprint(), kXlaSerializedCacheKeySeparator, + key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator, + key.device_type(), + key.compiled_using_pjrt() + ? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt") + : "", + ".pb"); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h index 78d208942ed770..0f546c0f196acc 100644 --- a/tensorflow/compiler/jit/device_executable_persistor.h +++ b/tensorflow/compiler/jit/device_executable_persistor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/log/log.h" +#include "absl/status/status.h" #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h" #include "tensorflow/compiler/jit/xla_device_compiler_client.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -35,6 +36,9 @@ limitations under the License. namespace tensorflow { +// Returns the persisted compilation cache file name for the given key. +std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key); + // Offers a way to persist and/or load compiled `ExecutableType`s along with the // corresponding HLO (`CompilationResult`) to/from `persistent_cache_directory` // (if one was provided during construction) on disk using `ClientType`. @@ -142,8 +146,6 @@ class DeviceExecutablePersistor { const xla::HloModuleProto& hlo_module, const XlaSerializedCacheEntry& entry) const; - std::string XlaSerializedCacheKeyToString( - const XlaSerializedCacheKey& key) const; std::string GetFilePath(const XlaSerializedCacheKey& key) const; const DeviceType device_type_; @@ -172,25 +174,10 @@ DeviceExecutablePersistor:: persistent_cache_directory_read_only_( config.persistent_cache_directory_read_only) {} -template -std::string DeviceExecutablePersistor:: - XlaSerializedCacheKeyToString(const XlaSerializedCacheKey& key) const { - static constexpr char kXlaSerializedCacheKeySeparator[] = "__"; - return absl::StrCat( - key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator, - key.signature_fingerprint(), kXlaSerializedCacheKeySeparator, - key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator, - key.device_type(), - key.compiled_using_pjrt() - ? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt") - : ""); -} - template std::string DeviceExecutablePersistor::GetFilePath( const XlaSerializedCacheKey& key) const { - const std::string file_name = - absl::StrCat(XlaSerializedCacheKeyToString(key), ".pb"); + const std::string file_name = XlaSerializedCacheKeyToFileName(key); return io::JoinPath(persistent_cache_directory_, file_name); } @@ -299,9 +286,10 @@ DeviceExecutablePersistor::SaveSerializedEntry( // Write to temp location, then when that completes, atomically move into the // final location. - std::string temp_path = io::JoinPath( - persistent_cache_directory_, XlaSerializedCacheKeyToString(entry.key())); - if (!env->CreateUniqueFileName(&temp_path, ".pb.tmp")) { + std::string temp_path = + io::JoinPath(persistent_cache_directory_, + XlaSerializedCacheKeyToFileName(entry.key())); + if (!env->CreateUniqueFileName(&temp_path, ".tmp")) { return absl::UnavailableError(absl::StrCat( "Could not create a unique file inside ", persistent_cache_directory_)); } diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 1ae4f6d4cd9938..0ef7156ef9f593 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -280,7 +280,6 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice, Node* static_shaped_slice) { - absl::InlinedVector edges_to_remove; std::vector slice_out_edges; absl::c_copy(slice->out_edges(), std::back_inserter(slice_out_edges)); for (const Edge* e : slice_out_edges) { diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index e689b4c0b3191c..c87dc83bdde956 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -7,6 +7,10 @@ """ load("@bazel_skylib//lib:paths.bzl", "paths") +load( + "@local_xla//xla:lit.bzl", + "lit_script_with_xla_gpu_cuda_data_dir", +) # Default values used by the test runner. _default_test_file_exts = ["mlir", ".pbtxt", ".td"] @@ -76,7 +80,8 @@ def glob_lit_tests( tags_override = {}, driver = _default_driver, features = [], - exec_properties = {}): + exec_properties = {}, + hermetic_cuda_data_dir = None): """Creates all plausible Lit tests (and their inputs) under this directory. Args: @@ -94,6 +99,8 @@ def glob_lit_tests( and specifying a default driver will abort the tests. features: [str], list of extra features to enable. exec_properties: a dictionary of properties to pass on. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. """ # Ignore some patterns by default for tests and input data. @@ -108,12 +115,24 @@ def glob_lit_tests( # failure. all_tests = [] for curr_test in tests: - all_tests.append(curr_test + ".test") + final_test_name = curr_test + if hermetic_cuda_data_dir: + output_file = "with_xla_gpu_cuda_data_dir_{}".format(curr_test) + rule_name = "script_{}".format(output_file) + lit_script_with_xla_gpu_cuda_data_dir( + rule_name, + curr_test, + output_file, + hermetic_cuda_data_dir, + ) + final_test_name = output_file + all_tests.append(final_test_name + ".test") # Instantiate this test with updated parameters. _run_lit_test( - name = curr_test + ".test", - data = data + [curr_test] + per_test_extra_data.get(curr_test, []), + name = final_test_name + ".test", + data = data + [final_test_name] + + per_test_extra_data.get(curr_test, []), size = size_override.get(curr_test, default_size), tags = default_tags + tags_override.get(curr_test, []), driver = driver, diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 92a64f24251543..78699e8418cf5e 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -31,6 +31,16 @@ package_group( ], ) +filegroup( + name = "tflite_internal_cc_3p_api_deps_src", + srcs = [ + "allocation.cc", + "allocation.h", + "mmap_allocation.cc", + ], + visibility = ["//tensorflow/lite:__pkg__"], +) + td_library( name = "tensorflow_lite_ops_td_files", srcs = [ @@ -81,7 +91,7 @@ gentbl_cc_library( ( [ "-gen-pass-decls", - "-name=TensorFlowLite", + "-name=TensorFlowLiteTd", ], "transforms/passes.h.inc", ), @@ -318,6 +328,13 @@ cc_library( ], ) +cc_library( + name = "stateful_error_reporter", + hdrs = ["stateful_error_reporter.h"], + compatible_with = get_compatible_with_portable(), + deps = ["//tensorflow/compiler/mlir/lite/core/api:error_reporter"], +) + gentbl_cc_library( name = "tensorflow_lite_canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), @@ -333,9 +350,29 @@ gentbl_cc_library( ) cc_library( - name = "tensorflow_lite", + name = "utils", + hdrs = ["utils/utils.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "attribute_utils", + srcs = ["utils/attribute_utils.cc"], + hdrs = ["utils/attribute_utils.h"], + deps = [ + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tensorflow_lite_ops", srcs = [ - "ir/tfl_canonicalize.inc", "ir/tfl_ops.cc", "ir/tfl_ops.cc.inc", "ir/tfl_ops.h.inc", @@ -347,22 +384,75 @@ cc_library( "ir/tfl_ops_interface.cc.inc", "ir/tfl_ops_interface.h.inc", "runtime_verifiers.inc", - "utils/attribute_utils.cc", ], hdrs = [ "ir/tfl_ops.h", + ], + deps = [ + ":converter_inc", + ":cost_estimators", + ":size_utils", + ":tensorflow_lite_canonicalize_inc_gen", + ":tensorflow_lite_op_enums_inc_gen", + ":tensorflow_lite_op_interfaces_inc_gen", + ":tensorflow_lite_ops_inc_gen", + ":utils", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:framework", + "//tensorflow/core/platform:status", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "tensorflow_lite", + srcs = [ + "ir/tfl_canonicalize.inc", + ], + hdrs = [ + "ir/tfl_ops.h", + "transforms/optimize.h", "transforms/passes.h", "utils/attribute_utils.h", "utils/utils.h", ], deps = [ + ":attribute_utils", ":converter_inc", ":cost_estimators", ":size_utils", ":tensorflow_lite_canonicalize_inc_gen", ":tensorflow_lite_op_enums_inc_gen", ":tensorflow_lite_op_interfaces_inc_gen", + ":tensorflow_lite_ops", ":tensorflow_lite_ops_inc_gen", + ":tensorflow_lite_optimize", ":tensorflow_lite_passes_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", @@ -389,6 +479,7 @@ cc_library( "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -772,14 +863,16 @@ cc_library( "transforms/optimize.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/optimize.h", ], deps = [ + ":attribute_utils", ":constant_utils", ":convert_type", - ":tensorflow_lite", + ":tensorflow_lite_ops", ":tensorflow_lite_optimize_inc_gen", ":tensorflow_lite_passes_inc_gen", + ":utils", ":validators", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", @@ -1236,7 +1329,7 @@ cc_library( "utils/convert_type.h", ], deps = [ - ":tensorflow_lite", + ":tensorflow_lite_ops", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:protos_all_cc", @@ -1323,6 +1416,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/lite/toco:toco_flags_proto_cc", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], @@ -1389,10 +1483,12 @@ cc_library( ":tensorflow_lite_quantize", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", + "//tensorflow/compiler/mlir/lite/stablehlo:build_stablehlo_composite", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", "//tensorflow/compiler/mlir/lite/stablehlo:composite_lowering", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", + "//tensorflow/compiler/mlir/lite/stablehlo:lift_callsite_loc_caller", "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", # buildcleaner: keep @@ -1430,8 +1526,10 @@ cc_library( "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/debug", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:converter_error_data_proto_cc", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantize_weights", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", @@ -1459,7 +1557,6 @@ cc_library( "//tensorflow/core/ir/types:Dialect", "//tensorflow/core/platform:status", "//tensorflow/lite/toco:toco_flags_proto_cc", - "//tensorflow/lite/tools/optimize:quantize_weights", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1477,6 +1574,7 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", "@stablehlo//:stablehlo_ops", @@ -1484,22 +1582,6 @@ cc_library( ], ) -cc_library( - name = "empty_passes", - hdrs = ["transforms/passes.h"], - visibility = [ - "//configs/devtools/hawkeye/tflite:__subpackages__", - "//learning/brain/models/app_benchmarks:__subpackages__", - ], - deps = [ - ":tensorflow_lite_passes_inc_gen", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", - "@com_google_absl//absl/container:flat_hash_set", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Pass", - ], -) - cc_library( name = "offset_buffer", hdrs = ["offset_buffer.h"], @@ -1539,6 +1621,32 @@ cc_library( visibility = ["//tensorflow/lite:__pkg__"], ) +exports_files(srcs = ["allocation.h"]) + +cc_library( + name = "allocation", + srcs = [ + "allocation.cc", + ] + select({ + ":tflite_mmap_disabled": [ + "mmap_allocation_disabled.cc", + ], + "//conditions:default": [ + "mmap_allocation.cc", + ], + }), + hdrs = [ + "allocation.h", + ], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts_warnings(), + visibility = [ + "//tensorflow/compiler/mlir/lite/core:__pkg__", + "//tensorflow/lite:__pkg__", + ], + deps = ["//tensorflow/compiler/mlir/lite/core/api:error_reporter"], +) + exports_files(srcs = ["utils/control_edges.h"]) cc_library( diff --git a/tensorflow/lite/allocation.cc b/tensorflow/compiler/mlir/lite/allocation.cc similarity index 97% rename from tensorflow/lite/allocation.cc rename to tensorflow/compiler/mlir/lite/allocation.cc index bbc41fabacbe17..3cad6908c889ad 100644 --- a/tensorflow/lite/allocation.cc +++ b/tensorflow/compiler/mlir/lite/allocation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" #include #include @@ -25,7 +25,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/allocation.h b/tensorflow/compiler/mlir/lite/allocation.h new file mode 100644 index 00000000000000..9ee9f4e846b71e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/allocation.h @@ -0,0 +1,156 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Memory management for TF Lite. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ + +#include + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite { + +/// A memory allocation handle. This could be a mmap or shared memory. +class Allocation { + public: + virtual ~Allocation() {} + + enum class Type { + kMMap, + kFileCopy, + kMemory, + }; + + /// Base pointer of this allocation + virtual const void* base() const = 0; + /// Size in bytes of the allocation + virtual size_t bytes() const = 0; + /// Whether the allocation is valid + virtual bool valid() const = 0; + /// Return the type of the Allocation. + Type type() const { return type_; } + + protected: + Allocation(ErrorReporter* error_reporter, Type type) + : error_reporter_(error_reporter), type_(type) {} + ErrorReporter* error_reporter_; + + private: + const Type type_; +}; + +/// Note that not all platforms support MMAP-based allocation. +/// Use `IsSupported()` to check. +class MMAPAllocation : public Allocation { + public: + /// Loads and maps the provided file to a memory region. + MMAPAllocation(const char* filename, ErrorReporter* error_reporter); + + /// Maps the provided file descriptor to a memory region. + /// Note: The provided file descriptor will be dup'ed for usage; the caller + /// retains ownership of the provided descriptor and should close accordingly. + MMAPAllocation(int fd, ErrorReporter* error_reporter); + + /// Maps the provided file descriptor, with the given offset and length (both + /// in bytes), to a memory region. + /// Note: The provided file descriptor will be dup'ed for usage; the caller + /// retains ownership of the provided descriptor and should close accordingly. + MMAPAllocation(int fd, size_t offset, size_t length, + ErrorReporter* error_reporter); + + ~MMAPAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + int fd() const { return mmap_fd_; } + + // The start address of the mmapped buffer. + // This will be base() rounded down to the nearest page boundary. + const void* mmapped_buffer() const { return mmapped_buffer_; } + + // The size of the mmapped buffer. + size_t mmapped_buffer_size() const { return bytes() + offset_in_buffer_; } + + // Offset of mmapped_buffer() in the file referenced by the file descriptor. + size_t mmapped_buffer_offset_in_file() const { + return offset_of_buffer_in_file_; + } + + static bool IsSupported(); + + protected: + // Data required for mmap. + int mmap_fd_ = -1; // mmap file descriptor + const void* mmapped_buffer_; + size_t buffer_size_bytes_ = 0; + // Used when the address to mmap is not page-aligned. + size_t offset_in_buffer_ = 0; + size_t offset_of_buffer_in_file_ = 0; + + private: + // Assumes ownership of the provided `owned_fd` instance. + MMAPAllocation(ErrorReporter* error_reporter, int owned_fd); + + // Assumes ownership of the provided `owned_fd` instance, and uses the given + // offset and length (both in bytes) for memory mapping. + MMAPAllocation(ErrorReporter* error_reporter, int owned_fd, size_t offset, + size_t length); +}; + +class FileCopyAllocation : public Allocation { + public: + /// Loads the provided file into a heap memory region. + FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); + ~FileCopyAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + std::unique_ptr copied_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class MemoryAllocation : public Allocation { + public: + /// Provides a (read-only) view of the provided buffer region as an + /// allocation. + /// Note: The caller retains ownership of `ptr`, and must ensure it remains + /// valid for the lifetime of the class instance. + MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter); + ~MemoryAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + const void* buffer_; +#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + void* aligned_ptr_ = nullptr; +#endif + size_t buffer_size_bytes_ = 0; +}; + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 47f09a037554b8..cdf20cc0913e38 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" namespace mlir { namespace TFL { @@ -98,6 +99,10 @@ struct PassConfig { // Enables the attempt to directly lower composites into tflite ops. bool enable_composite_direct_lowering = true; + + // Specifies the framework of the original model. + toco::TocoFlags::ModelOriginFramework model_origin_framework = + toco::TocoFlags::UNSET; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, @@ -126,7 +131,11 @@ inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, << pass_config.legalize_custom_tensor_list_ops << "\nreduce_type_precision: " << pass_config.reduce_type_precision << "\nconvert_qdq_format: " - << GetQDQQuantModeString(pass_config.qdq_conversion_mode) << "\n"; + << GetQDQQuantModeString(pass_config.qdq_conversion_mode) + << "\nmodel_origin_framework: " + << toco::TocoFlags::ModelOriginFramework_Name( + pass_config.model_origin_framework) + << "\n"; } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/core/BUILD b/tensorflow/compiler/mlir/lite/core/BUILD index e4c666993dc67d..d76299aa723d51 100644 --- a/tensorflow/compiler/mlir/lite/core/BUILD +++ b/tensorflow/compiler/mlir/lite/core/BUILD @@ -32,10 +32,10 @@ cc_library( ], deps = [ ":macros", + "//tensorflow/compiler/mlir/lite:allocation", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + "//tensorflow/compiler/mlir/lite/core/api:verifier", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/lite:allocation", - "//tensorflow/lite/core/api:error_reporter", - "//tensorflow/lite/core/api:verifier", "@com_google_absl//absl/strings", "@flatbuffers", ], @@ -53,7 +53,7 @@ cc_library( ], deps = [ ":model_builder_base", - "//tensorflow/lite/core/api:error_reporter", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", "@com_google_absl//absl/log", ], ) diff --git a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc index 4b83abe5000867..269d81efc0e73e 100644 --- a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc +++ b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/log/log.h" -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace mlir::TFL { diff --git a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h index f7a1c0cddd95b0..c3d76e2b03f820 100644 --- a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h +++ b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h @@ -17,8 +17,8 @@ limitations under the License. #include +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/compiler/mlir/lite/core/model_builder_base.h" -#include "tensorflow/lite/core/api/error_reporter.h" namespace mlir::TFL { diff --git a/tensorflow/compiler/mlir/lite/core/api/BUILD b/tensorflow/compiler/mlir/lite/core/api/BUILD new file mode 100644 index 00000000000000..0aaca3928420d6 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/BUILD @@ -0,0 +1,54 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/lite:__subpackages__", + "//tensorflow/lite:__subpackages__", + ], + licenses = ["notice"], +) + +exports_files(["error_reporter.h"]) + +filegroup( + name = "tflite_internal_cc_3p_api_deps_src", + srcs = [ + "error_reporter.cc", + "error_reporter.h", + "verifier.h", + ], + visibility = ["//tensorflow/lite:__pkg__"], +) + +cc_library( + name = "error_reporter", + srcs = ["error_reporter.cc"], + hdrs = ["error_reporter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), + deps = [], +) + +exports_files(["verifier.h"]) + +cc_library( + name = "verifier", + hdrs = ["verifier.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), + visibility = ["//visibility:public"], + deps = [":error_reporter"], +) + +tf_cc_test( + name = "error_reporter_test", + size = "small", + srcs = ["error_reporter_test.cc"], + deps = [ + ":error_reporter", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/core/api/error_reporter.cc b/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc similarity index 94% rename from tensorflow/lite/core/api/error_reporter.cc rename to tensorflow/compiler/mlir/lite/core/api/error_reporter.cc index 7070eaa57c589a..96f7561d1440f3 100644 --- a/tensorflow/lite/core/api/error_reporter.cc +++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + #include namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/core/api/error_reporter.h b/tensorflow/compiler/mlir/lite/core/api/error_reporter.h new file mode 100644 index 00000000000000..79c9fc9365e44a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter.h @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ + +#include + +namespace tflite { + +/// A functor that reports error to supporting system. Invoked similar to +/// printf. +/// +/// Usage: +/// ErrorReporter foo; +/// foo.Report("test %d", 5); +/// or +/// va_list args; +/// foo.Report("test %d", args); // where args is va_list +/// +/// Subclass ErrorReporter to provide another reporting destination. +/// For example, if you have a GUI program, you might redirect to a buffer +/// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter() = default; + /// Converts `args` to character equivalents according to `format` string, + /// constructs the error string and report it. + /// Returns number of characters written or zero on success, and negative + /// number on error. + virtual int Report(const char* format, va_list args) = 0; + + /// Converts arguments to character equivalents according to `format` string, + /// constructs the error string and report it. + /// Returns number of characters written or zero on success, and negative + /// number on error. + int Report(const char* format, ...); + + /// Equivalent to `Report` above. The additional `void*` parameter is unused. + /// This method is for compatibility with macros that takes `TfLiteContext`, + /// like TF_LITE_ENSURE and related macros. + int ReportError(void*, const char* format, ...); +}; + +} // namespace tflite + +// You should not make bare calls to the error reporter, instead use the +// TF_LITE_REPORT_ERROR macro, since this allows message strings to be +// stripped when the binary size has to be optimized. If you are looking to +// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and +// every call will be stubbed out, taking no memory. +#ifndef TF_LITE_STRIP_ERROR_STRINGS +#define TF_LITE_REPORT_ERROR(reporter, ...) \ + do { \ + static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \ + } while (false) +#else // TF_LITE_STRIP_ERROR_STRINGS +#define TF_LITE_REPORT_ERROR(reporter, ...) +#endif // TF_LITE_STRIP_ERROR_STRINGS + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/core/api/error_reporter_test.cc b/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc similarity index 96% rename from tensorflow/lite/core/api/error_reporter_test.cc rename to tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc index 03d6da734eae7d..ca7c4a2bb82ff8 100644 --- a/tensorflow/lite/core/api/error_reporter_test.cc +++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include diff --git a/tensorflow/compiler/mlir/lite/core/api/verifier.h b/tensorflow/compiler/mlir/lite/core/api/verifier.h new file mode 100644 index 00000000000000..2e24347dd626e4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/verifier.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Abstract interface for verifying a model. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite { + +/// Abstract interface that verifies whether a given model is legit. +/// It facilitates the use-case to verify and build a model without loading it +/// twice. +/// (See also "tensorflow/lite/tools/verifier.h".) +class TfLiteVerifier { + public: + /// Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.cc b/tensorflow/compiler/mlir/lite/core/model_builder_base.cc index 28306ca8684e49..2ad2b93329be16 100644 --- a/tensorflow/compiler/mlir/lite/core/model_builder_base.cc +++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/lite/allocation.h" -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.h b/tensorflow/compiler/mlir/lite/core/model_builder_base.h index b3b78a4b181468..aabd4f959a992d 100644 --- a/tensorflow/compiler/mlir/lite/core/model_builder_base.h +++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.h @@ -40,11 +40,11 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers #include "flatbuffers/verifier.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/verifier.h" #include "tensorflow/compiler/mlir/lite/core/macros.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/allocation.h" -#include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/verifier.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/delegates/flex/BUILD b/tensorflow/compiler/mlir/lite/delegates/flex/BUILD index 4ad7b874da82b8..2b3d198112d393 100644 --- a/tensorflow/compiler/mlir/lite/delegates/flex/BUILD +++ b/tensorflow/compiler/mlir/lite/delegates/flex/BUILD @@ -2,11 +2,9 @@ load( "//tensorflow:tensorflow.bzl", "if_mobile", "if_not_mobile", - "tf_cc_test", "tf_features_nolayering_check_if_ios", ) load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") -load("//tensorflow/compiler/mlir/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library") load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist") default_visibility = [ @@ -24,18 +22,6 @@ package( licenses = ["notice"], ) -exports_files([ - "delegate.h", - "exported_symbols.lds", - "version_script.lds", -]) - -tflite_flex_cc_library( - name = "delegate", - compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], -) - cc_library( name = "allowlisted_flex_ops_lib", srcs = [ @@ -54,21 +40,3 @@ cc_library( "//tensorflow/core:framework", ]), ) - -tf_cc_test( - name = "allowlisted_flex_ops_test", - size = "small", - srcs = [ - "allowlisted_flex_ops_test.cc", - ], - features = tf_features_nolayering_check_if_ios(), - deps = [ - ":allowlisted_flex_ops_lib", - ":delegate", - "@com_google_googletest//:gtest_main", - ] + if_mobile([ - "//tensorflow/core:portable_tensorflow_lib_lite", - ]) + if_not_mobile([ - "//tensorflow/core:framework", - ]), -) diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc index 8313bf2c10e269..4f12a705cc27f7 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc @@ -62,8 +62,7 @@ void TacModule::AddTACPass(mlir::OpPassManager* pass_manager, mlir::createCanonicalizerPass()); pass_manager->addPass( mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true)); - pass_manager->addPass( - mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true)); + pass_manager->addPass(mlir::TFL::CreateOptimizePass()); } pass_manager->addPass(mlir::TFL::tac::CreateComputeCostPass()); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 269f5cd7668062..ecad51df76be37 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -635,6 +635,11 @@ class Translator { mlir::TFL::WhileOp op, const std::vector& operands, const std::vector& results); + // Build while operator where then & else are regions. + std::optional> BuildIfOperator( + mlir::TFL::IfOp op, const std::vector& operands, + const std::vector& results); + // Build call once operator. BufferOffset BuildCallOnceOperator( mlir::TFL::CallOnceOp op, const std::vector& operands, @@ -1335,6 +1340,54 @@ std::optional> Translator::BuildWhileOperator( builtin_options); } +std::optional> Translator::BuildIfOperator( + mlir::TFL::IfOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); + auto get_call_op = [&](mlir::Block& b) -> std::optional { + if (b.getOperations().size() != 2) return std::nullopt; + if (auto call_op = dyn_cast(b.front())) return call_op; + return std::nullopt; + }; + auto then_call_op = get_call_op(op.getThenRegion().front()); + auto else_call_op = get_call_op(op.getElseRegion().front()); + if (!then_call_op || !else_call_op) + return op.emitOpError("only single call then/else while export supported"), + std::nullopt; + auto then_subgraph_index = + subgraph_index_map_.at(then_call_op.value().getCallee().str()); + auto else_subgraph_index = + subgraph_index_map_.at(else_call_op.value().getCallee().str()); + auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, + else_subgraph_index) + .Union(); + + // Get the subgraph index of IF op. + auto subgraph_func = op->getParentOfType(); + auto subgraph_idx = subgraph_index_map_[subgraph_func.getSymName().str()]; + auto new_operands = operands; + + // Then/Else region shares the same operands, only adding once as the new + // operands for the IF op. + if (then_call_op.value().getOperands() != + else_call_op.value().getOperands()) { + return op.emitOpError("Then/Else region does not contain same operands."), + std::nullopt; + } + + for (auto call_arg : then_call_op.value().getOperands()) { + auto name_of_call_arg = name_mapper_.GetUniqueName(call_arg); + const auto call_arg_tensor_id = + tensor_index_map_[subgraph_idx][name_of_call_arg]; + new_operands.push_back(call_arg_tensor_id); + } + auto inputs = builder_.CreateVector(new_operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_IfOptions, + builtin_options); +} + BufferOffset Translator::BuildNumericVerifyOperator( mlir::TFL::NumericVerifyOp op, const std::vector& operands, const std::vector& results) { @@ -2102,6 +2155,9 @@ std::optional> Translator::BuildOperator( } return BuildWhileOperator(whileOp, operands, results); } + if (auto ifOp = dyn_cast(inst)) { + return BuildIfOperator(ifOp, operands, results); + } inst->emitOpError("is not a supported TFLite op"); return std::nullopt; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 1633820bb5bd5e..c9aa62843d743c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -33,16 +34,21 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -54,6 +60,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project @@ -61,6 +68,7 @@ limitations under the License. #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h" #include "tensorflow/compiler/mlir/lite/utils/size_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" @@ -616,47 +624,50 @@ void IncrementIndex(ArrayRef result_shape, /// attributes `operand1` and `operand2` and returns the result if possible. /// This function assumes the both operands are verified to have value /// attributes of broadcastable types. -template > -Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs, +template > +Attribute ConstFoldBinaryOpDenseDense(ShapedType result_type, + DenseElementsAttr lhs, DenseElementsAttr rhs, const CalculationT& calculate) { - auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()) - .dyn_cast_or_null(); + auto type = llvm::dyn_cast_or_null( + OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())); if (!type) { return {}; } + type = type.clone(result_type.getElementType()); + const bool rhs_is_splat = rhs.isSplat(); const bool lhs_is_splat = lhs.isSplat(); + auto lhs_values = lhs.try_value_begin(); + auto rhs_values = rhs.try_value_begin(); + if (failed(lhs_values) || failed(rhs_values)) { + return {}; + } + // If both of them are splat, compute and return. if (lhs_is_splat && rhs_is_splat) { - auto element_result = AttrElementT::get( - type.getElementType(), calculate(lhs.getSplatValue(), - rhs.getSplatValue())); - if (!element_result) return {}; - - return DenseElementsAttr::get(type, element_result); + return DenseElementsT::get( + type, calculate(*lhs_values.value(), *rhs_values.value())); } auto num_elements = type.getNumElements(); - SmallVector new_values; + SmallVector new_values; new_values.reserve(num_elements); const auto result_shape = type.getShape(); std::vector current_index(type.getRank(), 0); + // Create the new shape with ones padded to the left. - const std::vector lhs_new_shape = + const auto lhs_new_shape = GetPaddedShape(lhs.getType().getShape(), type.getRank()); - const std::vector rhs_new_shape = + const auto rhs_new_shape = GetPaddedShape(rhs.getType().getShape(), type.getRank()); - auto lhs_old_values = lhs.getValues(); - auto rhs_old_values = rhs.getValues(); - // Add each pair of the corresponding values in the dense elements // attributes. for (int64_t i = 0; i < num_elements; ++i) { @@ -669,26 +680,27 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs, const int64_t rhs_index = rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index); - new_values.push_back(calculate(*(lhs_old_values.begin() + lhs_index), - *(rhs_old_values.begin() + rhs_index))); + new_values.push_back(calculate(*(lhs_values.value() + lhs_index), + *(rhs_values.value() + rhs_index))); IncrementIndex(result_shape, ¤t_index); } - return DenseElementsAttr::get(type, ArrayRef(new_values)); + return DenseElementsT::get(type, new_values); } /// Performs const folding `calculate` with broadcast behavior on the two /// attributes `operand1` and `operand2` and returns the result if possible. /// This function assumes the two operands are verified to have value /// attributes of broadcastable types. -template > -Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, +template > +Attribute ConstFoldBinaryOp(ShapedType result_type, Attribute operand1, Attribute operand2, const CalculationT& calculate) { if (operand1.dyn_cast_or_null() && operand2.dyn_cast_or_null()) { - return ConstFoldBinaryOpDenseDense( + return ConstFoldBinaryOpDenseDense( result_type, operand1.cast(), operand2.cast(), calculate); } @@ -703,23 +715,18 @@ Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, /// Depending on the given `resultType`, either `floatCalculate` or /// `intCalculate` is chosen to conduct the calculate. Attribute ConstFoldBinaryOp( - Type result_type, ArrayRef operands, + ShapedType type, ArrayRef operands, llvm::function_ref float_calculate, llvm::function_ref int_calculate) { - // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is - // represented as tensor. So we are only handling tensor types here. - auto type = result_type.dyn_cast(); - if (!type) return {}; - auto elemType = type.getElementType(); if (elemType.isa()) - return ConstFoldBinaryOp(result_type, operands[0], operands[1], - float_calculate); + return ConstFoldBinaryOp( + type, operands[0], operands[1], float_calculate); if (elemType.isSignlessInteger()) - return ConstFoldBinaryOp(result_type, operands[0], operands[1], - int_calculate); + return ConstFoldBinaryOp( + type, operands[0], operands[1], int_calculate); return {}; } @@ -809,6 +816,73 @@ int64_t AddOp::GetArithmeticCount(Operation* op) { return -1; } +//===----------------------------------------------------------------------===// +// FloorOp +//===----------------------------------------------------------------------===// + +OpFoldResult FloorOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + auto result_type = getType(); + if (!IsF32ShapedType(result_type)) return {}; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::floor(f); + return APFloat(result); + }; + + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// BitwiseXorOp +//===----------------------------------------------------------------------===// + +OpFoldResult BitwiseXorOp::fold(FoldAdaptor adaptor) { + auto compute = [](APInt lhs, APInt rhs) -> APInt { + lhs ^= rhs; + return lhs; + }; + + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), compute); +} + +//===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// + +OpFoldResult ExpOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + auto result_type = getType(); + if (!IsF32ShapedType(result_type)) return {}; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::exp(f); + return APFloat(result); + }; + + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// LogicalNotOp +//===----------------------------------------------------------------------===// + +OpFoldResult LogicalNotOp::fold(FoldAdaptor adaptor) { + auto data = llvm::dyn_cast_or_null(adaptor.getLhs()); + if (!data) { + return {}; + } + + auto compute = [](bool value) { return !value; }; + + return DenseIntElementsAttr::get( + data.getType(), + llvm::to_vector(llvm::map_range(data.getValues(), compute))); +} + //===----------------------------------------------------------------------===// // ConcatenationOp //===----------------------------------------------------------------------===// @@ -1681,6 +1755,38 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { // TODO(b/142478136): Handle fused ops. if (getFusedActivationFunction() != "NONE") return {}; + auto is_zero = [](Attribute a) { + return matchPattern(a, m_Zero()) || matchPattern(a, m_AnyZeroFloat()); + }; + auto is_one = [](Attribute a) { + return matchPattern(a, m_One()) || matchPattern(a, m_OneFloat()); + }; + + // Quantized folding not supported. + const bool is_quantized = + llvm::isa(getType().getElementType()); + + auto lhs = llvm::dyn_cast_or_null(adaptor.getLhs()); + auto rhs = llvm::dyn_cast_or_null(adaptor.getRhs()); + + if (lhs && !is_quantized) { + if (is_zero(lhs) && lhs.getType() == getType()) { + return lhs; + } + if (is_one(lhs) && getRhs().getType() == getType()) { + return getRhs(); + } + } + + if (rhs && !is_quantized) { + if (is_zero(rhs) && rhs.getType() == getType()) { + return rhs; + } + if (is_one(rhs) && getLhs().getType() == getType()) { + return getLhs(); + } + } + // This function is performance critical for op fusion patterns, e.g. // FuseBinaryOpToPrecedingAffine and FuseMulOrDivWithConv2dOrDepthwiseConv2d. // So a few specializations are provided to evaluate the math operation @@ -1688,14 +1794,15 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { // Specialization for f32 type. if (getType().cast().getElementType().isF32()) { - return ConstFoldBinaryOp( + return ConstFoldBinaryOp( getType(), operands[0], operands[1], [](float a, float b) { return a * b; }); } // Specialization for bf16 type. if (getType().cast().getElementType().isBF16()) { - return ConstFoldBinaryOp( + return ConstFoldBinaryOp( getType(), operands[0], operands[1], [](Eigen::bfloat16 a, Eigen::bfloat16 b) { return a * b; }); } @@ -1713,6 +1820,24 @@ int64_t MulOp::GetArithmeticCount(Operation* op) { return -1; } +//===----------------------------------------------------------------------===// +// PowOp +//===----------------------------------------------------------------------===// + +OpFoldResult PowOp::fold(FoldAdaptor adaptor) { + if (getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return std::pow(lhs, rhs); }); + } + if (getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return std::pow(lhs, rhs); }); + } + return {}; +} + //===----------------------------------------------------------------------===// // DivOp //===----------------------------------------------------------------------===// @@ -1721,9 +1846,30 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); // TODO(b/142478136): Handle fused ops. if (getFusedActivationFunction() != "NONE") return {}; - return ConstFoldBinaryOp( - getType(), operands, [](APFloat a, APFloat b) { return a / b; }, - [](APInt a, APInt b) { return a.sdiv(b); }); + + auto rhs = llvm::dyn_cast_or_null(adaptor.getRhs()); + auto lhs = llvm::dyn_cast_or_null(adaptor.getLhs()); + + if (rhs && lhs) { + return ConstFoldBinaryOp( + getType(), operands, [](APFloat a, APFloat b) { return a / b; }, + [](APInt a, APInt b) { return a.sdiv(b); }); + } + + if (llvm::isa(getType().getElementType())) { + // Quantized folding not supported for the following. + return {}; + } + + auto is_one = [](Attribute a) { + return matchPattern(a, m_One()) || matchPattern(a, m_OneFloat()); + }; + + if (rhs && is_one(rhs) && getLhs().getType() == getType()) { + return getLhs(); + } + + return {}; } int64_t DivOp::GetArithmeticCount(Operation* op) { @@ -3080,12 +3226,12 @@ OpFoldResult MaximumOp::fold(FoldAdaptor adaptor) { if (lhs && lhs.isSplat()) { APFloat lhs_value = lhs.getSplatValue(); lhs_value.changeSign(); - if (lhs_value.isLargest()) return getRhs(); + if (lhs_value.isLargest() || lhs_value.isInfinity()) return getRhs(); } if (rhs && rhs.isSplat()) { APFloat rhs_value = rhs.getSplatValue(); rhs_value.changeSign(); - if (rhs_value.isLargest()) return getLhs(); + if (rhs_value.isLargest() || rhs_value.isInfinity()) return getLhs(); } return nullptr; } @@ -3102,13 +3248,184 @@ OpFoldResult MinimumOp::fold(FoldAdaptor adaptor) { auto lhs = adaptor.getLhs().dyn_cast_or_null(); auto rhs = adaptor.getRhs().dyn_cast_or_null(); - if (lhs && lhs.isSplat() && lhs.getSplatValue().isLargest()) - return getRhs(); - if (rhs && rhs.isSplat() && rhs.getSplatValue().isLargest()) - return getLhs(); + if (lhs && lhs.isSplat()) { + auto splat = lhs.getSplatValue(); + if (splat.isLargest() || splat.isInfinity()) return getRhs(); + } + if (rhs && rhs.isSplat()) { + auto splat = rhs.getSplatValue(); + if (splat.isLargest() || splat.isInfinity()) return getLhs(); + } return nullptr; } +//===----------------------------------------------------------------------===// +// Comparison and Logical Ops +//===----------------------------------------------------------------------===// + +OpFoldResult LessOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs < rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs < rhs; }); + } + return {}; +} + +OpFoldResult LessEqualOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs <= rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs <= rhs; }); + } + return {}; +} + +OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs > rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs > rhs; }); + } + return {}; +} + +OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs >= rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs >= rhs; }); + } + return {}; +} + +OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { + if (getX().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getX(), adaptor.getY(), + [](int32_t lhs, int32_t rhs) { return lhs == rhs; }); + } + if (getX().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getX(), adaptor.getY(), + [](float lhs, float rhs) { return lhs == rhs; }); + } + return {}; +} + +OpFoldResult NotEqualOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs != rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs != rhs; }); + } + return {}; +} + +OpFoldResult LogicalAndOp::fold(FoldAdaptor adaptor) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](bool lhs, bool rhs) { return lhs && rhs; }); +} + +OpFoldResult LogicalOrOp::fold(FoldAdaptor adaptor) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](bool lhs, bool rhs) { return lhs || rhs; }); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +// TODO: b/359275356 - Expand this to handle the broadcast case similar +// to `ConstFoldBinaryOpDense`. +OpFoldResult SelectOp::fold(FoldAdaptor adaptor) { + auto lhs_type = getX().getType(); + auto rhs_type = getY().getType(); + auto condition_type = getCondition().getType(); + auto out_type = getType(); + + if (lhs_type != rhs_type) { + return {}; + } + + if (lhs_type.getShape() != condition_type.getShape()) { + // "broadcasted" condition not yet supported. + return {}; + } + + auto condition_vals = + llvm::dyn_cast_or_null(adaptor.getCondition()); + if (!condition_vals || !condition_vals.getElementType().isInteger(1)) { + return {}; + } + + if (condition_vals.isSplat()) { + const bool val = condition_vals.getSplatValue(); + return val ? adaptor.getX() : adaptor.getY(); + } + + auto lhs_vals = llvm::dyn_cast_or_null(adaptor.getX()); + auto rhs_vals = llvm::dyn_cast_or_null(adaptor.getY()); + if (!lhs_vals || !rhs_vals) { + return {}; + } + + llvm::SmallVector results; + results.reserve(condition_type.getNumElements()); + + auto lhs_it = lhs_vals.getValues().begin(); + auto lhs_end = lhs_vals.getValues().end(); + auto rhs_it = rhs_vals.getValues().begin(); + auto rhs_end = rhs_vals.getValues().end(); + + auto condition_it = condition_vals.getValues().begin(); + auto condition_end = condition_vals.getValues().end(); + + while (condition_it < condition_end && lhs_it < lhs_end && rhs_it < rhs_end) { + if (*condition_it++) { + results.push_back(*lhs_it); + } else { + results.push_back(*rhs_it); + } + + if (!lhs_vals.isSplat()) { + lhs_it++; + } + if (!rhs_vals.isSplat()) { + rhs_it++; + } + } + + return DenseElementsAttr::get(out_type, results); +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -3191,35 +3508,12 @@ void ConstOp::getCanonicalizationPatterns(RewritePatternSet& results, // CastOp //===----------------------------------------------------------------------===// -OpFoldResult CastOp::fold(FoldAdaptor adaptor) { - auto operands = adaptor.getOperands(); - assert(operands.size() == 1); - if (getInput().getType() == getType()) { - return getInput(); - } - - // For now, only supports cast between integer types. - auto elements_attr = operands[0].dyn_cast_or_null(); - if (!elements_attr) { - return nullptr; - } - - auto result_element_type = - getType().cast().getElementType().dyn_cast(); - auto operand_element_type = getInput() - .getType() - .cast() - .getElementType() - .dyn_cast(); - // Returns nullptr if either result/operand element type is not integer. - if (!result_element_type || !operand_element_type) { - return nullptr; - } - - const bool is_unsigned = operand_element_type.isUnsigned(); - const bool involves_bool = operand_element_type.getWidth() == 1 || - result_element_type.getWidth() == 1; - const int output_bitwidth = result_element_type.getWidth(); +OpFoldResult CastIntToInt(DenseIntElementsAttr data, IntegerType in_type, + IntegerType out_type) { + const bool is_unsigned = in_type.isUnsigned(); + const bool involves_bool = + in_type.getWidth() == 1 || out_type.getWidth() == 1; + const int output_bitwidth = out_type.getWidth(); // The integer cast op is the same as C integer cast. Depends on the operand // type's signedness, we will determine whether or not sign extension is // needed. @@ -3230,13 +3524,114 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { // true input should always be cast to 1 and not -1 as the sign extension // would do for signed outputs. Similarly, non-zero inputs should be cast // to true. Truncating even numbers to one bit will result in `false`. - return APInt(result_element_type.getWidth(), value != 0); + return APInt(out_type.getWidth(), value != 0); } return is_unsigned ? value.zextOrTrunc(output_bitwidth) : value.sextOrTrunc(output_bitwidth); }; - return elements_attr.mapValues(result_element_type, cast); + return data.mapValues(out_type, cast); +} + +OpFoldResult CastFloatToInt(DenseFPElementsAttr data, FloatType in_type, + IntegerType out_type) { + const bool from_f32 = in_type.isF32(); + const bool to_i32 = out_type.isSignlessInteger(32); + if (!from_f32 || !to_i32) { + return {}; + } + + auto cast = [&](APFloat value) -> APInt { + APSInt result(32, false); + bool is_exact; + value.convertToInteger(result, llvm::RoundingMode::TowardZero, &is_exact); + return result; + }; + + return data.mapValues(out_type, cast); +} + +template +llvm::SmallVector MapStaticCast(DenseElementsAttr data) { + return llvm::map_to_vector(data.getValues(), + [](InType v) { return static_cast(v); }); +} + +OpFoldResult CastIntToFloat(DenseIntElementsAttr data, IntegerType in_type, + FloatType out_type) { + auto result_type = data.getType().clone(out_type); + if (!out_type.isF32()) { + return {}; + } + + if (in_type.isSignlessInteger(32)) { + return DenseFPElementsAttr::get(result_type, + MapStaticCast(data)); + } + if (in_type.isSignlessInteger(1)) { + return DenseFPElementsAttr::get(result_type, + MapStaticCast(data)); + } + + return {}; +} + +OpFoldResult CastFloatToFloat(DenseFPElementsAttr data, FloatType in_type, + FloatType out_type) { + auto result_type = data.getType().clone(out_type); + if (in_type.isF32() && out_type.isF64()) { + return DenseFPElementsAttr::get(result_type, + MapStaticCast(data)); + } + + if (in_type.isF64() && out_type.isF32()) { + return DenseFPElementsAttr::get(result_type, + MapStaticCast(data)); + } + return {}; +} + +OpFoldResult CastOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + if (operands.size() != 1) { + return {}; + } + if (getInput().getType() == getType()) { + return getInput(); + } + + auto input = operands[0]; + + auto in_type = getInput().getType().getElementType(); + auto out_type = getType().getElementType(); + + if (auto int_in_type = llvm::dyn_cast_or_null(in_type)) { + auto in_data = llvm::dyn_cast_or_null(input); + if (!in_data) { + return {}; + } + if (auto float_out_type = llvm::dyn_cast_or_null(out_type)) { + return CastIntToFloat(in_data, int_in_type, float_out_type); + } + if (auto int_out_type = llvm::dyn_cast_or_null(out_type)) { + return CastIntToInt(in_data, int_in_type, int_out_type); + } + } + + if (auto float_in_type = llvm::dyn_cast_or_null(in_type)) { + auto in_data = llvm::dyn_cast_or_null(input); + if (!in_data) { + return {}; + } + if (auto float_out_type = llvm::dyn_cast_or_null(out_type)) { + return CastFloatToFloat(in_data, float_in_type, float_out_type); + } + if (auto int_out_type = llvm::dyn_cast_or_null(out_type)) { + return CastFloatToInt(in_data, float_in_type, int_out_type); + } + } + + return {}; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 33f920ebb02d5e..5eda0d01c31b61 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1281,6 +1281,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -1357,6 +1359,8 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -1554,6 +1558,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ let results = (outs TFL_BoolTensor:$output); + let hasFolder = 1; + let builders = [ OpBuilder<(ins "Value":$lhs, "Value":$rhs), @@ -1681,6 +1687,8 @@ def TFL_EqualOp: TFL_Op<"equal", [ let results = (outs TFL_BoolTensor:$output); let builders = [TFL_ComparisonBinaryBuilder]; + + let hasFolder = 1; } def TFL_ExpOp: TFL_Op<"exp", [ @@ -1697,6 +1705,8 @@ def TFL_ExpOp: TFL_Op<"exp", [ let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y); + let hasFolder = 1; + // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td @@ -1840,6 +1850,8 @@ def TFL_FloorOp: TFL_Op<"floor", [ let results = (outs TFL_FpTensor:$y); + let hasFolder = 1; + let extraClassDeclaration = [{ // Returns whether the return types are compatible. static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { @@ -1925,6 +1937,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -2036,6 +2050,8 @@ def TFL_LessOp : TFL_Op<"less", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -2046,7 +2062,7 @@ def TFL_LessOp : TFL_Op<"less", [ }]; } -def TFL_LogicalAndOp : TFL_Op<"logical_and", [Pure]> { +def TFL_LogicalAndOp : TFL_Op<"logical_and", [ResultsBroadcastableShape, Pure]> { let summary = "Logical AND operator"; let description = [{ @@ -2061,6 +2077,8 @@ def TFL_LogicalAndOp : TFL_Op<"logical_and", [Pure]> { let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -2083,9 +2101,11 @@ def TFL_LogicalNotOp : TFL_Op<"logical_not", [ let arguments = (ins TFL_BoolTensor:$lhs); let results = (outs TFL_BoolTensor:$output); + + let hasFolder = 1; } -def TFL_LogicalOrOp : TFL_Op<"logical_or", [Pure]> { +def TFL_LogicalOrOp : TFL_Op<"logical_or", [ResultsBroadcastableShape, Pure]> { let summary = "Logical OR operator"; let description = [{ @@ -2100,6 +2120,8 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [Pure]> { let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -2803,6 +2825,8 @@ def TFL_PowOp : TFL_Op<"pow", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -3157,6 +3181,8 @@ def TFL_SelectOp : TFL_Op<"select", [ let results = (outs TFL_TensorOf<[F32, I1, I8, I16, I32, I64, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$output); + let hasFolder = 1; + // TODO(jpienaar): autogenerate this. let builders = [ OpBuilder<(ins "Value":$condition, "Value":$x, "Value":$y), @@ -4080,6 +4106,7 @@ def TFL_BitcastOp : TFL_Op<"bitcast", [Pure]> { } def TFL_BitwiseXorOp : TFL_Op<"bitwise_xor", [ + ResultsBroadcastableShape, Commutative, SameOperandsAndResultElementType, Pure]> { @@ -4097,6 +4124,8 @@ def TFL_BitwiseXorOp : TFL_Op<"bitwise_xor", [ let results = (outs TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$output ); + + let hasFolder = 1; } def TFL_RightShiftOp : TFL_Op<"right_shift", [ @@ -4121,6 +4150,7 @@ def TFL_RightShiftOp : TFL_Op<"right_shift", [ //===----------------------------------------------------------------------===// // Quantization ops. //===----------------------------------------------------------------------===// + def TFL_DequantizeOp: TFL_Op<"dequantize", [NoMemoryEffect]> { let summary = "Dequantize operator"; diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc index f95e79d6c927c5..4a28c1474e9be8 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/compiler/mlir/lite/core/c/dimension_type.h" namespace tflite_migration { diff --git a/tensorflow/lite/mmap_allocation.cc b/tensorflow/compiler/mlir/lite/mmap_allocation.cc similarity index 97% rename from tensorflow/lite/mmap_allocation.cc rename to tensorflow/compiler/mlir/lite/mmap_allocation.cc index 3d1a7f03e713e1..eb106899228fba 100644 --- a/tensorflow/lite/mmap_allocation.cc +++ b/tensorflow/compiler/mlir/lite/mmap_allocation.cc @@ -21,8 +21,8 @@ limitations under the License. #include -#include "tensorflow/lite/allocation.h" -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace tflite { namespace { diff --git a/tensorflow/lite/mmap_allocation_disabled.cc b/tensorflow/compiler/mlir/lite/mmap_allocation_disabled.cc similarity index 96% rename from tensorflow/lite/mmap_allocation_disabled.cc rename to tensorflow/compiler/mlir/lite/mmap_allocation_disabled.cc index 95c34446797d7c..4e89594285473a 100644 --- a/tensorflow/lite/mmap_allocation_disabled.cc +++ b/tensorflow/compiler/mlir/lite/mmap_allocation_disabled.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "tensorflow/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 2744c4280038a8..299bb9e2f2bc06 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -198,6 +198,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/debug:debug_options_proto_cc", "//tensorflow/compiler/mlir/lite/metrics:error_collector", + "//tensorflow/compiler/mlir/lite/python/interpreter_wrapper:python_error_reporter", "//tensorflow/compiler/mlir/lite/python/interpreter_wrapper:python_utils", "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", @@ -206,7 +207,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_convert", diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc index 4354275aaedc55..881c30019b903a 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" #include "tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h" #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h" @@ -43,11 +44,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/toco/logging/conversion_log_util.h" #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h" #include "tensorflow/lite/toco/model.h" @@ -56,7 +58,6 @@ limitations under the License. #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/lite/toco/toco_tooling.h" -#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/lite/toco/types.pb.h" @@ -309,7 +310,7 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, bool enable_variable_quantization, bool disable_per_channel_for_dense_layers, PyObject* debug_options_proto_txt_raw) { - using tflite::interpreter_wrapper::PythonErrorReporter; + using tflite_migration::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; std::unique_ptr error_reporter(new PythonErrorReporter); @@ -399,7 +400,7 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, } PyObject* MlirSparsifyModel(PyObject* data) { - using tflite::interpreter_wrapper::PythonErrorReporter; + using tflite_migration::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; std::unique_ptr error_reporter(new PythonErrorReporter); diff --git a/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc b/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc index 6591251d9e915b..b880df7f74a3ca 100644 --- a/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc +++ b/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" @@ -30,7 +31,6 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD index 8d2cb7a65e4b8d..9268de7ec1de54 100644 --- a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD @@ -15,3 +15,14 @@ cc_library( "//third_party/python_runtime:headers", # buildcleaner: keep ], ) + +cc_library( + name = "python_error_reporter", + srcs = ["python_error_reporter.cc"], + hdrs = ["python_error_reporter.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite:stateful_error_reporter", + "//third_party/python_runtime:headers", # buildcleaner: keep + ], +) diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc new file mode 100644 index 00000000000000..75f9222d7c22d2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h" + +#include +#include +#include + +namespace tflite_migration { +namespace interpreter_wrapper { + +// Report an error message +int PythonErrorReporter::Report(const char* format, va_list args) { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << buf; + return formatted; +} + +// Set's a Python runtime exception with the last error. +PyObject* PythonErrorReporter::exception() { + std::string last_message = message(); + PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); + return nullptr; +} + +// Gets the last error message and clears the buffer. +std::string PythonErrorReporter::message() { + std::string value = buffer_.str(); + buffer_.clear(); + return value; +} +} // namespace interpreter_wrapper +} // namespace tflite_migration diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h new file mode 100644 index 00000000000000..f98a35227388bb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ + +#include + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/stateful_error_reporter.h" + +namespace tflite_migration { +namespace interpreter_wrapper { + +class PythonErrorReporter : public tflite_migration::StatefulErrorReporter { + public: + PythonErrorReporter() = default; + + // Report an error message + int Report(const char* format, va_list args) override; + + // Sets a Python runtime exception with the last error and + // clears the error message buffer. + PyObject* exception(); + + // Gets the last error message and clears the buffer. + std::string message() override; + + private: + std::stringstream buffer_; +}; + +} // namespace interpreter_wrapper +} // namespace tflite_migration +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 7a4567d7bd1c93..0aaedeae200a6e 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -211,6 +211,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer( pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config(); pass_config.enable_composite_direct_lowering = toco_flags.enable_composite_direct_lowering(); + pass_config.model_origin_framework = toco_flags.model_origin_framework(); if (toco_flags.qdq_conversion_mode() == "STATIC") { pass_config.quant_specs.qdq_conversion_mode = diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index d6f999eabb3c16..be09317cb52310 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 9e1694a0e95e12..c269b41b596ab5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -223,8 +223,7 @@ cc_library( srcs = ["test_util.cc"], hdrs = ["test_util.h"], deps = [ - "//tensorflow/lite/core/api", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", "@com_google_googletest//:gtest", - "@flatbuffers", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc index e096868eec8807..66c1adef98bc22 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" +#include +#include + #include namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h index b4e317c131888e..8953a384766963 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ -#include "tensorflow/lite/core/api/error_reporter.h" +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace mlir { namespace lite { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD index 36240d3fa7d1f9..a275f4ab2fbd66 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD @@ -1,10 +1,11 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//visibility:public", + "//visibility:private", ], licenses = ["notice"], ) @@ -13,6 +14,9 @@ cc_library( name = "portable_tensor_utils", srcs = ["portable_tensor_utils.cc"], hdrs = ["portable_tensor_utils.h"], + visibility = [ + "//tensorflow/compiler/mlir/quantization/common/quantization_lib:__pkg__", + ], ) cc_library( @@ -69,7 +73,68 @@ tf_cc_test( "//tensorflow/lite/core:framework", # to remove when mlir version is ready. "@com_google_absl//absl/status", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:platform_port", + "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/util:command_line_flags", + ], +) + +cc_library( + name = "quantize_weights", + srcs = select({ + "//tensorflow:ios": ["quantize_weights_portable.cc"], + "//tensorflow:android": ["quantize_weights_portable.cc"], + "//conditions:default": ["quantize_weights.cc"], + }), + hdrs = ["quantize_weights.h"], + compatible_with = get_compatible_with_portable(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:model_utils", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:portable_tensor_utils", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantization_utils", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", + "//tensorflow/core/platform:logging", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@flatbuffers//:runtime_cc", + ] + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//conditions:default": [ + "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_weights", + ], + }), +) + +tf_cc_test( + name = "quantize_weights_test", + srcs = ["quantize_weights_test.cc"], + args = [ + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", + ], + data = [ + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", + ], + tags = [ + "tflite_not_portable_android", + "tflite_not_portable_ios", + ], + deps = [ + ":quantize_weights", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", + "//tensorflow/core:framework_internal", + "//tensorflow/lite/core:framework", # to remove when mlir version is ready. + "@com_google_googletest//:gtest", + "@flatbuffers", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", "@local_xla//xla/tsl/util:command_line_flags", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc index 8b39e7fd678d5b..0a7bcd0df79597 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc @@ -32,9 +32,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/command_line_flags.h" #include "tensorflow/lite/core/model_builder.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/init_main.h" #include "tsl/platform/path.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc new file mode 100644 index 00000000000000..b2d6fe97280174 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc @@ -0,0 +1,751 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/model_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "tensorflow/core/platform/logging.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { +namespace { + +using absl::flat_hash_set; +using mlir::lite::toco_legacy:: + CustomOpMap; // Use this instead of mlir::lite::CustomOpMap because that + // uses mlir::lite::CustomOpInfo in + // tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h, + // and we need mlir::lite::toco_legacy::CustomOpInfo, in + // tensorflow/compiler/mlir/lite/quantization/lite/optimize/quantize_weights.h +using tflite::BufferT; +using tflite::BuiltinOperator; +using tflite::BuiltinOperator_BATCH_MATMUL; +using tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM; +using tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN; +using tflite::BuiltinOperator_CONV_2D; +using tflite::BuiltinOperator_CUSTOM; +using tflite::BuiltinOperator_DEPTHWISE_CONV_2D; +using tflite::BuiltinOperator_EMBEDDING_LOOKUP; +using tflite::BuiltinOperator_FULLY_CONNECTED; +using tflite::BuiltinOperator_GATHER; +using tflite::BuiltinOperator_LSTM; +using tflite::BuiltinOperator_RNN; +using tflite::BuiltinOperator_SVDF; +using tflite::BuiltinOperator_TRANSPOSE_CONV; +using tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM; +using tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN; +using tflite::FinishModelBuffer; +using tflite::GetBuiltinCode; +using tflite::Model; +using tflite::ModelT; +using tflite::OperatorCodeT; +using tflite::OperatorT; +using tflite::SubGraphT; +using tflite::TensorT; +using tflite::TensorType_FLOAT32; +using tflite::TensorType_INT8; + +struct ConsumerOpInfo { + OperatorT* op; + // The index of the op in the operators vector. + int32_t op_idx; + // The index of the tensor to quantize in subgraph->tensors. + int32_t op_input_idx; +}; + +struct TensorPerChannel { + TensorT* t; + bool is_per_channel; + int channel_dim; +}; + +// The default minimum number of elements a weights array must have to be +// quantized by this transformation. +const int kWeightsMinNumElementsDefault = 1024; + +// Redefined from tensorflow/lite/core/c/common.h as local const int instead of +// discouraged #define macro. +const int kTfLiteOptionalTensor = -1; + +// Convert the MLIR CustomOpMap from the TFlite CustomOpMap as their member +// variables differ. +void ConstructMLIRCustomOpMap(mlir::lite::CustomOpMap& mlir_map, + const CustomOpMap& tflite_map) { + for (const auto& entry : tflite_map) { + mlir_map[entry.first].quantizable_input_indices = + entry.second.quantizable_input_indices; + mlir_map[entry.first].is_weight_only = !entry.second.is_hybrid; + mlir_map[entry.first].no_side_effect = true; + } +} + +// Gets the operators that consume tensor_idx. +std::vector GetTensorConsumers(const ModelT* model, + const SubGraphT* subgraph, + int32_t tensor_idx) { + // TODO(suharshs): If this proves to be too slow, avoid calling it per tensor, + // instead doing one sweep for the entire model. + std::vector consumer_ops; + for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) { + OperatorT* op = subgraph->operators[op_idx].get(); + if (op == nullptr) { + continue; + } + for (size_t i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == tensor_idx) { + consumer_ops.push_back( + {op, static_cast(op_idx), static_cast(i)}); + } + } + } + return consumer_ops; +} + +// Gets the list of op->inputs indices of the weights inputs to be quantized for +// the provided op. +std::vector GetWeightInputIndices(const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map) { + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); + if (builtin_op_code == BuiltinOperator_CUSTOM) { + const std::string custom_code = op_code->custom_code; + const auto& custom_op_info = custom_op_map.find(custom_code); + if (custom_op_info != custom_op_map.end()) { + return custom_op_info->second.quantizable_input_indices; + } + } else if (builtin_op_code == BuiltinOperator_CONV_2D || + builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D || + builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || + builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP || + builtin_op_code == BuiltinOperator_TRANSPOSE_CONV) { + return {1}; + } else if (builtin_op_code == BuiltinOperator_SVDF) { + // tensorflow/lite/kernels/svdf.cc + return {1, 2}; + } else if (builtin_op_code == BuiltinOperator_LSTM || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) { + // tensorflow/lite/kernels/lstm.cc + // tensorflow/lite/kernels/unidirectional_sequence_lstm.cc + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16}; + } else if (builtin_op_code == BuiltinOperator_RNN || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + // tensorflow/lite/kernels/basic_rnn.cc + // tensorflow/lite/kernels/unidirectional_sequence_rnn.cc + return {1, 2}; + } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) { + // tensorflow/lite/kernels/bidirectional_sequence_lstm.cc + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 33, 40, 41, 42, 43, 44, 45, 46, 47}; + } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) { + // tensorflow/lite/kernels/bidirectional_sequence_rnn.cc + return {1, 2, 4, 5, 6, 8, 9, 10, 11}; + } else if (builtin_op_code == BuiltinOperator_GATHER) { + // tensorflow/lite/kernels/gather.cc + return {0}; + } + return {}; +} + +// Checks that a specific input can be quantized. +bool IsQuantizedInput(const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map, int op_input_idx) { + const auto quantized_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + return std::find(std::begin(quantized_input_indices), + std::end(quantized_input_indices), + op_input_idx) != std::end(quantized_input_indices); +} + +// Returns true if the operator supports hybrid evaluation. +bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme) { + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); + // Operations that support hybrid evaluation. + bool eval_hybrid = false; + if (builtin_op_code == BuiltinOperator_CUSTOM) { + const std::string custom_code = op_code->custom_code; + const auto custom_op_info = custom_op_map.find(custom_code); + if (custom_op_info == custom_op_map.end()) { + return {}; + } else { + return custom_op_info->second.is_hybrid; + } + } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || + builtin_op_code == BuiltinOperator_CONV_2D || + builtin_op_code == BuiltinOperator_SVDF || + builtin_op_code == BuiltinOperator_RNN || + builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + eval_hybrid = true; + } else if (builtin_op_code == BuiltinOperator_LSTM) { + const tflite::LSTMOptionsT* options = op->builtin_options.AsLSTMOptions(); + // Only lstm kernel_type full supports hybrid evaluation. + if (options->kernel_type == tflite::LSTMKernelType_FULL) { + eval_hybrid = true; + } + } else if (builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + eval_hybrid = use_updated_hybrid_scheme; + } + return eval_hybrid; +} + +// Returns true if all of the op's inputs are quantized. +bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op, + const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map) { + std::vector op_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + for (const int32_t op_input_idx : op_input_indices) { + int32_t tensor_idx = op->inputs[op_input_idx]; + + if (tensor_idx == -1) { + // Optional tensor. + continue; + } + + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + + if (tensor->type != TensorType_INT8) { + return false; + } + } + return true; +} + +// Inserts Tensors for each input tensor of op that should be +// quantized into tensor_map. +absl::Status InsertQuantizableInputTensorsFromOperator( + const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + absl::flat_hash_map* tensor_map, + int subgraph_index, bool use_updated_hybrid_scheme) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + auto builtin_code = GetBuiltinCode(op_code); + + std::vector op_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + for (const int32_t op_input_idx : op_input_indices) { + int32_t tensor_idx = op->inputs[op_input_idx]; + if (tensor_idx == -1) { + LOG(INFO) << "Skipping optional tensor input " << op_input_idx + << " of operation " << EnumNameBuiltinOperator(builtin_code); + continue; + } + + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + if (tensor->type != TensorType_FLOAT32) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " that is not type float."; + continue; + } + + uint64_t num_elements; + if (!mlir::lite::toco_legacy::NumElements(*tensor, &num_elements).ok()) { + return absl::InternalError("Error in quantization_utils NumElements"); + } + if (num_elements < weights_min_num_elements) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has fewer than " << weights_min_num_elements + << " elements (" << num_elements << ")."; + continue; + } + + // Some tensors may have a null buffer vector, indicating an intermediate + // array. + if (model->buffers[tensor->buffer]->data.data() == nullptr) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has no allocated buffer."; + continue; + } + + if (builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + tensor_map->insert({tensor_idx, + {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, + /*dim=*/3}}); + } else if (builtin_code == BuiltinOperator_CONV_2D) { + tensor_map->insert({tensor_idx, + {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, + /*dim=*/0}}); + } else { + switch (builtin_code) { + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: + op->builtin_options.AsBidirectionalSequenceLSTMOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: + op->builtin_options.AsBidirectionalSequenceRNNOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_FULLY_CONNECTED: + op->builtin_options.AsFullyConnectedOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_BATCH_MATMUL: + op->builtin_options.AsBatchMatMulOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_LSTM: + op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_RNN: + op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_SVDF: + op->builtin_options.AsSVDFOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + op->builtin_options.AsUnidirectionalSequenceLSTMOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: + op->builtin_options.AsSequenceRNNOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + default: + break; + } + tensor_map->insert({tensor_idx, {tensor, /*is_per_channel=*/false}}); + } + } + + return absl::OkStatus(); +} + +// Updates operator code versions for the operators with INT8 inputs. +void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) { + for (int i = 0, end = model->operator_codes.size(); i < end; ++i) { + const BuiltinOperator& op_code = + GetBuiltinCode(model->operator_codes[i].get()); + if (op_code == BuiltinOperator_RNN || + op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || + op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 3 : 2; + } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_EMBEDDING_LOOKUP) { + model->operator_codes[i]->version = 3; + } else if (op_code == BuiltinOperator_LSTM) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 3; + } else if (op_code == BuiltinOperator_CONV_2D) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2; + } else if (op_code == BuiltinOperator_FULLY_CONNECTED) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3; + } else if (op_code == BuiltinOperator_BATCH_MATMUL) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 1; + } else if (op_code == BuiltinOperator_SVDF) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2; + } else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + model->operator_codes[i]->version = 6; + } + } +} + +// Returns true if the op in consumer_op_infos can pass through quantization. +bool IsQuantizationPassThroughOps( + const ModelT* model, const std::vector& consumer_op_infos) { + if (consumer_op_infos.size() != 1) { + return false; + } + const OperatorT* consumer_op = consumer_op_infos.front().op; + const BuiltinOperator op_code = + GetBuiltinCode(model->operator_codes[consumer_op->opcode_index].get()); + return op_code == BuiltinOperator_GATHER || + op_code == BuiltinOperator_EMBEDDING_LOOKUP; +} + +// Copies quantization parameters from input to output and returns consumers of +// the output tensor as a tuple with values: +// - index of the output tensor +// - pointer to the output tensor +// - vector of consumers ops. +std::tuple> +PassQuantizationAndGetConsumers( + const ModelT* model, const SubGraphT* subgraph, + const std::vector& consumer_op_infos, + const CustomOpMap& custom_op_map) { + const OperatorT* op = consumer_op_infos.front().op; + const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + if (op->outputs.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized output"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t output_tensor_idx = op->outputs.front(); + const auto input_idx = GetWeightInputIndices(op_code, custom_op_map); + if (input_idx.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized input"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t input_tensor_idx = op->inputs[input_idx.front()]; + + // Propagate quantization params. + const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get(); + TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get(); + if (!output_tensor->quantization) { + output_tensor->quantization = + std::make_unique(); + } + *output_tensor->quantization = *input_tensor->quantization; + output_tensor->type = TensorType_INT8; + return std::make_tuple( + output_tensor_idx, output_tensor, + GetTensorConsumers(model, subgraph, output_tensor_idx)); +} + +inline bool IsOpDenylisted(const flat_hash_set& op_denylist, + const BuiltinOperator op_code) { + return op_denylist.find(op_code) != op_denylist.end(); +} + +absl::Status QuantizeWeightsInt8( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + bool use_hybrid_evaluation, uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme, + const absl::flat_hash_set& op_denylist = {}) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + for (int subgraph_index = 0, end = model->subgraphs.size(); + subgraph_index < end; ++subgraph_index) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + absl::Status status = InsertQuantizableInputTensorsFromOperator( + model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map, + subgraph_index, use_updated_hybrid_scheme); + if (!status.ok()) return status; + } + + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + if (tensor_pair.second.is_per_channel) { + if (!mlir::lite::toco_legacy::SymmetricQuantizeTensorPerChannel( + model.get(), tensor_pair.second.t, + tensor_pair.second.channel_dim) + .ok()) { + return absl::InternalError( + "SymmetricQuantizeTensorPerChannel failed"); + } + } else { + if (!mlir::lite::toco_legacy::SymmetricQuantizeTensor( + model.get(), tensor_pair.second.t) + .ok()) { + return absl::InternalError("SymmetricQuantizeTensor failed"); + } + } + } + + // Examine the tensor consumers to determine which require dequantize ops. + for (const auto& tensor_pair : tensor_map) { + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second.t; + std::vector consumer_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) { + std::tie(tensor_idx, tensor, consumer_op_infos) = + PassQuantizationAndGetConsumers(model.get(), subgraph, + consumer_op_infos, custom_op_map); + if (tensor_idx < 0) { + // Error message is already logged by PassQuantizationAndGetConsumers. + return absl::InternalError("PassQuantizationAndGetConsumers failed"); + } + } + + std::vector dequant_op_infos; // Ops that need dequants. + for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) { + OperatorT* consumer_op = consumer_op_info.op; + const OperatorCodeT* consumer_op_code = + model->operator_codes[consumer_op->opcode_index].get(); + // If the op is a hybrid op and all the required tensors are quantized, + // we have no further work to do, but for all ops that require + // dequantization we need to add a Dequantize op. + bool eval_hybrid = + use_hybrid_evaluation && + !IsOpDenylisted(op_denylist, GetBuiltinCode(consumer_op_code)) && + IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map, + use_updated_hybrid_scheme) && + CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code, + custom_op_map) && + IsQuantizedInput(consumer_op_code, custom_op_map, + consumer_op_info.op_input_idx); + if (!eval_hybrid) { + dequant_op_infos.push_back(consumer_op_info); + } + } + + // Check if this tensor is an output tensor. + int32_t output_index = -1; + for (int32_t i = 0; i < subgraph->outputs.size(); ++i) { + if (subgraph->outputs[i] == tensor_idx) { + output_index = i; + break; + } + } + + // If no ops require dequant and it is not output, we are done for this + // tensor. + if (dequant_op_infos.empty() && output_index < 0) { + continue; + } + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const std::string dequant_name = tensor->name + "_dequantize"; + mlir::lite::toco_legacy::MakeTensor( + dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + mlir::lite::toco_legacy::MakeDequantizeOperator( + model.get(), &dequantize_op, tensor_idx, dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + // Update output name. + if (output_index >= 0) { + subgraph->outputs[output_index] = dequantize_output_idx; + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + } + + // Update the modified operator code versions. + UpdateInt8OperatorVersions(model.get(), use_updated_hybrid_scheme); + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + + return absl::OkStatus(); +} + +absl::Status QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + for (int subgraph_index = 0, end = model->subgraphs.size(); + subgraph_index < end; ++subgraph_index) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0, sub_end = subgraph->operators.size(); i < sub_end; ++i) { + OperatorT* op = subgraph->operators[i].get(); + for (auto tensor_idx : op->inputs) { + // Skip optional tensors. + if (tensor_idx == kTfLiteOptionalTensor) { + continue; + } + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return absl::InternalError("Buffer is null"); + } + // Quantize tensors that have data to quantize. + bool is_constant = !model->buffers[tensor->buffer].get()->data.empty(); + if (tensor->type == TensorType_FLOAT32 && is_constant) { + tensor_map.insert({tensor_idx, tensor}); + } + } + } + + // The hash map ensures that we quantize each tensor exactly once. + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + if (!mlir::lite::toco_legacy::QuantizeTensorFloat16(model.get(), + tensor_pair.second) + .ok()) { + return absl::InternalError("QuantizeTensorFloat16 failed"); + } + + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second; + std::vector dequant_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const std::string dequant_name = tensor->name + "_dequantize"; + mlir::lite::toco_legacy::MakeTensor( + dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + mlir::lite::toco_legacy::MakeDequantizeOperator( + model.get(), &dequantize_op, tensor_idx, dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + } + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + return absl::OkStatus(); +} +} // namespace + +namespace internal { +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + bool use_hybrid_evaluation, + QuantizerType quantizer_type) { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + return mlir::lite::QuantizeWeights( + builder, input_model, weights_min_num_elements, use_hybrid_evaluation); + } + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} +} // namespace internal + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + return mlir::lite::QuantizeWeights(builder, input_model, + weights_min_num_elements); + } + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, BufferType quant_type, + bool use_updated_hybrid_scheme, + QuantizerType quantizer_type) { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + return mlir::lite::QuantizeWeights(builder, input_model, + (mlir::lite::BufferType)quant_type, + use_updated_hybrid_scheme); + } + switch (quant_type) { + case BufferType::QUANTIZED_INT8: { + mlir::lite::toco_legacy::CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, + kWeightsMinNumElementsDefault, custom_op_map, + use_updated_hybrid_scheme); + } + case BufferType::QUANTIZED_FLOAT16: + return QuantizeWeightsFloat16(builder, input_model); + } +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + mlir::lite::CustomOpMap mlir_custom_op_map; + ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map); + return mlir::lite::QuantizeWeights( + builder, input_model, weights_min_num_elements, mlir_custom_op_map); + } + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme, + const flat_hash_set& op_denylist, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + mlir::lite::CustomOpMap mlir_custom_op_map; + ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map); + return mlir::lite::QuantizeWeights( + builder, input_model, weights_min_num_elements, mlir_custom_op_map, + use_updated_hybrid_scheme, op_denylist); + } + return QuantizeWeightsInt8(builder, input_model, + /*use_hybrid_evaluation=*/true, + weights_min_num_elements, custom_op_map, + use_updated_hybrid_scheme, op_denylist); +} + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h new file mode 100644 index 00000000000000..039c18d8e1d256 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h @@ -0,0 +1,109 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { + +using ::tflite::BuiltinOperator; +using ::tflite::Model; + +// Supported resulting types from quantization process. +enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 }; +enum class QuantizerType { OLD_QUANTIZER, MLIR_QUANTIZER }; + +// Stores information about how to quantize a user-specified custom operation. +struct CustomOpInfo { + std::vector quantizable_input_indices; + bool is_hybrid; +}; + +// Map from custom op code to custom op quantization information. +using CustomOpMap = std::unordered_map; + +// This macro is for internal use for conversions requiring previous behavior. +#ifdef TFLITE_USE_PREVIOUS_HYBRID_SCHEME +// Use asymmetric quantized activations and per-channel quantized weights. +constexpr bool kUseUpdatedHybridSchemeDefault = false; +#else +// Use symmetric quantized activations and per-channel quantized weights. +constexpr bool kUseUpdatedHybridSchemeDefault = true; +#endif + +// Quantizes input_model and populates the provided builder with the new model. +// By default only weights tensors weight more than 1024 elements will be +// quantized. +// +// A tflite::Model can be obtained from the builder with: +// const uint8_t* buffer = builder->GetBufferPointer(); +// tflite::Model* model = GetModel(buffer); +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + BufferType quant_type = BufferType::QUANTIZED_INT8, + bool use_updated_hybrid_scheme = kUseUpdatedHybridSchemeDefault, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but only weights with greater than or equal +// weights_min_num_elements elements will be quantized. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but with entry point of quantizing custom ops. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but if use updated_hybrid_scheme is false, +// use previous quantization scheme. Optional op_denylist argument +// disables hybrid evaluation for provided BuiltinOperators. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme, + const absl::flat_hash_set& op_denylist = {}, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +namespace internal { +// If use_hybrid_evaluation is false, will disable using hybrid eval for +// operations that support it. +// +// We use this internal QuantizeWeights call to test models with hybrid +// evaluation disabled. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, bool use_hybrid_evaluation, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); +} // namespace internal + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_portable.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_portable.cc new file mode 100644 index 00000000000000..91030d4cf57e27 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_portable.cc @@ -0,0 +1,692 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// clang-format off +#include "tensorflow/lite/tools/toco_legacy/quantize_weights.h" +// clang-format on + +#include +#include +#include +#include + +#include "flatbuffers/flexbuffers.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +// #include "tensorflow/lite/context.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/model_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "tensorflow/lite/core/model.h" // to be replaced with unda's model_builder + +namespace tflite { +namespace optimize { + +namespace { + +struct ConsumerOpInfo { + OperatorT* op; + // The index of the op in the operators vector. + int32_t op_idx; + // The index of the tensor to quantize in subgraph->tensors. + int32_t op_input_idx; +}; + +struct TensorPerChannel { + TensorT* t; + bool is_per_channel; + int channel_dim; +}; + +// The default minimum number of elements a weights array must have to be +// quantized by this transformation. +const int kWeightsMinNumElementsDefault = 1024; + +// Gets the operators that consume tensor_idx. +std::vector GetTensorConsumers(const ModelT* model, + const SubGraphT* subgraph, + int32_t tensor_idx) { + // TODO(suharshs): If this proves to be too slow, avoid calling it per tensor, + // instead doing one sweep for the entire model. + std::vector consumer_ops; + for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) { + OperatorT* op = subgraph->operators[op_idx].get(); + if (op == nullptr) { + continue; + } + for (size_t i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == tensor_idx) { + consumer_ops.push_back( + {op, static_cast(op_idx), static_cast(i)}); + } + } + } + return consumer_ops; +} + +// Gets the list of op->inputs indices of the weights inputs to be quantized for +// the provided op. +std::vector GetWeightInputIndices(const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map) { + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); + if (builtin_op_code == BuiltinOperator_CUSTOM) { + const std::string custom_code = op_code->custom_code; + const auto& custom_op_info = custom_op_map.find(custom_code); + if (custom_op_info != custom_op_map.end()) { + return custom_op_info->second.quantizable_input_indices; + } + } else if (builtin_op_code == BuiltinOperator_CONV_2D || + builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D || + builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || + builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP || + builtin_op_code == BuiltinOperator_TRANSPOSE_CONV) { + return {1}; + } else if (builtin_op_code == BuiltinOperator_SVDF) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/svdf.cc + return {1, 2}; + } else if (builtin_op_code == BuiltinOperator_LSTM || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/lstm.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16}; + } else if (builtin_op_code == BuiltinOperator_RNN || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/basic_rnn.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc + return {1, 2}; + } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 33, 40, 41, 42, 43, 44, 45, 46, 47}; + } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc + return {1, 2, 4, 5, 6, 8, 9, 10, 11}; + } else if (builtin_op_code == BuiltinOperator_GATHER) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/gather.cc + return {0}; + } + return {}; +} + +// Checks that a specific input can be quantized. +bool IsQuantizedInput(const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map, int op_input_idx) { + const auto quantized_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + return std::find(std::begin(quantized_input_indices), + std::end(quantized_input_indices), + op_input_idx) != std::end(quantized_input_indices); +} + +// Returns true if the operator supports hybrid evaluation. +bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme) { + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); + // Operations that support hybrid evaluation. + bool eval_hybrid = false; + if (builtin_op_code == BuiltinOperator_CUSTOM) { + const std::string custom_code = op_code->custom_code; + const auto custom_op_info = custom_op_map.find(custom_code); + if (custom_op_info == custom_op_map.end()) { + return {}; + } else { + return custom_op_info->second.is_hybrid; + } + } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || + builtin_op_code == BuiltinOperator_CONV_2D || + builtin_op_code == BuiltinOperator_SVDF || + builtin_op_code == BuiltinOperator_RNN || + builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + eval_hybrid = true; + } else if (builtin_op_code == BuiltinOperator_LSTM) { + const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions(); + // Only lstm kernel_type full supports hybrid evaluation. + if (options->kernel_type == LSTMKernelType_FULL) { + eval_hybrid = true; + } + } else if (builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + eval_hybrid = use_updated_hybrid_scheme; + } + return eval_hybrid; +} + +// Returns true if all of the op's inputs are quantized. +bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op, + const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map) { + std::vector op_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + for (const int32_t op_input_idx : op_input_indices) { + int32_t tensor_idx = op->inputs[op_input_idx]; + + if (tensor_idx == -1) { + // Optional tensor. + continue; + } + + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + + if (tensor->type != TensorType_INT8) { + return false; + } + } + return true; +} + +// Inserts Tensors for each input tensor of op that should be +// quantized into tensor_map. +TfLiteStatus InsertQuantizableInputTensorsFromOperator( + const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + absl::flat_hash_map* tensor_map, + int subgraph_index, bool use_updated_hybrid_scheme) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + auto builtin_code = GetBuiltinCode(op_code); + + std::vector op_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + for (const int32_t op_input_idx : op_input_indices) { + int32_t tensor_idx = op->inputs[op_input_idx]; + if (tensor_idx == -1) { + LOG(INFO) << "Skipping optional tensor input " << op_input_idx + << " of operation " << EnumNameBuiltinOperator(builtin_code); + continue; + } + + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + if (tensor->type != TensorType_FLOAT32) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " that is not type float."; + continue; + } + + uint64_t num_elements; + TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements)); + if (num_elements < weights_min_num_elements) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has fewer than " << weights_min_num_elements + << " elements (" << num_elements << ")."; + continue; + } + + // Some tensors may have a null buffer vector, indicating an intermediate + // array. + if (model->buffers[tensor->buffer]->data.data() == nullptr) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has no allocated buffer."; + continue; + } + + if (builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + tensor_map->insert({tensor_idx, + {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, + /*dim=*/3}}); + } else if (builtin_code == BuiltinOperator_CONV_2D) { + tensor_map->insert({tensor_idx, + {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, + /*dim=*/0}}); + } else { + switch (builtin_code) { + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: + op->builtin_options.AsBidirectionalSequenceLSTMOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: + op->builtin_options.AsBidirectionalSequenceRNNOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_FULLY_CONNECTED: + op->builtin_options.AsFullyConnectedOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_BATCH_MATMUL: + op->builtin_options.AsBatchMatMulOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_LSTM: + op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_RNN: + op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_SVDF: + op->builtin_options.AsSVDFOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + op->builtin_options.AsUnidirectionalSequenceLSTMOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: + op->builtin_options.AsSequenceRNNOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + default: + break; + } + tensor_map->insert({tensor_idx, {tensor, /*is_per_channel=*/false}}); + } + } + + return kTfLiteOk; +} + +// Updates operator code versions for the operators with INT8 inputs. +void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) { + for (int i = 0, end = model->operator_codes.size(); i < end; ++i) { + const BuiltinOperator& op_code = + GetBuiltinCode(model->operator_codes[i].get()); + if (op_code == BuiltinOperator_RNN || + op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || + op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 3 : 2; + } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_EMBEDDING_LOOKUP) { + model->operator_codes[i]->version = 3; + } else if (op_code == BuiltinOperator_LSTM) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 3; + } else if (op_code == BuiltinOperator_CONV_2D) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2; + } else if (op_code == BuiltinOperator_FULLY_CONNECTED) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3; + } else if (op_code == BuiltinOperator_BATCH_MATMUL) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 1; + } else if (op_code == BuiltinOperator_SVDF) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2; + } else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + model->operator_codes[i]->version = 6; + } + } +} + +// Returns true if the op in consumer_op_infos can pass through quantization. +bool IsQuantizationPassThroughOps( + const ModelT* model, const std::vector& consumer_op_infos) { + if (consumer_op_infos.size() != 1) { + return false; + } + const OperatorT* consumer_op = consumer_op_infos.front().op; + const BuiltinOperator op_code = + GetBuiltinCode(model->operator_codes[consumer_op->opcode_index].get()); + return op_code == BuiltinOperator_GATHER || + op_code == BuiltinOperator_EMBEDDING_LOOKUP; +} + +// Copies quantization parameters from input to output and returns consumers of +// the output tensor as a tuple with values: +// - index of the output tensor +// - pointer to the output tensor +// - vector of consumers ops. +std::tuple> +PassQuantizationAndGetConsumers( + const ModelT* model, const SubGraphT* subgraph, + const std::vector& consumer_op_infos, + const CustomOpMap& custom_op_map) { + const OperatorT* op = consumer_op_infos.front().op; + const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + if (op->outputs.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized output"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t output_tensor_idx = op->outputs.front(); + const auto input_idx = GetWeightInputIndices(op_code, custom_op_map); + if (input_idx.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized input"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t input_tensor_idx = op->inputs[input_idx.front()]; + + // Propagate quantization params. + const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get(); + TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get(); + if (!output_tensor->quantization) { + output_tensor->quantization = std::make_unique(); + } + *output_tensor->quantization = *input_tensor->quantization; + output_tensor->type = TensorType_INT8; + return std::make_tuple( + output_tensor_idx, output_tensor, + GetTensorConsumers(model, subgraph, output_tensor_idx)); +} + +inline bool IsOpDenylisted(const flat_hash_set& op_denylist, + const BuiltinOperator op_code) { + return op_denylist.find(op_code) != op_denylist.end(); +} + +absl::Status QuantizeWeightsInt8( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + bool use_hybrid_evaluation, uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme, + const flat_hash_set& op_denylist = {}) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + for (int subgraph_index = 0, end = model->subgraphs.size(); + subgraph_index < end; ++subgraph_index) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + if (InsertQuantizableInputTensorsFromOperator( + model.get(), op, weights_min_num_elements, custom_op_map, + &tensor_map, subgraph_index, + use_updated_hybrid_scheme) != kTfLiteOk) { + return absl::InternalError( + "Failed to insert quantizable input tensors from operator"); + } + } + + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + if (tensor_pair.second.is_per_channel) { + if (utils::SymmetricQuantizeTensorPerChannel( + model.get(), tensor_pair.second.t, + tensor_pair.second.channel_dim, nullptr) != kTfLiteOk) { + return absl::InternalError("Failed to quantize tensor per channel"); + } + } else { + if (utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second.t) != + kTfLiteOk) { + return absl::InternalError("Failed to quantize tensor"); + } + } + } + + // Examine the tensor consumers to determine which require dequantize ops. + for (const auto& tensor_pair : tensor_map) { + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second.t; + std::vector consumer_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) { + std::tie(tensor_idx, tensor, consumer_op_infos) = + PassQuantizationAndGetConsumers(model.get(), subgraph, + consumer_op_infos, custom_op_map); + if (tensor_idx < 0) { + // Error message is already logged by PassQuantizationAndGetConsumers. + return absl::InternalError( + "Failed to pass quantization and get consumers"); + } + } + + std::vector dequant_op_infos; // Ops that need dequants. + for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) { + OperatorT* consumer_op = consumer_op_info.op; + const OperatorCodeT* consumer_op_code = + model->operator_codes[consumer_op->opcode_index].get(); + // If the op is a hybrid op and all the required tensors are quantized, + // we have no further work to do, but for all ops that require + // dequantization we need to add a Dequantize op. + bool eval_hybrid = + use_hybrid_evaluation && + !IsOpDenylisted(op_denylist, GetBuiltinCode(consumer_op_code)) && + IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map, + use_updated_hybrid_scheme) && + CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code, + custom_op_map) && + IsQuantizedInput(consumer_op_code, custom_op_map, + consumer_op_info.op_input_idx); + if (!eval_hybrid) { + dequant_op_infos.push_back(consumer_op_info); + } + } + + // Check if this tensor is an output tensor. + int32_t output_index = -1; + for (int32_t i = 0; i < subgraph->outputs.size(); ++i) { + if (subgraph->outputs[i] == tensor_idx) { + output_index = i; + break; + } + } + + // If no ops require dequant and it is not output, we are done for this + // tensor. + if (dequant_op_infos.empty() && output_index < 0) { + continue; + } + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const string dequant_name = tensor->name + "_dequantize"; + utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx, + dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + // Update output name. + if (output_index >= 0) { + subgraph->outputs[output_index] = dequantize_output_idx; + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + } + + // Update the modified operator code versions. + UpdateInt8OperatorVersions(model.get(), use_updated_hybrid_scheme); + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + + return absl::OkStatus(); +} + +absl::Status QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + for (int subgraph_index = 0, end = model->subgraphs.size(); + subgraph_index < end; ++subgraph_index) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0, sub_end = subgraph->operators.size(); i < sub_end; ++i) { + OperatorT* op = subgraph->operators[i].get(); + for (auto tensor_idx : op->inputs) { + // Skip optional tensors. + if (tensor_idx == kTfLiteOptionalTensor) { + continue; + } + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return absl::InternalError("Buffer is null"); + } + // Quantize tensors that have data to quantize. + bool is_constant = !model->buffers[tensor->buffer].get()->data.empty(); + if (tensor->type == TensorType_FLOAT32 && is_constant) { + tensor_map.insert({tensor_idx, tensor}); + } + } + } + + // The hash map ensures that we quantize each tensor exactly once. + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + if (utils::QuantizeTensorFloat16(model.get(), tensor_pair.second) != + kTfLiteOk) { + return absl::InternalError("QuantizeTensorFloat16 failed"); + } + + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second; + std::vector dequant_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const string dequant_name = tensor->name + "_dequantize"; + utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx, + dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + } + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + return absl::OkStatus(); +} +} // namespace + +namespace internal { +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + bool use_hybrid_evaluation, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} +} // namespace internal + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, BufferType quant_type, + bool use_updated_hybrid_scheme, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + switch (quant_type) { + case BufferType::QUANTIZED_INT8: { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, + kWeightsMinNumElementsDefault, custom_op_map, + use_updated_hybrid_scheme); + } + case BufferType::QUANTIZED_FLOAT16: + return QuantizeWeightsFloat16(builder, input_model); + } +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme, + const flat_hash_set& op_denylist, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + return QuantizeWeightsInt8(builder, input_model, + /*use_hybrid_evaluation=*/true, + weights_min_num_elements, custom_op_map, + use_updated_hybrid_scheme, op_denylist); +} + +} // namespace optimize +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_test.cc new file mode 100644 index 00000000000000..7277e1dfbbe438 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_test.cc @@ -0,0 +1,702 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "xla/tsl/util/command_line_flags.h" +#include "tensorflow/lite/core/model_builder.h" // TODO: b/321735756 - replace with mlir model_builder +#include "tsl/platform/init_main.h" +#include "tsl/platform/path.h" + +namespace { +std::string* g_test_model_dir = nullptr; +} // namespace + +namespace mlir { +namespace lite { +namespace toco_legacy { +namespace { + +using tflite::BuiltinOperator_CONV_2D; +using tflite::BuiltinOperator_CUSTOM; +using tflite::BuiltinOperator_DEQUANTIZE; +using tflite::FlatBufferModel; // to remove when mlir version is ready, from + // model.h +using tflite::GetModel; +using tflite::Model; +using tflite::TensorType_FLOAT16; +using tflite::TensorType_FLOAT32; +using tflite::TensorType_INT8; + +std::unique_ptr ReadTestModel() { + auto model_path = tsl::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +std::unique_ptr ReadSharedWeightsTestModel() { + auto model_path = tsl::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +std::unique_ptr ReadGatherTestModel() { + auto model_path = tsl::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +std::unique_ptr ReadCustomOpTestModel() { + auto model_path = tsl::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +template +std::vector GetAsVector(const flatbuffers::Vector* vec) { + return std::vector(vec->begin(), vec->end()); +} + +class QuantizeWeightsTest : public testing::Test { + protected: + QuantizeWeightsTest() = default; + + void LoadBasicModel() { + input_model_ = ReadTestModel(); + model_ = input_model_->GetModel(); + } + + void LoadSharedWeightsModel() { + input_model_ = ReadSharedWeightsTestModel(); + model_ = input_model_->GetModel(); + } + + void LoadGatherTestModel() { + input_model_ = ReadGatherTestModel(); + model_ = input_model_->GetModel(); + } + + void LoadCustomOpTestModel() { + input_model_ = ReadCustomOpTestModel(); + model_ = input_model_->GetModel(); + } + + std::unique_ptr input_model_; + const Model* model_; + + bool IsModelInputOrOutput(const Model* model, uint32_t tensor_idx) { + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto subgraph = model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < subgraph->inputs()->size(); ++i) { + if (subgraph->inputs()->Get(i) == tensor_idx) { + return true; + } + } + for (size_t i = 0; i < subgraph->outputs()->size(); ++i) { + if (subgraph->outputs()->Get(i) == tensor_idx) { + return true; + } + } + } + return false; + } + + // Returns the producer op code of the specified tensor_idx. + bool GetProducerOpCode(const Model* model, uint32_t subgraph_idx, + uint32_t tensor_idx, + tflite::BuiltinOperator* op_code) { + const auto subgraph = model->subgraphs()->Get(subgraph_idx); + for (size_t op_idx = 0; op_idx < subgraph->operators()->size(); ++op_idx) { + const auto op = subgraph->operators()->Get(op_idx); + for (size_t i = 0; i < op->outputs()->size(); ++i) { + if (op->outputs()->Get(i) == tensor_idx) { + const uint32_t op_code_idx = op->opcode_index(); + *op_code = GetBuiltinCode(model->operator_codes()->Get(op_code_idx)); + return true; + } + } + } + return false; + } +}; + +TEST_F(QuantizeWeightsTest, QuantizationSucceeds) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); +} + +TEST_F(QuantizeWeightsTest, WeightsMinNumElements) { + LoadBasicModel(); + // Make weights_min_size sufficiently large such that no quantization should + // happen, i.e. the original model is the same size as the old one. + flatbuffers::FlatBufferBuilder builder; + const uint64_t kWeightsMinNumElements = 1000000; + ASSERT_TRUE(QuantizeWeights(&builder, model_, kWeightsMinNumElements, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + subgraph_idx++) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size()); + for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + const auto float_tensor = float_graph->tensors()->Get(i); + // Everything should remain equal between the two graphs. + EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer()); + EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable()); + EXPECT_EQ(GetAsVector(quant_tensor->shape()), + GetAsVector(float_tensor->shape())); + EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str()); + EXPECT_EQ(quant_tensor->type(), float_tensor->type()); + } + } +} + +TEST_F(QuantizeWeightsTest, HybridConv) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + // Nothing should change. + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + subgraph_idx++) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size()); + // Make sure the graph only has one Conv operation. + ASSERT_EQ(quantized_graph->operators()->size(), 1); + const auto op = quantized_graph->operators()->Get(0); + const uint32_t op_code_idx = op->opcode_index(); + ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)), + BuiltinOperator_CONV_2D); + for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + const auto float_tensor = float_graph->tensors()->Get(i); + EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer()); + EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable()); + EXPECT_EQ(GetAsVector(quant_tensor->shape()), + GetAsVector(float_tensor->shape())); + EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str()); + // If the tensor is a weight, it should have type INT8, otherwise it + // should stay with type FLOAT32. + // If the tensor is a bias, it should have type FLOAT32. + if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->buffer() != 0) { + EXPECT_EQ(quant_tensor->type(), TensorType_INT8) + << quant_tensor->name()->str(); + auto shape = GetAsVector(quant_tensor->shape()); + if (kUseUpdatedHybridSchemeDefault) { + EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]); + } else { + EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1); + } + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +TEST_F(QuantizeWeightsTest, DequantizeConv) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(internal::QuantizeWeights(&builder, model_, 0, + /*use_hybrid_evaluation=*/false, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + // The output graph should have an extra tensor from the added dequantize + // op. + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size() + 1); + // Check that a dequantize op exists. + int32_t dequant_input_idx = -1; + int32_t dequant_output_idx = -1; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) == + BuiltinOperator_DEQUANTIZE) { + dequant_input_idx = op->inputs()->Get(0); + dequant_output_idx = op->outputs()->Get(0); + } + } + ASSERT_GT(dequant_input_idx, -1); + ASSERT_GT(dequant_output_idx, -1); + for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + // If the tensor is a weight, it should have type INT8. + // If the tensor is a bias, it should have type FLOAT32. + // If the tensor is an input or output it should have type FLOAT32. + // The input to dequantize should be INT8, and all other tensors should be + // FLOAT32. + if (i == dequant_input_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_INT8); + } else if (i == dequant_output_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->buffer() != 0) { + // If it's a non-bias constant tensor, it must be the weight. + EXPECT_EQ(quant_tensor->type(), TensorType_INT8); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(QuantizeWeights(&builder, model_, BufferType::QUANTIZED_FLOAT16, + kUseUpdatedHybridSchemeDefault, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + // The output graph should have two extra tensors from the added dequantize + // op. + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size() + 2); + // Check that a dequantize op exists. + int32_t dequant_input_idx = -1; + int32_t dequant_output_idx = -1; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) == + BuiltinOperator_DEQUANTIZE) { + dequant_input_idx = op->inputs()->Get(0); + dequant_output_idx = op->outputs()->Get(0); + } + } + ASSERT_GT(dequant_input_idx, -1); + ASSERT_GT(dequant_output_idx, -1); + for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + // If the tensor is a weight, it should have type FLOAT16. + // If the tensor is a bias, it should have type FLOAT16. + // If the tensor is an input or output it should have type FLOAT32. + // The input to dequantize should be FLOAT16, and all other tensors should + // be FLOAT32. + if (i == dequant_input_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else if (i == dequant_output_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else if (quant_tensor->buffer() != 0) { + // If it's a non-bias constant tensor, it must be the weight. + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) { + LoadSharedWeightsModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + uint32_t num_conv_ops = 0; + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == BuiltinOperator_CONV_2D) { + num_conv_ops++; + // Ensure that each convolution's weights tensor is now INT8. + const auto weights_tensor = + quantized_graph->tensors()->Get(op->inputs()->Get(1)); + EXPECT_EQ(weights_tensor->type(), TensorType_INT8); + } + } + } + // Ensure that there were exactly two convolutions in the model. + EXPECT_EQ(num_conv_ops, 2); +} + +TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) { + LoadSharedWeightsModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(internal::QuantizeWeights(&builder, model_, 0, + /*use_hybrid_evaluation*/ false, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + uint32_t num_conv_ops = 0; + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == BuiltinOperator_CONV_2D) { + num_conv_ops++; + // Ensure that each convolution's weights tensor is still FLOAT + // (the output of the dequantize). + uint32_t weights_tensor_index = op->inputs()->Get(1); + const auto weights_tensor = + quantized_graph->tensors()->Get(weights_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32); + + // Check that it comes from a dequantize operation. + BuiltinOperator producer_op_code; + ASSERT_TRUE(GetProducerOpCode(output_model, subgraph_idx, + weights_tensor_index, &producer_op_code)); + EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE); + } + } + } + // Ensure that there were exactly two convolutions in the model. + EXPECT_EQ(num_conv_ops, 2); +} + +TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) { + LoadGatherTestModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == tflite::BuiltinOperator_GATHER) { + uint32_t input_tensor_index = op->inputs()->Get(0); + const auto weights_tensor = + quantized_graph->tensors()->Get(input_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_INT8); + } + } + } +} + +TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationDequantize) { + LoadCustomOpTestModel(); + + // The custom op is not hybrid, and the second input is a constant that can + // be quantized. + CustomOpMap custom_op_map; + custom_op_map["CustomTestOp"] = { + .quantizable_input_indices = {1}, + .is_hybrid = false, + }; + + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + const auto quantized_graph = output_model->subgraphs()->Get(0); + // A dequantize op should be added. + ASSERT_EQ(quantized_graph->operators()->size(), + model_->subgraphs()->Get(0)->operators()->size() + 1); + int num_custom_ops_found = 0; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == BuiltinOperator_CUSTOM) { + uint32_t weights_tensor_index = op->inputs()->Get(1); + const auto weights_tensor = + quantized_graph->tensors()->Get(weights_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32); + + // Check that it comes from a dequantize operation. + BuiltinOperator producer_op_code; + ASSERT_TRUE(GetProducerOpCode(output_model, 0, weights_tensor_index, + &producer_op_code)); + EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE); + num_custom_ops_found++; + } + } + EXPECT_EQ(num_custom_ops_found, 1); +} + +TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationHybrid) { + LoadCustomOpTestModel(); + + // The custom op is hybrid, and the second input is a constant that can + // be quantized. + CustomOpMap custom_op_map; + custom_op_map["CustomTestOp"] = { + .quantizable_input_indices = {1}, + .is_hybrid = true, + }; + + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + const auto quantized_graph = output_model->subgraphs()->Get(0); + ASSERT_EQ(quantized_graph->operators()->size(), + model_->subgraphs()->Get(0)->operators()->size()); + int num_custom_ops_found = 0; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == BuiltinOperator_CUSTOM) { + uint32_t weights_tensor_index = op->inputs()->Get(1); + const auto weights_tensor = + quantized_graph->tensors()->Get(weights_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_INT8); + num_custom_ops_found++; + } + } + EXPECT_EQ(num_custom_ops_found, 1); +} + +TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + const CustomOpMap custom_op_map; + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + /*use_updated_hybrid_scheme=*/false, + /*op_denylist=*/{}, QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + // Nothing should change. + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + subgraph_idx++) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size()); + // Make sure the graph only has one Conv operation. + ASSERT_EQ(quantized_graph->operators()->size(), 1); + const auto op = quantized_graph->operators()->Get(0); + const uint32_t op_code_idx = op->opcode_index(); + ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)), + BuiltinOperator_CONV_2D); + for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + const auto float_tensor = float_graph->tensors()->Get(i); + EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer()); + EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable()); + EXPECT_EQ(GetAsVector(quant_tensor->shape()), + GetAsVector(float_tensor->shape())); + EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str()); + // If the tensor is a weight, it should have type INT8, otherwise it + // should stay with type FLOAT32. + // If the tensor is a bias, it should have type FLOAT32. + if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->buffer() != 0) { + EXPECT_EQ(quant_tensor->type(), TensorType_INT8) + << quant_tensor->name()->str(); + auto shape = GetAsVector(quant_tensor->shape()); + EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +TEST_F(QuantizeWeightsTest, DequantizeConvBlocklisted) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + const CustomOpMap custom_op_map; + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + /*use_updated_hybrid_scheme=*/true, + /*op_denylist*/ {BuiltinOperator_CONV_2D}, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + // The output graph should have an extra tensor from the added dequantize + // op. + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size() + 1); + // Check that a dequantize op exists. + int32_t dequant_input_idx = -1; + int32_t dequant_output_idx = -1; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) == + BuiltinOperator_DEQUANTIZE) { + dequant_input_idx = op->inputs()->Get(0); + dequant_output_idx = op->outputs()->Get(0); + } + } + ASSERT_GT(dequant_input_idx, -1); + ASSERT_GT(dequant_output_idx, -1); + for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + // If the tensor is a weight, it should have type INT8. + // If the tensor is a bias, it should have type FLOAT32. + // If the tensor is an input or output it should have type FLOAT32. + // The input to dequantize should be INT8, and all other tensors should be + // FLOAT32. + if (i == dequant_input_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_INT8); + // The dequantize should still be quantized per-channel + EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 5); + EXPECT_EQ(quant_tensor->quantization()->quantized_dimension(), 0); + } else if (i == dequant_output_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->buffer() != 0) { + // If it's a non-bias constant tensor, it must be the weight. + EXPECT_EQ(quant_tensor->type(), TensorType_INT8); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +} // namespace +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +int main(int argc, char** argv) { + std::string model_file; + const std::vector flag_list = { + tsl::Flag("test_model_file", &model_file, + "Path to test tflite model file."), + }; + + const bool parse_result = tsl::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + std::cerr << "Required test_model_file\n"; + std::abort(); + } + g_test_model_dir = new std::string(tsl::io::Dirname(model_file)); + ::tsl::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 5d96df6238ca0d..067d1cc185c2c6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -549,7 +549,6 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], alwayslink = 1, @@ -641,6 +640,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:custom_call", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:dot_general", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:gather", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:iota", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:pad", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce_window", @@ -648,6 +648,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:sort", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:util", "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -802,6 +803,39 @@ cc_library( ], ) +cc_library( + name = "lift_callsite_loc_caller", + srcs = ["transforms/torch/lift_callsite_loc_caller_pass.cc"], + copts = ["-Ithird_party"], + deps = [ + ":passes_inc_gen", + ":prepare_hlo", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], + alwayslink = True, +) + +cc_library( + name = "build_stablehlo_composite", + srcs = ["transforms/torch/build_stablehlo_composite_pass.cc"], + copts = ["-Ithird_party"], + deps = [ + ":passes_inc_gen", + ":prepare_hlo", + "@com_google_absl//absl/strings", + "@jsoncpp_git//:jsoncpp", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], + alwayslink = True, +) + cc_library( name = "composite_lowering", srcs = [ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir index 7db47a1a3e7703..a06886a8d4688a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir @@ -408,3 +408,107 @@ func.func @testConvertReshapeDotRhsToBatchedDot(%arg0: tensor<1x72x72xf32>, %arg // CHECK-SAME: >}> : (tensor<1x72x72xf32>, tensor<1x72x128xf32>) -> tensor<1x72x128xf32> // CHECK: return %[[R]] : tensor<1x72x128xf32> } + +// ----- + +// CHECK-LABEL: broadcast_reshape_one_non_unit_dimnsion +func.func @broadcast_reshape_one_non_unit_dimnsion(%arg0: tensor<1x1x1x63xf32>) -> tensor<32x1x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x63xf32> + return %1 : tensor<32x1x63xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<1x1x1x63xf32>) -> tensor<63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<63xf32>) -> tensor<32x1x63xf32> +// CHECK: return %1 : tensor<32x1x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_one_non_unit_dimnsion_trailing_zeros +func.func @broadcast_reshape_one_non_unit_dimnsion_trailing_zeros(%arg0: tensor<63x1x1x1xf32>) -> tensor<63x1x2xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<63x1x1x1xf32>) -> tensor<63x1x1x2xf32> + %1 = mhlo.reshape %0 : (tensor<63x1x1x2xf32>) -> tensor<63x1x2xf32> + return %1 : tensor<63x1x2xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<63x1x1x1xf32>) -> tensor<63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<63xf32>) -> tensor<63x1x2xf32> +// CHECK: return %1 : tensor<63x1x2xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_multiple_non_unit_dimension +func.func @broadcast_reshape_multiple_non_unit_dimension(%arg0: tensor<1x2x1x63xf32>) -> tensor<2x3x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x2x1x63xf32>) -> tensor<1x2x3x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x2x3x63xf32>) -> tensor<2x3x63xf32> + return %1 : tensor<2x3x63xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<1x2x1x63xf32>) -> tensor<2x63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>}> : (tensor<2x63xf32>) -> tensor<2x3x63xf32> +// CHECK: return %1 : tensor<2x3x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_multiple_non_unit_dimension_unsorted_broadcast_dims +func.func @broadcast_reshape_multiple_non_unit_dimension_unsorted_broadcast_dims(%arg0: tensor<1x2x1x63xf32>) -> tensor<3x2x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 2, 1, 3]> : tensor<4xi64>}> : (tensor<1x2x1x63xf32>) -> tensor<3x1x2x63xf32> + %1 = mhlo.reshape %0 : (tensor<3x1x2x63xf32>) -> tensor<3x2x63xf32> + return %1 : tensor<3x2x63xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<1x2x1x63xf32>) -> tensor<2x63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<2x63xf32>) -> tensor<3x2x63xf32> +// CHECK: return %1 : tensor<3x2x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_broadcast_increases_rank +func.func @broadcast_reshape_broadcast_increases_rank(%arg0: tensor<1x2x1x63xf32>) -> tensor<2x3x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 4]> : tensor<4xi64>}> : (tensor<1x2x1x63xf32>) -> tensor<1x2x3x1x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x2x3x1x63xf32>) -> tensor<2x3x63xf32> + return %1 : tensor<2x3x63xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<1x2x1x63xf32>) -> tensor<2x63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>}> : (tensor<2x63xf32>) -> tensor<2x3x63xf32> +// CHECK: return %1 : tensor<2x3x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_not_same_non_unit_dims +func.func @broadcast_reshape_not_same_non_unit_dims(%arg0: tensor<63x1x1x1xf32>) -> tensor<2x1x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<63x1x1x1xf32>) -> tensor<63x1x1x2xf32> + %1 = mhlo.reshape %0 : (tensor<63x1x1x2xf32>) -> tensor<2x1x63xf32> + return %1 : tensor<2x1x63xf32> +} + +// CHECK: %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<63x1x1x1xf32>) -> tensor<63x1x1x2xf32> +// CHECK: %1 = mhlo.reshape %0 : (tensor<63x1x1x2xf32>) -> tensor<2x1x63xf32> +// CHECK: return %1 : tensor<2x1x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_multi_use +func.func @broadcast_reshape_multi_use(%arg0: tensor<1x1x1x63xf32>) -> (tensor<32x1x63xf32>, tensor<1x32x1x63xf32>) { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x63xf32> + return %1, %0 : tensor<32x1x63xf32>, tensor<1x32x1x63xf32> +} + +// CHECK: %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> +// CHECK: %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_rank_increase +func.func @broadcast_reshape_rank_increase(%arg0: tensor<1x1x1x63xf32>) -> tensor<32x1x1x1x1x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x1x1x1x63xf32> + return %1 : tensor<32x1x1x1x1x63xf32> +} + +// CHECK: %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> +// CHECK: %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x1x1x1x63xf32> + + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir index db09ecae4bead0..38fc6d57d93015 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir @@ -15,6 +15,75 @@ func.func @main(%arg0: tensor) -> tensor { // 2D //=-- +// CHECK-LABEL: transpose_conv2d_same_padding_nchw_ihwo +func.func @transpose_conv2d_same_padding_nchw_ihwo(%input: tensor<1x2x256x256xf32>, %filter:tensor<2x2x4x4xf32>) -> tensor<1x2x512x512xf32> { + %1 = mhlo.convolution(%input, %filter) + dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], + window = {pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x2x256x256xf32>, tensor<2x2x4x4xf32>) -> tensor<1x2x512x512xf32> + func.return %1 : tensor<1x2x512x512xf32> +} + +// CHECK: %[[TRANSPOSED_INPUT:.*]] = "mhlo.transpose"(%arg0) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 2, 3, 1] +// CHECK: %[[TRANSPOSED_KERNEL:.*]] = "mhlo.transpose"(%arg1) +// CHECK-SAME: permutation +// CHECK-SAME: [1, 2, 3, 0] +// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution(%[[TRANSPOSED_INPUT]], %[[TRANSPOSED_KERNEL]]) +// CHECK-SAME: [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// CHECK: "mhlo.transpose"(%[[CONV_OUT]]) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 3, 1, 2] + +// CHECK-LABEL: transpose_conv2d_same_padding_nchw_oihw +func.func @transpose_conv2d_same_padding_nchw_oihw(%input: tensor<1x2x256x256xf32>, %filter:tensor<2x2x4x4xf32>) -> tensor<1x2x512x512xf32> { + %0 = mhlo.convolution(%input, %filter) + dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], + window = {pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2]} { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<1x2x256x256xf32>, tensor<2x2x4x4xf32>) -> tensor<1x2x512x512xf32> + func.return %0 : tensor<1x2x512x512xf32> +} + +// CHECK: %[[TRANSPOSED_INPUT:.*]] = "mhlo.transpose"(%arg0) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 2, 3, 1] +// CHECK: %[[TRANSPOSED_KERNEL:.*]] = "mhlo.transpose"(%arg1) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 2, 3, 1] +// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution(%[[TRANSPOSED_INPUT]], %[[TRANSPOSED_KERNEL]]) +// CHECK-SAME: [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// CHECK: "mhlo.transpose"(%[[CONV_OUT]]) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 3, 1, 2] + +// ----- + +// CHECK-LABEL: depthwise_transpose_conv2d_same_padding_nchw_hwoi +func.func @depthwise_transpose_conv2d_same_padding_nchw_hwoi(%input: tensor<1x2x20x20xf32>, %filter:tensor<8x8x2x1xf32>) -> tensor<1x2x80x80xf32> { + %1 = mhlo.convolution(%input, %filter) + dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], + window = {pad = [[5, 5], [5, 5]], lhs_dilate = [4, 4]} + {batch_group_count = 1 : i64, feature_group_count = 2 : i64} + : (tensor<1x2x20x20xf32>, tensor<8x8x2x1xf32>) -> tensor<1x2x80x80xf32> + func.return %1 : tensor<1x2x80x80xf32> + + // CHECK: %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<1x2x20x20xf32>) -> tensor<1x20x20x2xf32> + // CHECK: %1 = "mhlo.transpose"(%arg1) <{permutation = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : (tensor<8x8x2x1xf32>) -> tensor<2x8x8x1xf32> + // CHECK: %2 = "mhlo.slice"(%0) <{limit_indices = dense<[1, 20, 20, 1]> : tensor<4xi64>, start_indices = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<4xi64>}> : (tensor<1x20x20x2xf32>) -> tensor<1x20x20x1xf32> + // CHECK: %3 = "mhlo.slice"(%1) <{limit_indices = dense<[1, 8, 8, 1]> : tensor<4xi64>, start_indices = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<4xi64>}> : (tensor<2x8x8x1xf32>) -> tensor<1x8x8x1xf32> + // CHECK: %4 = mhlo.convolution(%2, %3) dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], window = {pad = {{\[\[}}5, 5], [5, 5]], lhs_dilate = [4, 4]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x20x20x1xf32>, tensor<1x8x8x1xf32>) -> tensor<1x80x80x1xf32> + // CHECK: %5 = "mhlo.slice"(%0) <{limit_indices = dense<[1, 20, 20, 2]> : tensor<4xi64>, start_indices = dense<[0, 0, 0, 1]> : tensor<4xi64>, strides = dense<1> : tensor<4xi64>}> : (tensor<1x20x20x2xf32>) -> tensor<1x20x20x1xf32> + // CHECK: %6 = "mhlo.slice"(%1) <{limit_indices = dense<[2, 8, 8, 1]> : tensor<4xi64>, start_indices = dense<[1, 0, 0, 0]> : tensor<4xi64>, strides = dense<1> : tensor<4xi64>}> : (tensor<2x8x8x1xf32>) -> tensor<1x8x8x1xf32> + // CHECK: %7 = mhlo.convolution(%5, %6) dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], window = {pad = {{\[\[}}5, 5], [5, 5]], lhs_dilate = [4, 4]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x20x20x1xf32>, tensor<1x8x8x1xf32>) -> tensor<1x80x80x1xf32> + // CHECK: %8 = "mhlo.concatenate"(%4, %7) <{dimension = 3 : i64}> : (tensor<1x80x80x1xf32>, tensor<1x80x80x1xf32>) -> tensor<1x80x80x2xf32> + // CHECK: %9 = "mhlo.transpose"(%8) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<1x80x80x2xf32>) -> tensor<1x2x80x80xf32> + // CHECK: return %9 : tensor<1x2x80x80xf32> +} + // CHECK-LABEL: conv2d_nhwc_ohwi_nhwc func.func @conv2d_nhwc_ohwi_nhwc(%input: tensor<1x256x256x3xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> { %0 = mhlo.convolution(%input, %filter) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index a6efef43a8fecf..60e40cf1082419 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -72,13 +72,13 @@ func.func @dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) func.return %0 : tensor<3x5x1x4xf32> } -// CHECK: %[[TRANSPOSED_0:.*]] = "tfl.transpose" -// CHECK: %[[TRANSPOSED_1:.*]] = "tfl.transpose" -// CHECK-NEXT: %[[RESHAPED_0:.*]] = mhlo.reshape %[[TRANSPOSED_0]] -// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %[[TRANSPOSED_1]] -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> -// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] -// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32> +// CHECK: %[[TRANSPOSED_0:.*]] = "tfl.transpose" +// CHECK: %[[TRANSPOSED_1:.*]] = "tfl.transpose" +// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%[[TRANSPOSED_0]] +// CHECK: %[[RESHAPED_1:.*]] = "tfl.reshape"(%[[TRANSPOSED_1]] +// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> +// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]] +// CHECK: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32> // ----- @@ -96,11 +96,10 @@ func.func @dot_general_repeated(%arg0: tensor<1x1x1024xf32>, %arg1: tensor<1024x func.return %0 : tensor<1x1x1024xf32> } -// CHECK: %[[RESHAPED_0:.*]] = mhlo.reshape %arg0 -// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1 -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32> -// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] -// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32> +// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%arg0 +// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32> +// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]] +// CHECK: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32> // ----- @@ -115,11 +114,10 @@ func.func @dot_general_int8(%arg0: tensor<256xi8>, %arg1: tensor<256x8xi8>) -> t func.return %0 : tensor<8xi32> } -// CHECK: %[[RESHAPED_0:.*]] = mhlo.reshape %arg0 -// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1 -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32> -// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] -// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<8xi32> +// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%arg0 +// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32> +// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]] +// CHECK: return %[[RESHAPED_BMM]] : tensor<8xi32> // ----- @@ -135,29 +133,30 @@ func.func @dot_general_dynamic_rhs_out_dim(%arg0: tensor<4x4x256xf32>, %arg1: te func.return %0 : tensor<4x4x?xf32> } -// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32> -// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> -// CHECK-NEXT: %3 = mhlo.reshape %arg0 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32> -// CHECK-NEXT: %4 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %7 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%4, %5, %7) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %9 = "tfl.unsorted_segment_prod"(%4, %6, %7) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %9, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %12 = mhlo.dynamic_reshape %2, %11 : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> -// CHECK-NEXT: %13 = "tfl.batch_matmul"(%3, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32> -// CHECK-NEXT: %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> -// CHECK-NEXT: %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64> -// CHECK-NEXT: %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32> -// CHECK-NEXT: %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %21 = mhlo.dynamic_reshape %13, %20 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> -// CHECK-NEXT: return %21 : tensor<4x4x?xf32> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> +// CHECK: %3 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> +// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %10 = "tfl.concatenation"(%9, %8, %7) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %11 = "tfl.cast"(%10) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %12 = "tfl.reshape"(%2, %11) : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> +// CHECK: %13 = "tfl.batch_matmul"(%arg0, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32> +// CHECK: %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32> +// CHECK: %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> +// CHECK: %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK: %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK: %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK: %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32> +// CHECK: %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %21 = "tfl.cast"(%20) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %22 = "tfl.reshape"(%13, %21) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> +// CHECK: return %22 : tensor<4x4x?xf32> // ----- @@ -173,43 +172,45 @@ func.func @dot_general_dynamic_batch_dim(%arg0: tensor<2x?x2x3xf32>, %arg1: tens func.return %0 : tensor<2x?x2x4xf32> } -// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> -// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> -// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> -// CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> -// CHECK-NEXT: %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK-NEXT: %12 = mhlo.dynamic_reshape %arg0, %11 : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32> -// CHECK-NEXT: %13 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %19 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> -// CHECK-NEXT: %20 = "tfl.gather"(%13, %19) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %21 = "tfl.concatenation"(%20, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK-NEXT: %22 = mhlo.dynamic_reshape %2, %21 : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> -// CHECK-NEXT: %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32> -// CHECK-NEXT: %24 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %25 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK-NEXT: %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> -// CHECK-NEXT: %28 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64> -// CHECK-NEXT: %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32> -// CHECK-NEXT: %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK-NEXT: %31 = mhlo.dynamic_reshape %23, %30 : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32> -// CHECK-NEXT: return %31 : tensor<2x?x2x4xf32> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> +// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> +// CHECK: %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> +// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK: %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK: %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %12 = "tfl.cast"(%11) : (tensor<4xi32>) -> tensor<4xi32> +// CHECK: %13 = "tfl.reshape"(%arg0, %12) : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32> +// CHECK: %14 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> +// CHECK-DAG: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %16 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %17 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %18 = "tfl.unsorted_segment_prod"(%14, %15, %17) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %19 = "tfl.unsorted_segment_prod"(%14, %16, %17) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %20 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK: %21 = "tfl.gather"(%14, %20) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK: %22 = "tfl.concatenation"(%21, %19, %18) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %23 = "tfl.cast"(%22) : (tensor<4xi32>) -> tensor<4xi32> +// CHECK: %24 = "tfl.reshape"(%2, %23) : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> +// CHECK: %25 = "tfl.batch_matmul"(%13, %24) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32> +// CHECK: %26 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> +// CHECK: %27 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> +// CHECK: %28 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %29 = "tfl.gather"(%26, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> +// CHECK: %30 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK: %31 = "tfl.gather"(%27, %30) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32> +// CHECK: %32 = "tfl.concatenation"(%29, %31) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %33 = "tfl.cast"(%32) : (tensor<4xi32>) -> tensor<4xi32> +// CHECK: %34 = "tfl.reshape"(%25, %33) : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32> +// CHECK: return %34 : tensor<2x?x2x4xf32> // ----- - // CHECK-LABEL: dot_general_dynamic_lhs_rhs_out_dims func.func @dot_general_dynamic_lhs_rhs_out_dims(%arg0: tensor<2x2x?x3xf32>, %arg1: tensor<2x4x?x3xf32>) -> tensor<2x2x?x4x?xf32> { %0 = "mhlo.dot_general"(%arg0, %arg1) { @@ -222,37 +223,40 @@ func.func @dot_general_dynamic_lhs_rhs_out_dims(%arg0: tensor<2x2x?x3xf32>, %arg func.return %0 : tensor<2x2x?x4x?xf32> } -// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> -// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> -// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32> -// CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %11 = mhlo.dynamic_reshape %arg0, %10 : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32> -// CHECK-NEXT: %12 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %13 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %16 = "tfl.unsorted_segment_prod"(%12, %13, %15) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%12, %14, %15) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %19 = "tfl.concatenation"(%18, %17, %16) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %20 = mhlo.dynamic_reshape %2, %19 : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32> -// CHECK-NEXT: %21 = "tfl.batch_matmul"(%11, %20) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32> -// CHECK-NEXT: %22 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %23 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %24 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK-NEXT: %25 = "tfl.gather"(%22, %24) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> -// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64> -// CHECK-NEXT: %27 = "tfl.gather"(%23, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %28 = "tfl.concatenation"(%25, %27) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32> -// CHECK-NEXT: %29 = mhlo.dynamic_reshape %21, %28 : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32> -// CHECK-NEXT: return %29 : tensor<2x2x?x4x?xf32> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> +// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32> +// CHECK: %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> +// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %11 = "tfl.cast"(%10) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %12 = "tfl.reshape"(%arg0, %11) : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32> +// CHECK: %13 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> +// CHECK-DAG: %14 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %19 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %20 = "tfl.concatenation"(%19, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %21 = "tfl.cast"(%20) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %22 = "tfl.reshape"(%2, %21) : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32> +// CHECK: %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32> +// CHECK: %24 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> +// CHECK: %25 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> +// CHECK: %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> +// CHECK: %28 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK: %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK: %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32> +// CHECK: %31 = "tfl.cast"(%30) : (tensor<5xi32>) -> tensor<5xi32> +// CHECK: %32 = "tfl.reshape"(%23, %31) : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32> +// CHECK: return %32 : tensor<2x2x?x4x?xf32 // ----- @@ -268,27 +272,28 @@ func.func @dot_general_dynamic_contracting_dim(%arg0: tensor<4x4x?xf32>, %arg1: func.return %0 : tensor<4x4x256xf32> } -// CHECK: %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32> -// CHECK-NEXT: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %8 = mhlo.dynamic_reshape %arg0, %7 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> -// CHECK-NEXT: %9 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %11 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %12 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %13 = "tfl.unsorted_segment_prod"(%9, %10, %12) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %14 = "tfl.unsorted_segment_prod"(%9, %11, %12) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %16 = "tfl.concatenation"(%15, %14, %13) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %17 = mhlo.dynamic_reshape %arg1, %16 : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32> -// CHECK-NEXT: %18 = "tfl.batch_matmul"(%8, %17) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32> -// CHECK-NEXT: %19 = mhlo.reshape %18 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32> -// CHECK-NEXT: return %19 : tensor<4x4x256xf32> +// CHECK: %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32> +// CHECK-DAG: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %8 = "tfl.cast"(%7) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %9 = "tfl.reshape"(%arg0, %8) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> +// CHECK: %10 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> +// CHECK-DAG: %11 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %12 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %13 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %14 = "tfl.unsorted_segment_prod"(%10, %11, %13) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %15 = "tfl.unsorted_segment_prod"(%10, %12, %13) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %16 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %17 = "tfl.concatenation"(%16, %15, %14) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %18 = "tfl.cast"(%17) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %19 = "tfl.reshape"(%arg1, %18) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32> +// CHECK: %20 = "tfl.batch_matmul"(%9, %19) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32> +// CHECK: return %20 : tensor<4x4x256xf32> // ----- @@ -318,14 +323,10 @@ func.func @argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32 func.return %4#0, %4#1 : tensor<4x32xf32>, tensor<4x32xi32> } -// CHECK: %0 = mhlo.constant dense<0xFF800000> : tensor -// CHECK-DAG: %1 = mhlo.constant dense<0> : tensor -// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> -// CHECK: %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> -// CHECK: %cst = arith.constant dense<2> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> -// CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> -// CHECK: return %4, %5 : tensor<4x32xf32>, tensor<4x32xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<2> : tensor<1xi32> +// CHECK: %[[REDUCE:.*]] = "tfl.reduce_max"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> +// CHECK: %[[ARG:.*]] = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> +// CHECK: return %[[REDUCE]], %[[ARG]] : tensor<4x32xf32>, tensor<4x32xi32> // ----- @@ -410,12 +411,11 @@ func.func @argmax_bool(%arg0: tensor<2xi1>) -> tensor { return %3#1 : tensor } -// CHECK: %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> // CHECK-DAG: %1 = mhlo.constant dense : tensor -// CHECK: %2 = mhlo.constant dense<0> : tensor -// CHECK: %cst = arith.constant dense<0> : tensor<1xi32> -// CHECK: %3 = "tfl.reduce_any"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor -// CHECK: %4 = "tfl.arg_max"(%arg0, %cst) : (tensor<2xi1>, tensor<1xi32>) -> tensor +// CHECK-DAG: %2 = mhlo.constant dense<0> : tensor +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<1xi32> +// CHECK: %3 = "tfl.reduce_any"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor +// CHECK: %4 = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<2xi1>, tensor<1xi32>) -> tensor // CHECK: return %4 : tensor // ----- @@ -442,14 +442,10 @@ func.func @argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32 func.return %4#0, %4#1 : tensor<4x32xf32>, tensor<4x32xi32> } -// CHECK-DAG: %0 = mhlo.constant dense<0x7F800000> : tensor -// CHECK: %1 = mhlo.constant dense<0> : tensor -// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> -// CHECK: %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> -// CHECK: %cst = arith.constant dense<2> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> -// CHECK: %5 = "tfl.arg_min"(%arg0, %cst) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> -// CHECK: return %4, %5 : tensor<4x32xf32>, tensor<4x32xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<2> : tensor<1xi32> +// CHECK: %[[REDUCE:.*]] = "tfl.reduce_min"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> +// CHECK: %[[ARG:.*]] = "tfl.arg_min"(%arg0, %[[CST]]) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> +// CHECK: return %[[REDUCE]], %[[ARG]] : tensor<4x32xf32>, tensor<4x32xi32> // ----- @@ -474,14 +470,10 @@ func.func @argmin_i16(%arg0: tensor<2xi16>) -> (tensor, tensor) { func.return %4#0, %4#1 : tensor, tensor } -// CHECK: %0 = mhlo.constant dense : tensor -// CHECK: %1 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> -// CHECK-DAG: %2 = mhlo.constant dense<32767> : tensor -// CHECK: %3 = mhlo.constant dense<0> : tensor -// CHECK: %cst = arith.constant dense<0> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi16>, tensor<1xi32>) -> tensor -// CHECK: %5 = "tfl.arg_min"(%arg0, %cst) : (tensor<2xi16>, tensor<1xi32>) -> tensor -// CHECK: return %4, %5 : tensor, tensor +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<1xi32> +// CHECK: %[[REDUCE:.*]] = "tfl.reduce_min"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<2xi16>, tensor<1xi32>) -> tensor +// CHECK: %[[ARG:.*]] = "tfl.arg_min"(%arg0, %[[CST]]) : (tensor<2xi16>, tensor<1xi32>) -> tensor +// CHECK: return %[[REDUCE]], %[[ARG]] : tensor, tensor // ----- @@ -535,12 +527,11 @@ func.func @argmin_bool(%arg0: tensor<2xi1>) -> tensor { return %3#1 : tensor } -// CHECK: %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> // CHECK-DAG: %1 = mhlo.constant dense : tensor -// CHECK: %2 = mhlo.constant dense<0> : tensor -// CHECK: %cst = arith.constant dense<0> : tensor<1xi32> -// CHECK: %3 = "tfl.reduce_all"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor -// CHECK: %4 = "tfl.arg_min"(%arg0, %cst) : (tensor<2xi1>, tensor<1xi32>) -> tensor +// CHECK-DAG: %2 = mhlo.constant dense<0> : tensor +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<1xi32> +// CHECK: %3 = "tfl.reduce_all"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor +// CHECK: %4 = "tfl.arg_min"(%arg0, %[[CST]]) : (tensor<2xi1>, tensor<1xi32>) -> tensor // CHECK: return %4 : tensor // ----- @@ -567,14 +558,10 @@ func.func @argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tensor<1x1xf func.return %4#0, %4#1 : tensor<1x1xf32>, tensor<1x1xi32> } -// CHECK-DAG: %0 = mhlo.constant dense<0xFF800000> : tensor -// CHECK: %1 = mhlo.constant dense<0> : tensor -// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32> -// CHECK: %3 = mhlo.reshape %2 : (tensor<32xi32>) -> tensor<1x32x1xi32> -// CHECK: %cst = arith.constant dense<1> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32> -// CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: return %4, %5 : tensor<1x1xf32>, tensor<1x1xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<1> : tensor<1xi32> +// CHECK: %[[REDUCE:.*]] = "tfl.reduce_max"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32> +// CHECK: %[[ARG:.*]] = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: return %[[REDUCE]], %[[ARG]] : tensor<1x1xf32>, tensor<1x1xi32> // ----- @@ -597,14 +584,9 @@ func.func @pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> { func.return %4#1 : tensor<1xi32> } -// CHECK: %0 = mhlo.constant dense<0> : tensor -// CHECK-DAG: %1 = mhlo.constant dense<-2147483648> : tensor -// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32> -// CHECK: %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32> -// CHECK: %cst = arith.constant dense<1> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> -// CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> -// CHECK: return %5 : tensor<1xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<1> : tensor<1xi32> +// CHECK: %[[ARG:.*]] = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> +// CHECK: return %[[ARG]] : tensor<1xi32> // ----- @@ -618,11 +600,11 @@ func.func @cbrt_f32(%arg0: tensor<1x32x1xf32>) -> tensor<1x32x1xf32> { func.return %0 : tensor<1x32x1xf32> } -// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor -// CHECK: %cst_0 = arith.constant dense<3.000000e+00> : tensor -// CHECK: %0 = tfl.div %cst, %cst_0 {fused_activation_function = "NONE"} : tensor -// CHECK: %1 = tfl.pow(%arg0, %0) : (tensor<1x32x1xf32>, tensor) -> tensor<1x32x1xf32> -// CHECK: return %1 : tensor<1x32x1xf32> +// CHECK-DAG: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK-DAG: %cst_0 = arith.constant dense<3.000000e+00> : tensor +// CHECK: %0 = tfl.div %cst, %cst_0 {fused_activation_function = "NONE"} : tensor +// CHECK: %1 = tfl.pow(%arg0, %0) : (tensor<1x32x1xf32>, tensor) -> tensor<1x32x1xf32> +// CHECK: return %1 : tensor<1x32x1xf32> // ----- @@ -636,6 +618,100 @@ func.func @cbrt_f64(%arg0: tensor<1x32x1xf64>) -> tensor<1x32x1xf64> { // ----- +//===----------------------------------------------------------------------===// +// mhlo.(dynamic)reshape +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: reshape +func.func @reshape(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2xf32> + func.return %0 : tensor<3x2xf32> +} + +// CHECK: %cst = arith.constant dense<[3, 2]> : tensor<2xi64> +// CHECK: %0 = "tfl.cast"(%cst) : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + +// ----- + +// CHECK-LABEL: dynamic_reshape_i32 +func.func @dynamic_reshape_i32(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor { + %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor + func.return %0 : tensor +} + +// CHECK: %0 = "tfl.cast"(%arg1) : (tensor<2xi32>) -> tensor<2xi32> +// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor + +// ----- + +// CHECK-LABEL: dynamic_reshape_i64 +func.func @dynamic_reshape_i64(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi64>) -> tensor { + %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor + func.return %0 : tensor +} + +// CHECK: %0 = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo binary bit-wise ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: logical_and +func.func @logical_and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { + %0 = mhlo.and %arg0, %arg1 : tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.logical_and +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: bitwise_and +func.func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.and %arg0, %arg1 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK: mhlo.and +// CHECK-NOT: tfl + +// ----- + +// CHECK-LABEL: logical_or +func.func @logical_or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { + %0 = mhlo.or %arg0, %arg1 : tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.logical_or +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: bitwise_or +func.func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.or %arg0, %arg1 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK: mhlo.or +// CHECK-NOT: tfl + +// ----- + +// CHECK-LABEL: logical_xor +func.func @logical_xor(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + %0 = mhlo.xor %arg0, %arg1 : tensor<4xi1> + func.return %0 : tensor<4xi1> +} + +// ----- + //===----------------------------------------------------------------------===// // mhlo.convolution //===----------------------------------------------------------------------===// @@ -644,6 +720,73 @@ func.func @cbrt_f64(%arg0: tensor<1x32x1xf64>) -> tensor<1x32x1xf64> { // 2D //=--- +// CHECK-LABEL: transpose_conv2d_valid_padding_odd +func.func @transpose_conv2d_valid_padding_odd(%arg0: tensor<1x200x198x4xf32>, %arg1: tensor<4x4x4x4xf32>) -> tensor<1x402x398x4xf32> { + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], + window = {pad = [[3, 3], [3, 3]],lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x200x198x4xf32>, tensor<4x4x4x4xf32>) -> tensor<1x402x398x4xf32> + func.return %0 : tensor<1x402x398x4xf32> + // CHECK %cst = arith.constant dense<0.000000e+00> : tensor<4xf32> + // CHECK %cst_0 = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK %0 = "tfl.reverse_v2"(%arg1, %cst_0) : (tensor<4x4x4x4xf32>, tensor<2xi32>) -> tensor<4x4x4x4xf32> + // CHECK %cst_1 = arith.constant dense<[1, 402, 398, 4]> : tensor<4xi32> + // CHECK %1 = "tfl.transpose_conv"(%cst_1, %0, %arg0, %cst) <{fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<4x4x4x4xf32>, tensor<1x200x198x4xf32>, tensor<4xf32>) -> tensor<1x402x398x4xf32> + // CHECK return %1 : tensor<1x402x398x4xf32> +} + +// CHECK-LABEL: transpose_conv2d_same_padding +func.func @transpose_conv2d_same_padding(%input: tensor<1x256x256x2xf32>, %filter:tensor<2x4x4x2xf32>) -> tensor<1x512x512x2xf32> { + %0 = mhlo.convolution(%input, %filter) + dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], + window = {pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x256x256x2xf32>, tensor<2x4x4x2xf32>) -> tensor<1x512x512x2xf32> + func.return %0 : tensor<1x512x512x2xf32> + // CHECK %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> + // CHECK %cst_0 = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK %0 = "tfl.reverse_v2"(%arg1, %cst_0) : (tensor<2x4x4x2xf32>, tensor<2xi32>) -> tensor<2x4x4x2xf32> + // CHECK %1 = "tfl.pseudo_const"() <{value = dense<[1, 512, 512, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK %2 = "tfl.transpose_conv"(%1, %0, %arg0, %cst) <{fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<2x4x4x2xf32>, tensor<1x256x256x2xf32>, tensor<2xf32>) -> tensor<1x512x512x2xf32> + // CHECK return %2 : tensor<1x512x512x2xf32> +} + +// ----- + +// CHECK-LABEL: transpose_conv2d_valid_padding +func.func @transpose_conv2d_valid_padding(%input: tensor<1x256x256x2xf32>, %filter:tensor<2x4x4x2xf32>) -> tensor<1x514x514x2xf32> { + %0 = mhlo.convolution(%input, %filter) + dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], + window = {pad = [[3, 3], [3, 3]], lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x256x256x2xf32>, tensor<2x4x4x2xf32>) -> tensor<1x514x514x2xf32> + func.return %0 : tensor<1x514x514x2xf32> + // CHECK %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> + // CHECK %cst_0 = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK %0 = "tfl.reverse_v2"(%arg1, %cst_0) : (tensor<2x4x4x2xf32>, tensor<2xi32>) -> tensor<2x4x4x2xf32> + // CHECK %1 = "tfl.pseudo_const"() <{value = dense<[1, 514, 514, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK %2 = "tfl.transpose_conv"(%1, %0, %arg0, %cst) <{fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<2x4x4x2xf32>, tensor<1x256x256x2xf32>, tensor<2xf32>) -> tensor<1x514x514x2xf32> + // CHECK return %2 : tensor<1x514x514x2xf32> +} + +// ----- + +// CHECK-LABEL: transpose_conv2d_valid_padding_equal_strides +func.func @transpose_conv2d_valid_padding_equal_strides(%arg0: tensor<1x200x198x3xf32>, %arg1: tensor<3x3x3x3xf32>) -> tensor<1x401x397x3xf32> { + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], + window = {pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x200x198x3xf32>, tensor<3x3x3x3xf32>) -> tensor<1x401x397x3xf32> + func.return %0 : tensor<1x401x397x3xf32> + // CHECK %cst = arith.constant dense<0.000000e+00> : tensor<3xf32> + // CHECK %cst_0 = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK %0 = "tfl.reverse_v2"(%arg1, %cst_0) : (tensor<3x3x3x3xf32>, tensor<2xi32>) -> tensor<3x3x3x3xf32> + // CHECK %cst_1 = arith.constant dense<[1, 401, 397, 3]> : tensor<4xi32> + // CHECK %1 = "tfl.transpose_conv"(%cst_1, %0, %arg0, %cst) <{fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<3x3x3x3xf32>, tensor<1x200x198x3xf32>, tensor<3xf32>) -> tensor<1x401x397x3xf32> + // CHECK return %1 : tensor<1x401x397x3xf32> +} // CHECK-LABEL: conv2d_nhwc_ohwi_nhwc func.func @conv2d_nhwc_ohwi_nhwc(%input: tensor<1x256x256x3xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> { %0 = mhlo.convolution(%input, %filter) @@ -906,7 +1049,6 @@ func.func @depthwise_conv2d_nhwc_ihwo_nhwc_non_trivial_depth_multiplier(%arg0: t // ----- -// TODO: b/351437662 - Add support for conv to resize. // CHECK-LABEL: conv2d_resize_perferred_nhwc_hwoi_nhwc func.func @conv2d_resize_perferred_nhwc_hwoi_nhwc(%arg0: tensor<1x56x1248x16xf32>, %arg1: tensor<16x3x1x1xf32>) -> tensor<1x111x1248x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) { @@ -920,13 +1062,13 @@ func.func @conv2d_resize_perferred_nhwc_hwoi_nhwc(%arg0: tensor<1x56x1248x16xf32 window_strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<1x56x1248x16xf32>, tensor<16x3x1x1xf32>) -> tensor<1x111x1248x16xf32> func.return %0 : tensor<1x111x1248x16xf32> + // CHECK %0 = "tfl.pseudo_const"() <{value = dense<[111, 1248]> : tensor<2xi32>}> : () -> tensor<2xi32> + // CHECK %1 = "tfl.resize_bilinear"(%arg0, %0) <{align_corners = false, half_pixel_centers = false}> : (tensor<1x56x1248x16xf32>, tensor<2xi32>) -> tensor<1x111x1248x16xf32> + // CHECK return %1 : tensor<1x111x1248x16xf32> } -// CHECK-NOT: tfl - // ----- -// TODO: b/351437662 - Add support for conv to resize. // CHECK-LABEL: conv2d_to_resize_nhwc_hwoi_nhwc func.func @conv2d_to_resize_nhwc_hwoi_nhwc(%arg0: tensor<1x56x624x16xf32>, %arg1: tensor<16x1x257x1xf32>) -> tensor<1x56x904x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) { @@ -940,10 +1082,11 @@ func.func @conv2d_to_resize_nhwc_hwoi_nhwc(%arg0: tensor<1x56x624x16xf32>, %arg1 window_strides = dense<[1, 89]> : tensor<2xi64> } : (tensor<1x56x624x16xf32>, tensor<16x1x257x1xf32>) -> tensor<1x56x904x16xf32> func.return %0 : tensor<1x56x904x16xf32> + // CHECK %0 = "tfl.pseudo_const"() <{value = dense<[56, 904]> : tensor<2xi32>}> : () -> tensor<2xi32> + // CHECK %1 = "tfl.resize_bilinear"(%arg0, %0) <{align_corners = true, half_pixel_centers = false}> : (tensor<1x56x624x16xf32>, tensor<2xi32>) -> tensor<1x56x904x16xf32> + // CHECK return %1 : tensor<1x56x904x16xf32> } -// CHECK-NOT: tfl - // ----- // @@ -1473,7 +1616,7 @@ func.func @gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> tens func.return %0 : tensor<4x64x128xf32> } -// CHECK: %[[VAL_0:.*]] = mhlo.reshape %arg1 : (tensor<4x64xi32>) -> tensor<4x64x1xi32> +// CHECK: %[[VAL_0:.*]] = "tfl.reshape"(%arg1, %0) : (tensor<4x64xi32>, tensor<3xi32>) -> tensor<4x64x1xi32 // CHECK: %[[VAL_1:.*]] = "tfl.gather_nd"(%arg0, %[[VAL_0]]) : (tensor<98x128xf32>, tensor<4x64x1xi32>) -> tensor<4x64x128xf32> // ----- @@ -1900,6 +2043,26 @@ func.func @maxpool_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x // CHECK: return // CHECK-SAME: tensor<4x3x8x8xf32> +// ----- + +//===------------------------------------------------------------------------=== +// mhlo.reduce_window -> tfl.cumsum +//===------------------------------------------------------------------------=== + +// CHECK-LABEL: reduce_window_sum +func.func @reduce_window_sum(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %2 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[3, 0], [0, 0]]> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[4, 1]> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x12xf32>, tensor) -> tensor<4x12xf32> + func.return %1 : tensor<4x12xf32> +} + +// CHECK: %[[AXIS:.*]] = arith.constant dense<0> : tensor +// CHECK: "tfl.cumsum"(%arg0, %[[AXIS]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor) -> tensor<4x12xf32> + // ----- @@ -1939,8 +2102,9 @@ func.func @sort_to_topk_iota_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf3 func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32> } -// CHECK: %cst = arith.constant dense<6> : tensor -// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32>) +// CHECK: arith.constant dense<6> : tensor +// CHECK: %[[CST:.*]] = arith.constant dense<6> : tensor +// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %[[CST]]) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32>) // ----- @@ -1956,8 +2120,8 @@ func.func @sort_to_topk_iota_cst_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32> } -// CHECK: %cst = arith.constant dense<6> : tensor -// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32>) +// CHECK: %[[CST:.*]] = arith.constant dense<6> : tensor +// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %[[CST]]) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32>) // ----- @@ -1972,8 +2136,43 @@ func.func @sort_to_topk_const(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tenso func.return %1#0, %1#1 : tensor<3x6xf32>, tensor<3x6xi32> } -// CHECK: %cst = arith.constant dense<6> : tensor -// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<6> : tensor +// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %[[CST]]) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32> + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.iota +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: iota_1d +func.func @iota_1d() -> tensor<123xf32> { + %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<123xf32> + func.return %0 : tensor<123xf32> +} + +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<1.230000e+02> : tensor +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<1.000000e+00> : tensor +// CHECK: "tfl.range"(%[[CST_1]], %[[CST_2]], %[[CST_3]]) : (tensor, tensor, tensor) -> tensor<123xf32> + +// ----- + +// CHECK-LABEL: iota_3d +func.func @iota_3d() -> tensor<5x7x9xi32> { + %0 = "mhlo.iota"() <{ iota_dimension = 1 : i64 }> : () -> tensor<5x7x9xi32> + func.return %0 : tensor<5x7x9xi32> +} + +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<7> : tensor +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<1> : tensor +// CHECK: %[[RANGE:.*]] = "tfl.range"(%[[CST_1]], %[[CST_2]], %[[CST_3]]) : (tensor, tensor, tensor) -> tensor<7xi32> +// CHECK: %[[CST_4:.*]] = arith.constant dense<[1, 7, 1]> : tensor<3xi64> +// CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST_4]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%[[RANGE]], %[[CAST]]) : (tensor<7xi32>, tensor<3xi32>) -> tensor<1x7x1xi32> +// CHECK: %[[CST_5:.*]] = arith.constant dense<[5, 7, 9]> : tensor<3xi64> +// CHECK: "tfl.broadcast_to"(%[[RESHAPE]], %[[CST_5]]) : (tensor<1x7x1xi32>, tensor<3xi64>) -> tensor<5x7x9xi32> // ----- @@ -2038,6 +2237,587 @@ func.func @dynamic_slice_splat_sizes(%arg0: tensor<7x3xf32>, %arg1: tensor, // ----- +//===----------------------------------------------------------------------===// +// mhlo.dynamic_update_slice +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: dynamic_update_slice +func.func @dynamic_update_slice(%arg0: tensor<28x1x100xf32>, %arg1: tensor<1x1x100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<28x1x100xf32> { + %0 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor, tensor, tensor) -> tensor<28x1x100xf32> + func.return %0 : tensor<28x1x100xf32> +} + +// CHECK: %0 = "tfl.pack"(%arg2, %arg3, %arg4) <{axis = 0 : i32, values_count = 3 : i32}> : (tensor, tensor, tensor) -> tensor<3xi32> +// CHECK: %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor<3xi32>) -> tensor<28x1x100xf32> + +// ----- + +// CHECK-LABEL: dynamic_update_slice_inputs_have_dynamic_dim +func.func @dynamic_update_slice_inputs_have_dynamic_dim(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = mhlo.dynamic_update_slice %arg0, %arg1, %arg2, %arg3 : (tensor, tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: %0 = "tfl.pack"(%arg2, %arg3) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor, tensor) -> tensor<2xi32> +// CHECK: %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor, tensor, tensor<2xi32>) -> tensor + +// ----- + +// CHECK-LABEL: dynamic_update_slice_operand_has_dynamic_dim +func.func @dynamic_update_slice_operand_has_dynamic_dim(%arg0: tensor<1x?x256xf32>, %arg1: tensor<1x1x256xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<1x?x256xf32> { + %0 = mhlo.dynamic_update_slice %arg0, %arg1, %arg2, %arg3, %arg4 : (tensor<1x?x256xf32>, tensor<1x1x256xf32>, tensor, tensor, tensor) -> tensor<1x?x256xf32> + func.return %0 : tensor<1x?x256xf32> +} + +// CHECK: %0 = "tfl.pack"(%arg2, %arg3, %arg4) <{axis = 0 : i32, values_count = 3 : i32}> : (tensor, tensor, tensor) -> tensor<3xi32> +// CHECK: %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor<1x?x256xf32>, tensor<1x1x256xf32>, tensor<3xi32>) -> tensor<1x?x256xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// rounding +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: round +func.func @round(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = mhlo.constant dense<2.000000e+00> : tensor<8x128xf32> + %1 = mhlo.constant dense<5.000000e-01> : tensor<8x128xf32> + %2 = mhlo.constant dense<1.000000e+00> : tensor<8x128xf32> + %3 = "mhlo.floor"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %4 = mhlo.subtract %arg0, %3 : tensor<8x128xf32> + %5 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1> + %6 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1> + %7 = mhlo.multiply %arg0, %1 : tensor<8x128xf32> + %8 = "mhlo.floor"(%7) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %9 = mhlo.multiply %8, %0 : tensor<8x128xf32> + %10 = mhlo.subtract %3, %9 : tensor<8x128xf32> + %11 = "mhlo.compare"(%10, %2) {comparison_direction = #mhlo} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1> + %12 = mhlo.and %6, %11 : tensor<8x128xi1> + %13 = mhlo.or %5, %12 : tensor<8x128xi1> + %14 = mhlo.add %3, %2 : tensor<8x128xf32> + %15 = "mhlo.select"(%13, %14, %3) : (tensor<8x128xi1>, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + func.return %15 : tensor<8x128xf32> +} + +// CHECK: "tfl.round"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32> + +// ----- + +// CHECK-LABEL: floor_mod_float +func.func @floor_mod_float(%arg0: tensor<192x8xf32>, %arg1: tensor<192x8xf32>) -> tensor<192x8xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xf32> + %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xf32> + %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1> + %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %6 = mhlo.and %4, %5 : tensor<192x8xi1> + %7 = mhlo.add %1, %arg1 : tensor<192x8xf32> + %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32> + func.return %8 : tensor<192x8xf32> +} + +// CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32> + +// ----- + +// CHECK-LABEL: floor_mod_int +func.func @floor_mod_int(%arg0: tensor<192x8xi32>, %arg1: tensor<192x8xi32>) -> tensor<192x8xi32> { + %0 = mhlo.constant dense<0> : tensor<192x8xi32> + %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xi32> + %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1> + %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %6 = mhlo.and %4, %5 : tensor<192x8xi1> + %7 = mhlo.add %1, %arg1 : tensor<192x8xi32> + %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32> + func.return %8 : tensor<192x8xi32> +} + +// CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32> + +// ----- + +// CHECK-LABEL: floor_mod_float_cst +func.func @floor_mod_float_cst(%arg0: tensor<192x8xf32>) -> tensor<192x8xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xf32> + %1 = mhlo.constant dense<2.000000e+00> : tensor<192x8xf32> + %2 = mhlo.remainder %arg0, %1 : tensor<192x8xf32> + %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %5 = mhlo.and %3, %4 : tensor<192x8xi1> + %6 = mhlo.add %2, %1 : tensor<192x8xf32> + %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32> + func.return %7 : tensor<192x8xf32> +} + +// CHECK: %cst = arith.constant dense<2.000000e+00> : tensor<192x8xf32> +// CHECK: "tfl.floor_mod"(%arg0, %cst) : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32> + +// ----- + +// CHECK-LABEL: floor_mod_int_cst +func.func @floor_mod_int_cst(%arg0: tensor<192x8xi32>) -> tensor<192x8xi32> { + %0 = mhlo.constant dense<0> : tensor<192x8xi32> + %1 = mhlo.constant dense<2> : tensor<192x8xi32> + %2 = mhlo.remainder %arg0, %1 : tensor<192x8xi32> + %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %5 = mhlo.and %3, %4 : tensor<192x8xi1> + %6 = mhlo.add %2, %1 : tensor<192x8xi32> + %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32> + func.return %7 : tensor<192x8xi32> +} + +// CHECK: %cst = arith.constant dense<2> : tensor<192x8xi32> +// CHECK: "tfl.floor_mod"(%arg0, %cst) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32> + +// ----- + +// CHECK-LABEL: floor_div +func.func @floor_div(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32> + %1 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32> + %2 = mhlo.remainder %arg0, %arg1 : tensor<10x10xf32> + %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %4 = "mhlo.sign"(%arg1) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %5 = "mhlo.sign"(%2) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %6 = "mhlo.compare"(%4, %5) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %7 = mhlo.and %3, %6 : tensor<10x10xi1> + %8 = mhlo.subtract %arg0, %2 : tensor<10x10xf32> + %9 = mhlo.divide %8, %arg1 : tensor<10x10xf32> + %10 = mhlo.add %9, %1 : tensor<10x10xf32> + %11 = "mhlo.select"(%7, %10, %9) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %12 = "mhlo.round_nearest_afz"(%11) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %13 = "mhlo.tuple"(%12) : (tensor<10x10xf32>) -> tuple> + func.return %12 : tensor<10x10xf32> +} + +// CHECK: tfl.floor_div %arg0, %arg1 : tensor<10x10xf32 + +// ----- + +// CHECK-LABEL: floor_div_cst +func.func @floor_div_cst(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32> + %1 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32> + %2 = mhlo.constant dense<1.000000e+00> : tensor<10x10xf32> + %3 = mhlo.constant dense<5.000000e-01> : tensor<10x10xf32> + %4 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32> + %5 = mhlo.remainder %arg0, %0 : tensor<10x10xf32> + %6 = "mhlo.compare"(%5, %1) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %7 = "mhlo.sign"(%5) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %8 = "mhlo.compare"(%2, %7) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %9 = mhlo.and %6, %8 : tensor<10x10xi1> + %10 = mhlo.subtract %arg0, %5 : tensor<10x10xf32> + %11 = mhlo.multiply %10, %3 : tensor<10x10xf32> + %12 = mhlo.add %11, %4 : tensor<10x10xf32> + %13 = "mhlo.select"(%9, %12, %11) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %14 = "mhlo.round_nearest_afz"(%13) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %15 = "mhlo.tuple"(%14) : (tensor<10x10xf32>) -> tuple> + func.return %14 : tensor<10x10xf32> +} + +// CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32> +// CHECK: tfl.floor_div %arg0, %[[CST]] : tensor<10x10xf32> + +// ----- + +// CHECK-LABEL: floor_div_cst2 +func.func @floor_div_cst2(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<10x10xf32> + %1 = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32> + %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32> + %4 = mhlo.remainder %arg0, %1 : tensor<10x10xf32> + %5 = "mhlo.compare"(%4, %2) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %6 = "mhlo.sign"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %7 = "mhlo.compare"(%0, %6) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %8 = mhlo.and %5, %7 : tensor<10x10xi1> + %9 = mhlo.subtract %arg0, %4 : tensor<10x10xf32> + %10 = mhlo.divide %9, %1 : tensor<10x10xf32> + %11 = mhlo.add %10, %3 : tensor<10x10xf32> + %12 = "mhlo.select"(%8, %11, %10) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %13 = "mhlo.round_nearest_afz"(%12) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %14 = "mhlo.tuple"(%13) : (tensor<10x10xf32>) -> tuple> + func.return %13 : tensor<10x10xf32> +} + +// CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32> +// CHECK: tfl.floor_div %arg0, %[[CST]] : tensor<10x10xf32> + +// ----- + +// CHECK-LABEL: floor_div_broadcast_cst +func.func @floor_div_broadcast_cst(%arg0: tensor<10x8xf32>) -> tensor<10x8xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<10x8xf32> + %1 = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00, 1.600000e+01, 3.200000e+01, 6.400000e+01, 1.280000e+02]> : tensor<8xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor<10x8xf32> + %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x8xf32> + %5 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<8xf32>) -> tensor<10x8xf32> + %6 = mhlo.remainder %arg0, %5 : tensor<10x8xf32> + %7 = "mhlo.compare"(%6, %2) {comparison_direction = #mhlo} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1> + %8 = "mhlo.sign"(%6) : (tensor<10x8xf32>) -> tensor<10x8xf32> + %9 = "mhlo.compare"(%0, %8) {comparison_direction = #mhlo} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1> + %10 = mhlo.and %7, %9 : tensor<10x8xi1> + %11 = mhlo.subtract %arg0, %6 : tensor<10x8xf32> + %12 = mhlo.divide %11, %5 : tensor<10x8xf32> + %13 = mhlo.add %12, %3 : tensor<10x8xf32> + %14 = "mhlo.select"(%10, %13, %12) : (tensor<10x8xi1>, tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xf32> + %15 = "mhlo.round_nearest_afz"(%14) : (tensor<10x8xf32>) -> tensor<10x8xf32> + %16 = "mhlo.tuple"(%15) : (tensor<10x8xf32>) -> tuple> + func.return %15 : tensor<10x8xf32> +} + +// CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%1) +// CHECK: tfl.floor_div %arg0, %[[BCAST]] : tensor<10x8xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// unary elementwise +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: convert_i32_f32 +func.func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { + %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.cast + +// ----- + +// CHECK-LABEL: abs +func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.abs + +// ----- + +// CHECK-LABEL: abs_dynamic +func.func @abs_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.abs + +// ----- + +// CHECK-LABEL: ceil +func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.ceil + +// ----- + +// CHECK-LABEL: ceil_dynamic +func.func @ceil_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.ceil + +// ----- + +// CHECK-LABEL: complex_abs +func.func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { + %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK-NOT: tfl + +// ----- + +func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { + %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: %0 = tfl.sub %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2xf32> +// CHECK: %cst = arith.constant dense<0.000000e+00> : tensor +// CHECK: %1 = "tfl.equal"(%0, %cst) : (tensor<2xf32>, tensor) -> tensor<2xi1> +// CHECK: return %1 : tensor<2xi1> + +// ----- + +func.func @is_finite_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: %0 = tfl.sub %arg0, %arg0 {fused_activation_function = "NONE"} : tensor +// CHECK: %cst = arith.constant dense<0.000000e+00> : tensor +// CHECK: %1 = "tfl.equal"(%0, %cst) : (tensor, tensor) -> tensor + +// ----- + +// CHECK-LABEL: cos +func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.cos + +// ----- + +// CHECK-LABEL: cos_dynamic +func.func @cos_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.cos + +// ----- + +// CHECK-LABEL: logistic +func.func @logistic(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.logistic"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.logistic + +// ----- + +// CHECK-LABEL: exp +func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.exp + +// ----- + +// CHECK-LABEL: exp_dynamic +func.func @exp_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.exp + +// ----- + +// CHECK-LABEL: expm1 +func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: %0 = "tfl.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK: %1 = tfl.sub(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor) -> tensor<2xf32> + +// ----- + +// CHECK-LABEL: floor +func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.floor + +// ----- + +// CHECK-LABEL: floor_dynamic +func.func @floor_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.floor + +// ----- + +// CHECK-LABEL: log +func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.log + +// ----- + +// CHECK-LABEL: log_dynamic +func.func @log_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.log + +// ----- + +// CHECK-LABEL: log1p +func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor) -> tensor<2xf32> +// CHECK: %1 = "tfl.log"(%0) : (tensor<2xf32>) -> tensor<2xf32> + +// ----- + +// CHECK-LABEL: log1p_dynamic +func.func @log1p_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor) -> tensor +// CHECK: %1 = "tfl.log"(%0) : (tensor) -> tensor + +// ----- + +// CHECK-LABEL: neg +func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.neg + +// ----- + +// CHECK-LABEL: neg_dynamic +func.func @neg_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.neg + +// ----- + +// CHECK-LABEL: sin +func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.sin + +// ----- + +// CHECK-LABEL: sin_dynamic +func.func @sin_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.sin + +// ----- + +// CHECK-LABEL: rsqrt +func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.rsqrt + +// ----- + +// CHECK-LABEL: rsqrt_dynamic +func.func @rsqrt_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.rsqrt + +// ----- + +// CHECK-LABEL: @sqrt +func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.sqrt + +// ----- + +// CHECK-LABEL: sqrt_dynamic +func.func @sqrt_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.sqrt + +// ----- + +// CHECK-LABEL: tanh +func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.tanh + +// ----- + +// CHECK-LABEL: tanh_dynamic +func.func @tanh_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.tanh + +// ----- + +// CHECK-LABEL: bitcast +func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.bitcast + +// ----- + +// CHECK-LABEL: bitcast_dynamic +func.func @bitcast_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.bitcast + +// ----- + +// CHECK-LABEL: bitcast_same_widths +func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// CHECK: tfl.bitcast + +// ----- + //===----------------------------------------------------------------------===// // logical and bitwise ops //===----------------------------------------------------------------------===// @@ -2115,3 +2895,174 @@ func.func @not_ui32(%arg0: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> { // CHECK: %cst = arith.constant dense<4294967295> : tensor // CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui32>, tensor) -> tensor<7x9x11xui32> + +// ----- + +//===----------------------------------------------------------------------===// +// binary ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: remainder +func.func @remainder(%arg0: tensor<10x8xi32>, %arg1: tensor<10x8xi32>) -> tensor<10x8xi32> { + %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xi32> + func.return %0 : tensor<10x8xi32> +} + +// CHECK: %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<10x8xi32>, tensor<10x8xi32>) -> tensor<10x8xi32> + +// ----- + +// CHECK-LABEL: shift_right_arith +func.func @shift_right_arith(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK: %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + +// ----- + +// CHECK-LABEL: shift_right_logical +func.func @shift_right_logical(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.shift_right_logical %arg0, %arg1 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK: %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.compare +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: greater_unsupported_compare_type +func.func @greater_unsupported_compare_type(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK-NOT: tfl +// CHECK: mhlo.compare + +// ----- + +// CHECK-LABEL: equal +func.func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.equal + +// ----- + +// CHECK-LABEL: notequal +func.func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.not_equal + +// ----- + +// CHECK-LABEL: greater +func.func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.greater + +// ----- + +// CHECK-LABEL: greater_equal +func.func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.greater_equal + +// ----- + +// CHECK-LABEL: less +func.func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.less + +// ----- + +// CHECK-LABEL: less_equal +func.func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.less_equal + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo binary element-wise ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: maximum +func.func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: "tfl.maximum"(%arg0, %arg1) +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: minimum +func.func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: "tfl.minimum"(%arg0, %arg1) +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: mul +func.func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// CHECK: tfl.mul %arg0, %arg0 +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: pow +func.func @pow(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.power %arg0, %arg0 : tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: tfl.pow +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: clamp +func.func @clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-NEXT: %0 = "tfl.minimum"(%arg1, %arg2) +// CHECK-NEXT: %1 = "tfl.maximum"(%0, %arg0) +// CHECK-NEXT: return %1 : tensor + + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index eb866dc64931d0..2b96254e04fc3d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -1144,7 +1144,8 @@ class ComposeUniformQuantizedDotGeneralOp .clone(output_uniform_quantized_type), /*lhs=*/op.getLhs(), /*rhs=*/op.getRhs(), /*dot_dimension_numbers=*/op.getDotDimensionNumbers(), - /*precision_config=*/op.getPrecisionConfigAttr()); + /*precision_config=*/op.getPrecisionConfigAttr(), + /*algorithm=*/op.getAlgorithmAttr()); rewriter.replaceAllUsesWith(op.getResult(), new_dot_general_op.getResult()); @@ -1489,7 +1490,8 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations .clone(output_uniform_quantized_type), /*lhs=*/op.getLhs(), /*rhs=*/op.getRhs(), /*dot_dimension_numbers=*/op.getDotDimensionNumbers(), - /*precision_config=*/op.getPrecisionConfigAttr()); + /*precision_config=*/op.getPrecisionConfigAttr(), + /*algorithm=*/op.getAlgorithmAttr()); rewriter.replaceAllUsesWith(op.getResult(), new_dot_general_op.getResult()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index e5d0a59c7fd82d..27e655c3aa51d3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -124,6 +124,7 @@ cc_library( hdrs = ["conv.h"], deps = [ ":conv_util", + ":op_util_common", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -140,6 +141,7 @@ cc_library( hdrs = ["conv_util.h"], deps = [ ":op_util_common", + "//tensorflow/core/lib/math:math_util", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", @@ -204,6 +206,7 @@ cc_library( ":util", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -266,3 +269,19 @@ cc_library( "@local_xla//xla/mlir_hlo", ], ) + +cc_library( + name = "iota", + srcs = ["iota.cc"], + hdrs = ["iota.h"], + deps = [ + ":op_util_common", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc index 87a429d8ff9835..1ad6e7bfc044e3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc @@ -15,19 +15,23 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h" #include +#include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir::odml { @@ -74,13 +78,17 @@ bool IsWindowReversalSupported(const ConvView& data) { bool IsConvLegal(mhlo::ConvolutionOp op) { const ConvView data(op); - const bool supported_conv_type = - IsStandardConv(data) || IsDepthwiseConv(data); + const bool supported_conv_type = IsStandardConv(data) || + IsDepthwiseConv(data) || + IsSupportedNonTrivialConv(data); + + const bool is_non_supported_trivial_conv = + (!IsSupportedNonTrivialConv(data) && + (!IsPaddingSupported(data) || !IsInputDilationSupported(data))); return !supported_conv_type || !IsBatchGroupSupported(data) || - !IsInputDilationSupported(data) || !AreShapesSupported(data) || - !IsTFLNativeLayout(data) || !IsPaddingSupported(data) || - !IsWindowReversalSupported(data); + !AreShapesSupported(data) || !IsTFLNativeLayout(data) || + is_non_supported_trivial_conv || !IsWindowReversalSupported(data); } //===----------------------------------------------------------------------===// @@ -285,6 +293,295 @@ LogicalResult LegalizeConv3D::matchAndRewrite( return success(); } +//===----------------------------------------------------------------------===// +// mhlo.convolution -> TFL::ResizeBilinearOp +//===----------------------------------------------------------------------===// + +// Convert a 2d mhlo.convolution op to a tfl.resize_bilinear +class ConvertNonTrivialConvToResizeBilinearOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult ConvertNonTrivialConvToResizeBilinearOp::matchAndRewrite( + mhlo::ConvolutionOp conv_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + const ConvView data(conv_op); + bool align_corners; + if (!MatchWithResizeBilinearOp(data, align_corners)) { + return rewriter.notifyMatchFailure( + conv_op, "op does not match with resize_bilinear op"); + } + + // The output size attribute is an array of 32bit values. + SmallVector output_shape_i32; + for (int64_t spatial_dim : data.InputLayout().Spatials()) { + output_shape_i32.push_back( + static_cast(data.OutputShape()[spatial_dim])); + } + Value output_sizes_attr = rewriter.create( + conv_op.getLoc(), rewriter.getI32TensorAttr(output_shape_i32)); + // The value of half_pixel_centers couldn't be inferred from the IR and XLA + // only support half_pixel_centers=True as in 01/11/2022. Here + // half_pixel_centers=False is hardcoded. + rewriter.replaceOpWithNewOp( + conv_op, conv_op.getType(), conv_op.getLhs(), output_sizes_attr, + /*align_corners=*/rewriter.getBoolAttr(align_corners), + /*half_pixel_centers=*/rewriter.getBoolAttr(false)); + + return success(); +} + +//===----------------------------------------------------------------------===// +// mhlo.convolution -> TFL::TransposeConv2dOp +//===----------------------------------------------------------------------===// + +// Convert a 2d mhlo.convolution op to a tfl.transpose_conv2d +class ConvertNonTrivialConvToTransposeConvOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult ConvertNonTrivialConvToTransposeConvOp::matchAndRewrite( + mhlo::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + const ConvView data(op); + + // + // Test if the op is a supported non-trivial convolution. + //===----- + + if (!IsSupportedNonTrivialConv(data)) { + return rewriter.notifyMatchFailure(op, "Not a non-trivial convolution."); + } + + // For depthwise and group convolutions, feature_group_count != 1 + if (op.getFeatureGroupCount() != 1) { + // Depthwise or Group convolution is not supported yet. + return rewriter.notifyMatchFailure( + op, "group or depthwise convolution is not supported"); + } + + // + // strides + //===----- + + // TFL::TravsposeConv2D applies strides on LHS. strides == lhs_dilation + auto strides = data.InputDilations(); + auto tfl_h_stride = rewriter.getI32IntegerAttr(strides[0]); + auto tfl_w_stride = rewriter.getI32IntegerAttr(strides[1]); + + // + // padding + //===----- + + std::string padding; + SmallVector padding_array; + for (auto& padding : data.Padding()) { + padding_array.push_back(padding.Lo()); + padding_array.push_back(padding.Hi()); + } + + if (IsTransposeConvPaddingValid(op, /*num_spatial_dims*/ 2, strides, + padding_array)) { + padding = "VALID"; + } else if (IsTransposeConvPaddingSame(op, /*num_spatial_dims*/ 2, strides, + padding_array)) { + padding = "SAME"; + } else { + return rewriter.notifyMatchFailure(op, + "requires padding to be SAME or VALID"); + } + + // + // build tfl op + //===------- + + auto bias = BuildEmptyBias(rewriter, op->getLoc(), data); + auto tfl_faf_none = rewriter.getStringAttr("NONE"); + + // Need to reverse the kernel data inorder to run TFL::TransposeConv2d + // The axis along which to reverse. In this case, we want to mirror the + // kernel's spatial dimensions. + SmallVector kernel_spatial_dims_i32( + data.KernelLayout().Spatials().begin(), + data.KernelLayout().Spatials().end()); + Value axis = rewriter.create( + op.getLoc(), rewriter.getI32TensorAttr(kernel_spatial_dims_i32)); + + // Create the tfl::ReverseV2Op + auto filter = rewriter.create( + op.getLoc(), op.getRhs().getType(), op.getRhs(), axis); + + // Calculate the output size and shape for TFL::TransposeConv2dOp + SmallVector output_shape_i32(data.OutputShape().begin(), + data.OutputShape().end()); + + auto output_sizes = rewriter.create( + op.getLoc(), rewriter.getI32TensorAttr(output_shape_i32)); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), /*output_shape=*/output_sizes, + /*filter=*/filter, /*input=*/op.getLhs(), /*bias=*/bias, + /*padding=*/rewriter.getStringAttr(padding), + /*stride_h=*/tfl_h_stride, /*stride_w=*/tfl_w_stride, + /*fused_activation_function=*/tfl_faf_none); + + return success(); +} + +//===----------------------------------------------------------------------===// + +class SliceDepthwiseTransposedConvolution + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::ConvolutionOp op, + PatternRewriter& rewriter) const final; +}; + +// Pattern rewriter to match a depthwise transposed convolution and rewrite it +// to depth-times slices of input and filter to perform the transposed +// convolution on individual slices of tensors and concatenate the results of. +// the convolutions. This is a. workaround because the TFLite runtime doesn't +// support depthwise-transposed-conv op natively. +LogicalResult SliceDepthwiseTransposedConvolution::matchAndRewrite( + mhlo::ConvolutionOp conv_op, PatternRewriter& rewriter) const { + const ConvView data(conv_op); + + // + // Test if the op is a supported non-trivial convolution. + //===----- + if (!IsSupportedNonTrivialConv(data)) { + return rewriter.notifyMatchFailure(conv_op, + "Not a non-trivial convolution."); + } + + // These checks narrow down the support to depthwise transpose conv2d. + mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); + const int64_t input_feature_dimension = dnums.getInputFeatureDimension(); + const int64_t input_channels = + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); + const int64_t feature_group_count = conv_op.getFeatureGroupCount(); + const int64_t kernel_input_feature_dimension = + dnums.getKernelInputFeatureDimension(); + const int64_t kernel_input_channels = + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_input_feature_dimension); + const int64_t kernel_output_feature_dimension = + dnums.getKernelOutputFeatureDimension(); + const int64_t kernel_output_channels = + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_output_feature_dimension); + + // To support a depthwise convolution, we need- + // 1. feature_group_count != 1 (except when input_channels==1) + // 2. feature_group_count == input_channels + // 3. kernel_input_channels == 1 + // 4. kernel_output_channels % kernel_input_channels == 0 + if (feature_group_count == 1) { + return rewriter.notifyMatchFailure(conv_op, "Not a depthwise convolution"); + } + + if (input_channels != feature_group_count) { + return rewriter.notifyMatchFailure( + conv_op, "Not a detphwise transposed convolution"); + } + + if (MatchWithResizeBilinearOp(data)) { + return rewriter.notifyMatchFailure( + conv_op, "Op will be legalized to ResizeBilinearOp"); + } + + if ((kernel_output_channels % feature_group_count != 0) || + (kernel_input_channels != 1)) { + return rewriter.notifyMatchFailure( + conv_op, "Not a supported detphwise transposed convolution"); + } + + // This needs to be checked because the TFLite runtime generated incorrect + // results for depthwise transpose convolutions with non-1 channel + // multiplier. + if ((kernel_output_channels / feature_group_count) != 1) { + return rewriter.notifyMatchFailure( + conv_op, + "Unsupported detphwise transpose convolution with non-1 channel " + "multiplier"); + } + + // Slicing with dynamic offsets (helper method advised) + auto create_slice = [&](mlir::Value tensor, int64_t depth_idx, + int64_t channel_idx, + bool is_kernel = false) -> mlir::Value { + auto tensor_shape = + mlir::cast(tensor.getType()).getShape().vec(); + + // Calculate offsets based on depth_idx, channel_idx and tensor_shape + llvm::SmallVector start_indices(tensor_shape.size(), 0); + auto limit_indices = tensor_shape; + const llvm::SmallVector strides(tensor_shape.size(), 1); + start_indices[channel_idx] = depth_idx; + if (is_kernel) { + // kernel can have a channel_multiplier that needs to be accounted for + limit_indices[channel_idx] = + depth_idx + (kernel_output_channels / feature_group_count); + } else { + limit_indices[channel_idx] = depth_idx + 1; + } + return rewriter.create( + conv_op.getLoc(), tensor, rewriter.getI64TensorAttr(start_indices), + rewriter.getI64TensorAttr(limit_indices), + rewriter.getI64TensorAttr(strides)); + }; + + // Storage for smaller convolution results + llvm::SmallVector conv_results; + + // Iterative Slicing and Convolutions + for (int i = 0; i < feature_group_count; ++i) { + auto sliced_input = + create_slice(conv_op.getLhs(), i, input_feature_dimension); + auto sliced_kernel = create_slice(conv_op.getRhs(), i, + kernel_output_feature_dimension, true); + + // Calculate convolution output_type based on sliced_input and + // sliced_kernel + auto output_type = mlir::cast(conv_op->getResult(0).getType()); + auto new_output_shape = output_type.getShape().vec(); + new_output_shape[dnums.getOutputFeatureDimension()] /= feature_group_count; + auto new_output_type = + RankedTensorType::get(new_output_shape, output_type.getElementType()); + + // Create a Smaller Convolution (Ensure compatibility) + auto conv_result = rewriter.create( + conv_op.getLoc(), new_output_type, sliced_input, sliced_kernel, + conv_op.getWindowStridesAttr(), conv_op.getPaddingAttr(), + conv_op.getLhsDilationAttr(), conv_op.getRhsDilationAttr(), + conv_op.getWindowReversalAttr(), conv_op.getDimensionNumbers(), + /*feature_group_count*/ 1, /*batch_group_count*/ 1, + conv_op.getPrecisionConfigAttr()); + + conv_results.push_back(conv_result); + } + + auto final_output = rewriter.create( + conv_op.getLoc(), conv_results, + rewriter.getI64IntegerAttr(dnums.getOutputFeatureDimension())); + rewriter.replaceOp(conv_op, final_output.getResult()); + return success(); +} + +//===----------------------------------------------------------------------===// + // Convert a 1-D convolution into a 2-D convolution (which TF supports) so that // it can be rewritten by the pattern `Convert2DConvOp`. class Conv1DToConv2D : public OpRewritePattern { @@ -436,12 +733,14 @@ LogicalResult Conv1DToConv2D::matchAndRewrite(mhlo::ConvolutionOp op, void PopulateLegalizeConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns, ConversionTarget& target) { - patterns.add(ctx); + patterns.add(ctx); target.addDynamicallyLegalOp(IsConvLegal); } void PopulatePrepareConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns) { - patterns.add(ctx); + patterns.add(ctx); } } // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc index e2dff5f5e5a7ce..70c0ab5acc5f1e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc @@ -16,11 +16,16 @@ limitations under the License. #include #include +#include +#include #include +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" @@ -115,4 +120,151 @@ Value CreatePadOpFromConvPadding(OpBuilder& b, mhlo::ConvolutionOp op) { return pad_op; } +bool MatchWithResizeBilinearOp(const ConvView& data, bool& align_corners) { + if (data.InputLayout().Rank() != 4 || data.KernelLayout().Rank() != 4 || + data.OutputLayout().Rank() != 4 || + data.InputLayout().Spatials() != data.OutputLayout().Spatials()) { + return false; + } + + if (data.InputDilations().size() != 2 || + !(llvm::all_of(data.KernelDilations(), [](auto d) { return d == 1; })) || + data.Strides().size() != 2 || data.Padding().size() != 2) { + return false; + } + + // This is based on method in compiler/tf2xla/kernels/image_resize_ops.cc + auto can_convert_to_bilinear = + [](bool align_corners, int64_t dilation, int64_t padding, int64_t stride, + int64_t input_spatial, int64_t output_spatial) { + int64_t input_spatial_size = + align_corners ? input_spatial - 1 : input_spatial; + int64_t output_spatial_size = + align_corners ? output_spatial - 1 : output_spatial; + + int64_t gcd = std::gcd(static_cast(input_spatial_size), + static_cast(output_spatial_size)); + + if ((gcd == 0) || (input_spatial_size % gcd != 0) || + (input_spatial_size / gcd != stride) || (dilation - 1 != padding)) { + return false; + } + return true; + }; + + if (data.InputDilations()[0] != 1 && data.InputDilations()[1] == 1) { + if (can_convert_to_bilinear( + /*align_corners=*/true, data.InputDilations()[0], + data.Padding()[0].Lo(), data.Strides()[0], + data.InputShape()[data.InputLayout().Spatials()[0]], + data.OutputShape()[data.OutputLayout().Spatials()[0]])) { + align_corners = true; + return true; + } else if (can_convert_to_bilinear( + /*align_corners=*/false, data.InputDilations()[0], + data.Padding()[0].Lo(), data.Strides()[0], + data.InputShape()[data.InputLayout().Spatials()[0]], + data.OutputShape()[data.OutputLayout().Spatials()[0]])) { + align_corners = false; + return true; + }; + } else if (data.InputDilations()[0] == 1 && data.InputDilations()[1] != 1) { + if (can_convert_to_bilinear( + /*align_corners=*/true, data.InputDilations()[1], + data.Padding()[1].Lo(), data.Strides()[1], + data.InputShape()[data.InputLayout().Spatials()[1]], + data.OutputShape()[data.OutputLayout().Spatials()[1]])) { + align_corners = true; + return true; + } else if (can_convert_to_bilinear( + /*align_corners=*/false, data.InputDilations()[1], + data.Padding()[1].Lo(), data.Strides()[1], + data.InputShape()[data.InputLayout().Spatials()[1]], + data.OutputShape()[data.OutputLayout().Spatials()[1]])) { + align_corners = false; + return true; + }; + } + + return false; +} + +bool IsTransposeConvPaddingValid(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding) { + auto dnums = conv_op.getDimensionNumbers(); + // The newly added spatial dimension requires zero left and right padding. + ArrayRef input_spatial_dims = dnums.getInputSpatialDimensions(); + ArrayRef kernel_spatial_dims = dnums.getKernelSpatialDimensions(); + ArrayRef output_spatial_dims = dnums.getOutputSpatialDimensions(); + + for (size_t i = 0; i < num_spatial_dims; ++i) { + int64_t stride = strides[i]; + int64_t input_size = mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dims[i]); + int64_t kernel_size = mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_spatial_dims[i]); + int64_t output_size = mlir::cast(conv_op.getType()) + .getDimSize(output_spatial_dims[i]); + + // stablehlo.convolution op needs explicit padding to be set to model any + // Transposed-Convolution in JAX/PT. Checking to see if- + // 1. Pre set padding matches to the desired padding + // 2. Output size respects the `VALID` padding scenario + if ((padding[2 * i] == padding[2 * i + 1]) && + (((kernel_size - 1) != padding[2 * i]) || + (output_size != (stride * (input_size - 1)) + kernel_size))) { + // padding[2 * i] == padding[2 * i + 1] means equal padding is applied + // on both sides of a spatial dimension. + // This happens when kernel_dim >= stride + return false; + } else if ((padding[2 * i] != padding[2 * i + 1]) && + (((kernel_size - 1) != padding[2 * i]) || + ((stride - 1) != padding[2 * i + 1]) || + (output_size != (stride * input_size)))) { + return false; + } + } + + return true; +} + +bool IsTransposeConvPaddingSame(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding) { + auto dnums = conv_op.getDimensionNumbers(); + + // The newly added spatial dimension requires zero left and right padding. + ArrayRef input_spatial_dims = dnums.getInputSpatialDimensions(); + ArrayRef output_spatial_dims = dnums.getOutputSpatialDimensions(); + for (size_t i = 0; i < num_spatial_dims; ++i) { + // In some cases the total padding is odd, so we have 1 leftover, which is + // why below we check pad_delta > 1. + int64_t pad_delta = std::abs(padding[2 * i] - padding[2 * i + 1]); + if (pad_delta > 1) { + return false; + } + int64_t stride = strides[i]; + int64_t input_size = mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dims[i]); + int64_t output_size = mlir::cast(conv_op.getType()) + .getDimSize(output_spatial_dims[i]); + // The reason for the below check is as follows: + // When computing the output, we have the following relation between + // o - output dim size, i - input dim size, s - stride, P - total pads + // o = (i-k+1) + (s-1)(i-1) + P + // Where the first term is the kernel applications on the input, + // the second term is the additional applications from the stride + // and P is a term that captures the total padding. After expanding we get + // o = si + k - s + 2 + P + // Here JAX sets P to cancel k-s+2, leading to the expression below + if (output_size != input_size * stride) { + return false; + } + } + return true; +} + } // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h index d20ad087b20410..ed8b06e036d816 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h @@ -15,8 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_UTIL_H_ +#include +#include +#include + #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" @@ -100,10 +106,48 @@ inline bool HasSupportedOutFeatureDims(const ConvView& data) { return kernel_out_features == out_features; } -inline bool IsNonTrivialConv(const ConvView& data) { +inline bool IsTrivialConv(const ConvView& data) { return llvm::all_of(data.InputDilations(), [](auto d) { return d == 1; }); } +// +// Supported non-trivial conv predicates +//=----- + +bool MatchWithResizeBilinearOp(const ConvView& data, bool& align_corners); + +inline bool MatchWithResizeBilinearOp(const ConvView& data) { + bool align_corners = false; + return MatchWithResizeBilinearOp(data, align_corners); +} + +bool IsTransposeConvPaddingValid(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding); + +bool IsTransposeConvPaddingSame(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding); + +inline bool IsSupportedNonTrivialConv(const ConvView& data) { + // Only non-trivial 2d convolutions are supported. + const bool valid_rank = data.InputLayout().Rank() == 4; + + // Negative padding is unsupported. + bool has_nagative_padding = llvm::all_of( + data.Padding(), + [](const DimPadding& p) { return p.Hi() < 0 || p.Lo() < 0; }); + + return (valid_rank && !IsTrivialConv(data) && !has_nagative_padding); +} + +inline bool IsSupportedNonTrivialConv(mhlo::ConvolutionOp op) { + const ConvView data(op); + return IsSupportedNonTrivialConv(data); +} + // // Standard conv predicates //=----- @@ -122,7 +166,7 @@ inline bool HasStandardConvInFeatureDims(const ConvView& data) { } inline bool IsStandardConv(const ConvView& data) { - return HasSupportedRank(data) && IsNonTrivialConv(data) && + return HasSupportedRank(data) && IsTrivialConv(data) && HasStandardConvInFeatureDims(data) && HasSupportedOutFeatureDims(data); } @@ -140,7 +184,7 @@ inline bool IsStandardConv(mhlo::ConvolutionOp op) { inline bool IsDepthwiseConv(const ConvView& data) { const bool valid_rank = data.InputLayout().Rank() == 4; if (!valid_rank || !HasSupportedOutFeatureDims(data) || - !IsNonTrivialConv(data)) { + !IsTrivialConv(data)) { return false; } const int64_t in_channel_dim = @@ -197,7 +241,7 @@ inline bool IsTFLNativeLayout(const ConvView& data) { std::optional native_kernel_layout = std::nullopt; if (IsDepthwiseConv(data)) { native_kernel_layout = GetTFLNativeDepthwiseConvKernelLayout(); - } else if (IsStandardConv(data)) { + } else if (IsStandardConv(data) || IsSupportedNonTrivialConv(data)) { native_kernel_layout = GetTFLNativeStandardConvKernelLayout(rank); } if (!native_kernel_layout.has_value()) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.cc new file mode 100644 index 00000000000000..74aaa81519ea88 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h" + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { +namespace { + +class LegalizeIota : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::IotaOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +std::tuple +BuildRangeParams(Type e_type, int64_t iota_dim_size, OpBuilder& b) { + if (e_type.isInteger()) { + return std::tuple(BuildScalarDense(e_type, 0), + BuildScalarDense(e_type, iota_dim_size), + BuildScalarDense(e_type, 1)); + } + return std::tuple(BuildScalarDense(e_type, 0.0), + BuildScalarDense(e_type, iota_dim_size), + BuildScalarDense(e_type, 1.0)); +} + +LogicalResult LegalizeIota::matchAndRewrite( + mhlo::IotaOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + auto type = llvm::cast(op.getType()); + if (!type.getElementType().isIntOrFloat()) { + return rewriter.notifyMatchFailure(op, "Must be int or float"); + } + + auto e_type = type.getElementType(); + const int64_t iota_dim_size = type.getDimSize(op.getIotaDimension()); + + auto [start, limit, delta] = + BuildRangeParams(e_type, iota_dim_size, rewriter); + + auto start_op = rewriter.create(op->getLoc(), start); + auto limit_op = rewriter.create(op->getLoc(), limit); + auto delta_op = rewriter.create(op->getLoc(), delta); + + auto range_type = RankedTensorType::get({iota_dim_size}, e_type); + auto range_op = rewriter.create(op->getLoc(), range_type, + start_op, limit_op, delta_op); + + if (type.getRank() == 1) { + rewriter.replaceOp(op, range_op); + return success(); + } + + // mhlo.iota allows filling ND tensors iota-style. Reshape and broadcast + // tfl 1D range output. + + llvm::SmallVector reshape_shape(type.getRank(), 1); + reshape_shape[op.getIotaDimension()] = iota_dim_size; + Value reshape_shape_cst = rewriter.create( + op->getLoc(), rewriter.getI64TensorAttr(reshape_shape)); + reshape_shape_cst = rewriter.create( + op->getLoc(), + llvm::cast(reshape_shape_cst.getType()) + .clone(rewriter.getI32Type()), + reshape_shape_cst); + + auto reshape_type = RankedTensorType::get(reshape_shape, e_type); + auto reshape_op = rewriter.create( + op->getLoc(), reshape_type, range_op, reshape_shape_cst); + + auto broad_cast_shape_cst = rewriter.create( + op->getLoc(), rewriter.getI64TensorAttr(type.getShape())); + + rewriter.replaceOpWithNewOp(op, type, reshape_op, + broad_cast_shape_cst); + + return success(); +} + +} // namespace + +void PopulateIotaPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(ctx); + target.addIllegalOp(); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h new file mode 100644 index 00000000000000..a53bdeda2a2097 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ + +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +void PopulateIotaPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h index e3f89412225693..3c2c8ae5ced600 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir::odml { @@ -134,6 +135,12 @@ llvm::SmallVector ResolvePadding( bool IsSamePaddingOnDim(int64_t in, int64_t dilate, int64_t stride, int64_t k, const DimPadding& pad); +template +inline DenseElementsAttr BuildScalarDense(Type e_type, T val) { + auto type = RankedTensorType::get({}, e_type); + return DenseElementsAttr::get(type, val); +} + } // namespace mlir::odml #endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc index 443a30591482fb..a00ee33c45a8ca 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -101,6 +102,13 @@ bool IsCstFloatZero(Value val) { initial_value.getValues()[0].isZero(); } +bool IsCstIntZero(Value val) { + DenseIntElementsAttr initial_value; + return matchPattern(val, m_Constant(&initial_value)) && + initial_value.getNumElements() == 1 && + initial_value.getValues()[0].isZero(); +} + llvm::SmallVector Permute(llvm::ArrayRef data, llvm::ArrayRef perm) { llvm::SmallVector res(data.size()); @@ -293,6 +301,126 @@ LogicalResult RelayoutReduceWindow::matchAndRewrite( return success(); } +//===------------------------------------------------------------------------=== +// mhlo.reduce_window -> tfl.cum_sum +//===------------------------------------------------------------------------=== + +class LegalizeCumSum : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ReduceWindowOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeCumSum::matchAndRewrite( + mhlo::ReduceWindowOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + // + // check singular params and trivial attrs + //=----- + + auto opt_input_init = GetInputAndInitIfValid(op); + if (!opt_input_init.has_value()) { + return rewriter.notifyMatchFailure(op, + "Must have 1 input, init and result."); + } + auto [input, init] = opt_input_init.value(); + + if (failed(MatchBinaryReduceFunction(op.getBody()))) { + return rewriter.notifyMatchFailure(op, "Requires scalar add in region."); + } + + if (!IsCstFloatZero(init) && !IsCstIntZero(init)) { + return rewriter.notifyMatchFailure(op, "Requires 0 for init value."); + } + + const ReduceWindowView view(op); + + auto trivial = [](int64_t v) { return v == 1; }; + const bool trivial_window_dilate = + llvm::all_of(view.WindowDilations(), trivial); + const bool trivial_base_dilate = llvm::all_of(view.BaseDilations(), trivial); + const bool trivial_stride = llvm::all_of(view.WindowStrides(), trivial); + if (!trivial_window_dilate || !trivial_stride || !trivial_base_dilate) { + return rewriter.notifyMatchFailure( + op, "Requires trivial strides and dilations attributes."); + } + + // + // figure out the implicit axis of reduction + //=----- + + auto input_type = llvm::cast(input.getType()); + if (view.WindowDims().size() != input_type.getRank()) { + return rewriter.notifyMatchFailure(op, "Splat window dims not supported."); + } + int64_t axis = -1; + for (auto [ind, val] : llvm::enumerate(view.WindowDims())) { + if (val == 1) { + continue; + } + + if (axis != -1) { + return rewriter.notifyMatchFailure(op, "Multiple non 1 dimensions."); + } + + if (val != input_type.getShape()[ind]) { + return rewriter.notifyMatchFailure( + op, "Axis dimension requires size be same as input shape's."); + } + axis = ind; + } + + if (axis == -1) { + return rewriter.notifyMatchFailure(op, "Could not identify axis."); + } + + const int64_t axis_size = input_type.getShape()[axis]; + + // + // validate padding is [N-1, 0] on axis and zero elsewhere + //=----- + + for (const auto& [ind, dim_pad] : llvm::enumerate(view.Paddings())) { + if (dim_pad.Hi() != 0) { + return rewriter.notifyMatchFailure(op, "Has non trivial high padding."); + } + + if (ind != axis) { + if (!dim_pad.Trivial()) { + return rewriter.notifyMatchFailure( + op, "Has non trivial padding on non axis dim."); + } + } else { + if (dim_pad.Lo() != axis_size - 1) { + return rewriter.notifyMatchFailure( + op, "Requires low padding on axis dim to be N - 1."); + } + } + } + + // + // build axis constant and tfl op + //=----- + + auto axis_cst_attr = DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI32Type()), + static_cast(axis)); + auto axis_cst = + rewriter.create(op->getLoc(), axis_cst_attr); + + auto tfl_exclusive_attr = rewriter.getBoolAttr(false); + auto tfl_reverse_attr = rewriter.getBoolAttr(false); + + rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], input, + axis_cst, tfl_exclusive_attr, + tfl_reverse_attr); + + return success(); +} + //===------------------------------------------------------------------------=== // mhlo.reduce_window -> tfl.max_pool //===------------------------------------------------------------------------=== @@ -601,7 +729,7 @@ LogicalResult LegalizeAvgPool::matchAndRewrite( void PopulateLegalizeReduceWindowPatterns(MLIRContext* ctx, RewritePatternSet& patterns, ConversionTarget& target) { - patterns.add(ctx); + patterns.add(ctx); target.addDynamicallyLegalOp(IsReduceWindowLegal); target.addDynamicallyLegalOp(IsDivideLegal); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc index 76c843e550c631..07f9f0368ad665 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc @@ -23,6 +23,8 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project @@ -33,6 +35,21 @@ limitations under the License. namespace mlir::odml { namespace { +// mhlo encodes ND indice arguments as a variadiac of scalars. Pack them +// into a single tensor for use in TFL. +Value PackScalarIndices(mlir::ValueRange indices, OpBuilder& b) { + auto e_type = + llvm::cast(indices.front().getType()).getElementType(); + const int64_t num_indices = indices.size(); + auto packed_indices_type = RankedTensorType::get({num_indices}, e_type); + + auto values_count_attr = b.getI32IntegerAttr(num_indices); + auto pack_axis_attr = b.getI32IntegerAttr(0); + + return b.create(indices.back().getLoc(), packed_indices_type, + indices, values_count_attr, pack_axis_attr); +} + //===----------------------------------------------------------------------===// // mhlo.slice //===----------------------------------------------------------------------===// @@ -166,43 +183,53 @@ LogicalResult LegalizeDynamicSliceOp::matchAndRewrite( new_start_indices.push_back(new_start_ind); } - // - // pack variadic scalar start indices into one tensor - //=----- - - const int64_t packed_start_indices_size = new_start_indices.size(); - auto packed_start_indices_type = - RankedTensorType::get({packed_start_indices_size}, start_e_type); - - auto values_count_attr = - rewriter.getI32IntegerAttr(packed_start_indices_size); - auto pack_axis_attr = rewriter.getI32IntegerAttr(0); - - auto packed_start_inds = rewriter.create( - op->getLoc(), packed_start_indices_type, new_start_indices, - values_count_attr, pack_axis_attr); - // // build tfl //=----- + auto packed_indices = PackScalarIndices(new_start_indices, rewriter); + auto slice_sizes_cst = rewriter.create(op->getLoc(), op.getSliceSizes()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand(), - packed_start_inds, slice_sizes_cst); + packed_indices, slice_sizes_cst); return success(); } +//===----------------------------------------------------------------------===// +// mhlo.dynamic_update_slice +//===----------------------------------------------------------------------===// + +class LegalizeDynamicUpdateSliceOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeDynamicUpdateSliceOp::matchAndRewrite( + mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + auto packed_indices = PackScalarIndices(op.getStartIndices(), rewriter); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getOperand(), op.getUpdate(), packed_indices); + return success(); +}; + } // namespace void PopulateLegalizeSlicePatterns(MLIRContext* ctx, RewritePatternSet& patterns, ConversionTarget& target) { - patterns.add(ctx); + patterns.add(ctx); - target.addIllegalOp(); + target.addIllegalOp(); target.addDynamicallyLegalOp(IsDynamicSliceLegal); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td index fe988ba9b20265..49d38d78cb6f2a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td @@ -380,7 +380,7 @@ def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, " def : Pat<(MHLO_DotGeneralOp:$old_value RankedTensorOf<[TF_ElementType]>:$lhs, RankedTensorOf<[TF_ElementType]>:$rhs, - $dot_dimension_numbers, $precision_config), + $dot_dimension_numbers, $precision_config, $algorithm), (ConvertDotGeneralOp $old_value)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc index d9c23dfa12b8ae..062222f72b3b9a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc @@ -58,7 +58,7 @@ LogicalResult ConvertDotToDotGeneral(mhlo::DotOp op, /*rhsBatchingDimensions=*/{}, /*lhsContractingDimensions=*/{lhs_type.getRank() - 1}, /*rhsContractingDimensions=*/{0}), - op.getPrecisionConfigAttr()); + op.getPrecisionConfigAttr(), mhlo::DotAlgorithmAttr{}); return success(); } @@ -161,7 +161,7 @@ LogicalResult RemoveReshapeAroundDotGeneral(mhlo::ReshapeOp reshape_after, range(batch_dims_count + shape_y1.size(), contracting_dims_count), /*rhsContractingDimensions=*/ range(batch_dims_count, contracting_dims_count)), - dot.getPrecisionConfigAttr()); + dot.getPrecisionConfigAttr(), dot.getAlgorithmAttr()); return success(); } @@ -273,7 +273,8 @@ LogicalResult LiftDotConcatLHS(mhlo::ConcatenateOp concat, rewriter.getI64IntegerAttr(new_concat_dim)); rewriter.replaceOpWithNewOp( concat, concat.getType(), new_concat, first_dot.getRhs(), - first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr()); + first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr(), + first_dot.getAlgorithmAttr()); return success(); } @@ -374,7 +375,8 @@ LogicalResult LiftDotConcatLHSAndRHS(mhlo::ConcatenateOp concat, all_dot_rhs, rewriter.getI64IntegerAttr(rhs_batch_dim)); rewriter.replaceOpWithNewOp( concat, concat.getType(), lhs_new_concat, rhs_new_concat, - first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr()); + first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr(), + first_dot.getAlgorithmAttr()); return success(); } @@ -611,10 +613,134 @@ LogicalResult ConvertReshapeDotRhsToBatchedDot(mhlo::DotGeneralOp dot, /*rhsBatchingDimensions=*/{0}, /*lhsContractingDimensions=*/dim_nums.getLhsContractingDimensions(), /*rhsContractingDimensions=*/new_rhs_contracting_dims), - dot.getPrecisionConfigAttr()); + dot.getPrecisionConfigAttr(), dot.getAlgorithmAttr()); return success(); } +//===----------------------------------------------------------------------===// +// BroadcastInDimsOp +//===----------------------------------------------------------------------===// + +// Minimizing unit dimensions in reshape(broadcast(X)). +// +// There are situations where X, or broadcast(X) have some number of `1` (unit) +// sized dimensions which are not meaningful to the computation. E.g. +// +// ``` +// x = [1x1x1x3] +// b = broadast(x) : [1x2x1x3] +// r = reshape(b) : [2x3] +// ``` +// +// Provided the relative broadcast dims are preserved, removing any number +// of unit dims from the input or output shape of a broadcast has no effect on +// the semantic of the computation. +// +// Assume a reshape(broadcast(x)) where the shape of the broadcast and reshape +// have the same non-unit dims in the same order. In this case we can +// change the broadcast shape into the reshape shape simply by adding or +// removing unit-dims, and the reshape can be replaced with the broadcast. +// +// When removing unit dims from the broadcast in this way, we may also need +// to remove the corresponding unit dim from the input shape. This pattern takes +// the approach of removing all unit dims for the broadcast input +// rather than explicitly checking each. +// +// The result on the above example: +// +// ``` +// x = [1x1x1x3] +// r = reshape(x) : [3] +// b = broadast(r) : [2x3] +// ``` +// +// Note that the ability of removing unit dims from the input or output shape of +// a broascast is not contingent on matching and replacing a reshaped output. We +// require however for this pattern to not increase the net number of reshapes. +// Additionally, we want to minimize the rank of broadcasts so only considered +// are cases where rank(reshape) < rank(broadcast). +class SimplifyBroadcastInDimsReshape + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::BroadcastInDimOp op, + PatternRewriter &rewriter) const override { + if (!op->hasOneUse()) { + return rewriter.notifyMatchFailure(op, "has more than one use."); + } + + auto reshape = mlir::dyn_cast(*op->getUsers().begin()); + if (!reshape) { + return rewriter.notifyMatchFailure(op, "user not reshape."); + } + + auto broadcast_type = mlir::cast(op.getType()); + auto broadcast_input_type = + mlir::cast(op.getOperand().getType()); + auto reshape_type = mlir::cast(reshape.getType()); + + // Reshape must be squeezing unit dimensions. + if (!(reshape_type.getRank() < broadcast_type.getRank())) { + return rewriter.notifyMatchFailure(op, "reshape doesn't reduce rank."); + } + + // Reshape and broadcast must have the same non-unit dims in the + // same order. + llvm::SmallVector broadcast_dim_to_reshape_dim( + broadcast_type.getRank()); + int64_t reshape_dim_idx = -1; + for (auto [idx, dim] : llvm::enumerate(broadcast_type.getShape())) { + if (dim == 1) { + continue; + } + + int64_t reshape_dim_size = 1; + while (reshape_dim_idx < reshape_type.getRank() - 1) { + reshape_dim_size = reshape_type.getDimSize(++reshape_dim_idx); + if (reshape_dim_size != 1) { + break; + } + } + + if (dim != reshape_dim_size) { + return rewriter.notifyMatchFailure( + op, "reshape and broadcast have different non-unit dim sizes."); + } + + // Maps index of non-unit broadcast dims to corresponding reshape dim. + broadcast_dim_to_reshape_dim[idx] = reshape_dim_idx; + } + // Unchecked reshape dim sizes are guaranteed to be unit at this point. + + llvm::SmallVector current_broadcast_dims( + op.getBroadcastDimensions().getValues()); + llvm::SmallVector new_broadcast_dims; + llvm::SmallVector new_broadcast_input_shape; + + for (auto [idx, dim] : llvm::enumerate(broadcast_input_type.getShape())) { + if (dim == 1) { + continue; + } + // If dim != 1 then it must be broadcasted to a non-unit dimension + // and must have a corresponding reshape dimension in our vectors. + new_broadcast_dims.push_back( + broadcast_dim_to_reshape_dim[current_broadcast_dims[idx]]); + new_broadcast_input_shape.push_back(dim); + } + + auto new_broadcast_input_type = RankedTensorType::get( + new_broadcast_input_shape, broadcast_type.getElementType()); + auto new_broadcast_input = rewriter.create( + op->getLoc(), new_broadcast_input_type, op.getOperand()); + auto new_broadcast_dims_attr = + rewriter.getI64TensorAttr(new_broadcast_dims); + + rewriter.replaceOpWithNewOp( + reshape, reshape_type, new_broadcast_input, new_broadcast_dims_attr); + + return success(); + } +}; + class OptimizePass : public PassWrapper> { public: @@ -632,6 +758,7 @@ class OptimizePass patterns.add(FuseSliceConcat); patterns.add(ConvertReshapeDotRhsToBatchedDot); patterns.add(MergeConsecutivePad); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td index c789b3bde293c6..3eb051d38d8917 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td @@ -151,3 +151,19 @@ def PrepareHloPass ]; } +def LiftCallSiteLocCallerPass : Pass<"lift-callsite-loc-caller", "ModuleOp"> { + let summary = "Lifts CallSites in pytorch generated stablehlo."; + let description = [{ + Lifts CallSites in pytorch generated stablehlo to make the Loc's consitent + after inlining. + }]; + let dependentDialects = ["func::FuncDialect"]; +} + +def BuildStableHLOCompositePass : Pass<"build-stablehlo-composite", "ModuleOp"> { + let summary = "Build stablehlo.composite from inlined stablehlo.custom_call mark_tensor ops."; + let description = [{ + Build stablehlo.composite from inlined stablehlo.custom_call mark_tensor ops. + }]; + let dependentDialects = ["func::FuncDialect", "stablehlo::StablehloDialect"]; +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td index a3a5845c3367dd..9b6f6efbfcf4f6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td @@ -67,6 +67,15 @@ def IsStandardConv : Constraint())">>; +def IsSupportedNonTrivialConv : Constraint())">>; + +def IsSupportedConv : Constraint>; + +def IsSupportedStandardOrNonTrivialConv : Constraint>; + def IsStandardOrDepthwiseConv : Constraint>; @@ -135,7 +144,7 @@ def ReLayoutConvInput : Pat<(MHLO_ConvolutionOp:$conv [(AreDnumsFullyDefined $conv), (InputHasIotaSpatials $dnums), (IsInputNotTFLNativeLayout $dnums), - (IsStandardOrDepthwiseConv $conv)], + (IsSupportedConv $conv)], [], (addBenefit 1)>; @@ -210,7 +219,7 @@ def ReLayoutConvKernel : Pat<(MHLO_ConvolutionOp:$conv [(AreDnumsFullyDefined $conv), (KernelHasIotaSpatials $dnums), (IsKernelNotTFLNativeStandardConvLayout $dnums), - (IsStandardConv $conv)], + (IsSupportedStandardOrNonTrivialConv $conv)], [], (addBenefit 1)>; @@ -344,7 +353,7 @@ def ReLayoutConvOutput : Pat<(MHLO_ConvolutionOp:$conv [(AreDnumsFullyDefined $conv), (KernelHasIotaSpatials $dnums), (IsOutputNotTFLNativeLayout $dnums), - (IsStandardOrDepthwiseConv $conv)]>; + (IsSupportedConv $conv)]>; // Pull out non-trivial padding into separate explicit pad_op. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index d23ec076de2ef0..6cd284a73dd576 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -14,12 +14,16 @@ limitations under the License. ==============================================================================*/ // The kept headers are provided for the included file `passes.h.inc`. +#include #include #include #include +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -35,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h" @@ -50,9 +55,137 @@ namespace mlir { namespace odml { namespace { +// Returns the shape of the given value in a Constant Op. +arith::ConstantOp ShapeToConst(PatternRewriter& rewriter, Value value) { + ArrayRef shape = mlir::cast(value.getType()).getShape(); + auto attr_type = RankedTensorType::get({static_cast(shape.size())}, + rewriter.getIntegerType(64)); + auto attr = DenseElementsAttr::get(attr_type, shape); + return rewriter.create(value.getLoc(), attr_type, attr); +} + +bool IsSign(APInt a, APInt sign) { + if (a.isZero()) return a == sign; + if (a.isNegative()) return sign == -1; + return sign == 1; +} + +bool IsSign(APFloat a, APFloat sign) { + if (a.isNaN() || a.isZero()) return a == sign; + if (a.isNegative()) return sign.isExactlyValue(-1.0); + return sign.isExactlyValue(1.0); +} + +bool IsDenseSplatIntAttr(ElementsAttr float_or_int) { + return mlir::isa(float_or_int) && + mlir::isa(float_or_int); +} + +bool IsDenseSplatFloatAttr(ElementsAttr float_or_int) { + return mlir::isa(float_or_int) && + mlir::isa(float_or_int); +} + +bool ValueEquals(ElementsAttr float_or_int, double rhs) { + if (IsDenseSplatFloatAttr(float_or_int)) { + return mlir::cast(float_or_int) + .getSplatValue() + .isExactlyValue(rhs); + } else if (IsDenseSplatIntAttr(float_or_int)) { + return mlir::cast(float_or_int).getSplatValue() == + static_cast(rhs); + } + return false; +} + +// Returns whether the splat constant is the sign of the int or float Tensor. +bool TensorIsSign(PatternRewriter& rewriter, ElementsAttr float_or_int, + ElementsAttr sgn_cst) { + auto sgn_splat = llvm::dyn_cast(sgn_cst); + if (!sgn_splat) return false; + + auto splat = dyn_cast(float_or_int); + if (auto float_spl = llvm::dyn_cast_if_present(splat), + sgn_cst_spl = llvm::dyn_cast_if_present(sgn_splat); + float_spl && sgn_cst_spl) { + return IsSign(float_spl.getValue(), sgn_cst_spl.getValue()); + } + if (auto int_spl = llvm::dyn_cast_if_present(splat), + sgn_cst_spl = llvm::dyn_cast_if_present(sgn_splat); + int_spl && sgn_cst_spl) { + return IsSign(int_spl.getValue(), sgn_cst_spl.getValue()); + } + if (mlir::isa(float_or_int)) { + auto sgn_splat_value = sgn_splat.getSplatValue(); + return llvm::all_of(float_or_int.getValues(), [&](APFloat value) { + return IsSign(value, sgn_splat_value); + }); + } + if (mlir::isa(float_or_int)) { + auto sgn_splat_value = sgn_splat.getSplatValue(); + return llvm::all_of(float_or_int.getValues(), [&](APInt value) { + return IsSign(value, sgn_splat_value); + }); + } + return false; +} + +bool SameTypeOrDefaultCompare(mhlo::ComparisonTypeAttr comparison_type_attr, + ElementsAttr cst) { + if (!comparison_type_attr) return true; + auto comparison_type_attr_value = comparison_type_attr.getValue(); + if (comparison_type_attr_value == mhlo::ComparisonType::FLOAT && + IsDenseSplatFloatAttr(cst)) { + return true; + } + if ((comparison_type_attr_value == mhlo::ComparisonType::SIGNED || + comparison_type_attr_value == mhlo::ComparisonType::UNSIGNED) && + IsDenseSplatIntAttr(cst)) { + return true; + } + return false; +} + +bool ValueIsReciprocal(ElementsAttr float_or_int, ElementsAttr rhs) { + if (IsDenseSplatFloatAttr(float_or_int) && + IsDenseSplatFloatAttr(float_or_int)) { + return (mlir::cast(float_or_int) + .getSplatValue() * + mlir::cast(rhs).getSplatValue()) + .isExactlyValue(1.0); + } else if (IsDenseSplatIntAttr(float_or_int) && + IsDenseSplatIntAttr(float_or_int)) { + return (mlir::cast(float_or_int).getSplatValue() * + mlir::cast(rhs).getSplatValue()) == 1; + } + return false; +} + +bool ValueGreaterThanZero(ElementsAttr float_or_int) { + if (IsDenseSplatIntAttr(float_or_int)) { + auto value = + mlir::cast(float_or_int).getSplatValue(); + return !value.isNegative() && !value.isZero(); + } else if (IsDenseSplatFloatAttr(float_or_int)) { + auto value = + mlir::cast(float_or_int).getSplatValue(); + return !value.isNaN() && !value.isNegative() && !value.isZero(); + } + return false; +} + #define GEN_PASS_DEF_LEGALIZEHLOTOTFLITEPASS #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +bool SupportedComparisonType(mhlo::ComparisonTypeAttr comp_type) { + if (!comp_type) return true; + auto c_ty = comp_type.getValue(); + return c_ty == mhlo::ComparisonType::FLOAT || + c_ty == mhlo::ComparisonType::SIGNED || + c_ty == mhlo::ComparisonType::UNSIGNED || + c_ty == mhlo::ComparisonType::NOTYPE; +} + class LegalizeHloToTfLitePass : public impl::LegalizeHloToTfLitePassBase { public: @@ -69,6 +202,81 @@ bool IsNotOpLegal(mhlo::NotOp op) { return op.getType().getElementType().isInteger(64); } +// Mark possible target ops from rounding patterns as having "unknown" +// legality. This is required to schedule patterns on these ops even +// though MhloDialect is explicitly marked legal (which cannot be changed +// easily). +void AddRoundingOpsAsUnknown(ConversionTarget& target) { + target.addDynamicallyLegalOp< + // go/keep-sorted start + // clang-format off + mhlo::AddOp, + mhlo::BroadcastInDimOp, + mhlo::ConstantOp, + mhlo::DivOp, + mhlo::FloorOp, + mhlo::MulOp, + mhlo::RemOp, + mhlo::RoundOp, + mhlo::SelectOp, + mhlo::SignOp, + mhlo::SubtractOp, + mhlo::TupleOp + // clang-format on + // go/keep-sorted end + >([](Operation* op) { return std::nullopt; }); +} +bool IsCompareLegal(mhlo::CompareOp op) { + return !SupportedComparisonType(op.getCompareTypeAttr()); +} + +void SetUnaryOpLegal(ConversionTarget& target) { + auto is_legal = [](Operation* op) { + return !llvm::cast(op->getOperand(0).getType()) + .getElementType() + .isIntOrFloat(); + }; + target.addDynamicallyLegalOp< + // go/keep-sorted start + // clang-format off + mhlo::AbsOp, + mhlo::BitcastConvertOp, + mhlo::CeilOp, + mhlo::ConvertOp, + mhlo::CosineOp, + mhlo::ExpOp, + mhlo::Expm1Op, + mhlo::FloorOp, + mhlo::ImagOp, + mhlo::IsFiniteOp, + mhlo::Log1pOp, + mhlo::LogOp, + mhlo::LogisticOp, + mhlo::NegOp, + mhlo::RealOp, + mhlo::RsqrtOp, + mhlo::SignOp, + mhlo::SineOp, + mhlo::SqrtOp, + mhlo::TanhOp + // clang-format on + // go/keep-sorted end + >(is_legal); +} + +// mhlo "bitwise ops" can be both bitwise (floats/ints) or logical (bools). +// TFL ops are only one of logical or bitwise. +void SetBinaryBitwiseLegal(ConversionTarget& target) { + auto is_logical = [](Operation* op) { + return llvm::cast(op->getResultTypes()[0]) + .getElementType() + .isInteger(1); + }; + auto is_bitwise = [&](Operation* op) { return !is_logical(op); }; + target.addDynamicallyLegalOp(is_bitwise); + target.addDynamicallyLegalOp(is_logical); +} + #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_tflite_legalize_hlo.inc" void LegalizeHloToTfLitePass::runOnOperation() { MLIRContext* context = &getContext(); @@ -79,10 +287,35 @@ void LegalizeHloToTfLitePass::runOnOperation() { ConversionTarget target(*context); target.addLegalDialect(); target.addLegalOp(); + target.addDynamicallyLegalOp(IsCustomCallLegal); target.addDynamicallyLegalOp(IsCbrtLegal); - target.addIllegalOp(); target.addDynamicallyLegalOp(IsNotOpLegal); + target.addDynamicallyLegalOp(IsCompareLegal); + + target.addIllegalOp< + // go/keep-sorted start + // clang-format off + mhlo::ClampOp, + mhlo::DotGeneralOp, + mhlo::DotOp, + mhlo::DynamicReshapeOp, + mhlo::MaxOp, + mhlo::MinOp, + mhlo::MulOp, + mhlo::PowOp, + mhlo::RemOp, + mhlo::ReshapeOp, + mhlo::ShiftRightArithmeticOp, + mhlo::ShiftRightLogicalOp, + mhlo::TransposeOp + // clang-format on + // go/keep-sorted end + >(); + + AddRoundingOpsAsUnknown(target); + SetUnaryOpLegal(target); + SetBinaryBitwiseLegal(target); PopulatePadPatterns(context, patterns, target); PopulateReducePatterns(context, patterns, target); @@ -91,6 +324,7 @@ void LegalizeHloToTfLitePass::runOnOperation() { PopulateLegalizeConvPatterns(context, patterns, target); PopulateLegalizeSlicePatterns(context, patterns, target); PopulateSortPatterns(context, patterns, target); + PopulateIotaPatterns(context, patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td index 72cbe00c822aa4..55e76560da365f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td @@ -13,14 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -include "mlir/IR/OpBase.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" -include "mhlo/IR/hlo_ops.td" +include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/CommonTypeConstraints.td" -include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" +include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td" -include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" +include "mhlo/IR/hlo_ops.td" + + +def ShapeToConst : NativeCodeCall<"ShapeToConst($_builder, $0)">; def CreateTFLCastToInt32Op : NativeCodeCall< "CreateCastToInt32($0, $_loc, $_builder)">; @@ -29,14 +33,12 @@ def LegalizeTranspose : Pat<(MHLO_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, (CreateTFLCastToInt32Op (TFL_ConstOp $perm)))>; -def LowerCbrt : Pat<(MHLO_CbrtOp $opr), - (TFL_PowOp $opr, - (TFL_DivOp - (Arith_ConstantOp ConstantAttr, "1.0f">), - (Arith_ConstantOp ConstantAttr, "3.0f">), - TFL_AF_None)), - [(F32Tensor $opr)]>; +def LegalizeReshape : Pat<(MHLO_ReshapeOp:$output $input), + (TFL_ReshapeOp $input, + (CreateTFLCastToInt32Op (ShapeToConst $output)))>; +def LegalizeDynamicReshape : Pat<(MHLO_DynamicReshapeOp $input, $shape), + (TFL_ReshapeOp $input, (CreateTFLCastToInt32Op $shape))>; //===----------------------------------------------------------------------===// // logical and bitwise ops @@ -84,3 +86,527 @@ def : Pat<(MHLO_NotOp TensorOf<[UI32]>:$input), (TFL_BitwiseXorOp $input, (Arith_ConstantOp (GetRankedScalarAttr<"u", 32, ", false", "0xFFFFFFFFUL">)))>; + +foreach pair = [ + [MHLO_AndOp, TFL_LogicalAndOp], + [MHLO_OrOp, TFL_LogicalOrOp], +] in { + def : Pat< + (pair[0] TFL_BoolTensor:$l, TFL_BoolTensor:$r), + (pair[1] $l, $r)>; +} + +def LegalizeXor : Pat< + (MHLO_XorOp + TFL_IntTensor:$l, + TFL_IntTensor:$r), + (TFL_BitwiseXorOp $l, $r)>; + +//===----------------------------------------------------------------------===// +// binary element-wise ops +//===----------------------------------------------------------------------===// + +def : Pat< + (MHLO_ShiftRightArithmeticOp $l, $r), + (TFL_RightShiftOp $l, $r)>; + +def : Pat< + (MHLO_ShiftRightLogicalOp $l, $r), + (TFL_RightShiftOp $l, $r)>; + +def : Pat< + (MHLO_RemOp $l, $r), + (TFL_FloorModOp $l, $r)>; + +// Binary ops with no attrs. +foreach pair = [ + [MHLO_MaxOp, TFL_MaximumOp], + [MHLO_MinOp, TFL_MinimumOp], + [MHLO_PowOp, TFL_PowOp], +] in { + def : Pat< + (pair[0] $l, $r), + (pair[1] $l, $r)>; +} + +// Binary ops with fused activiation attr. +foreach pair = [ + [MHLO_MulOp, TFL_MulOp], +] in { + def : Pat< + (pair[0] $l, $r), + (pair[1] $l, $r, TFL_AF_None)>; +} + + + +//===----------------------------------------------------------------------===// +// comparison ops +//===----------------------------------------------------------------------===// + +// Check implicit bool cast of `$_self` to ensure Attribute is non-null before +// casting. +def HasSupportedComparisonType : AttrConstraint< + CPred<"!$_self || SupportedComparisonType($_self.cast())">>; + +class MHLO_ComparisonDirectionValue : + ConstantAttr; + +foreach p = [ + [TFL_EqualOp, MHLO_ComparisonDirectionValue<"EQ">], + [TFL_NotEqualOp, MHLO_ComparisonDirectionValue<"NE">], + [TFL_GreaterEqualOp, MHLO_ComparisonDirectionValue<"GE">], + [TFL_LessEqualOp, MHLO_ComparisonDirectionValue<"LE">], + [TFL_GreaterOp, MHLO_ComparisonDirectionValue<"GT">], + [TFL_LessOp, MHLO_ComparisonDirectionValue<"LT">]] +in { + def : Pat< + (MHLO_CompareOp $l, $r, p[1], HasSupportedComparisonType), + (p[0] $l, $r)>; +} + +//===----------------------------------------------------------------------===// +// unary element-wise op +//===----------------------------------------------------------------------===// + +def LowerCbrt : Pat<(MHLO_CbrtOp $opr), + (TFL_PowOp $opr, + (TFL_DivOp + (Arith_ConstantOp ConstantAttr, "1.0f">), + (Arith_ConstantOp ConstantAttr, "3.0f">), + TFL_AF_None)), + [(F32Tensor $opr)]>; + + +foreach pair = [ + [MHLO_AbsOp, TFL_AbsOp], + [MHLO_BitcastConvertOp, TFL_BitcastOp], + [MHLO_CeilOp, TFL_CeilOp], + [MHLO_CosineOp, TFL_CosOp], + [MHLO_ExpOp, TFL_ExpOp], + [MHLO_FloorOp, TFL_FloorOp], + [MHLO_ImagOp, TFL_ImagOp], + [MHLO_LogOp, TFL_LogOp], + [MHLO_LogisticOp, TFL_LogisticOp], + [MHLO_NegOp, TFL_NegOp], + [MHLO_RealOp, TFL_RealOp], + [MHLO_RsqrtOp, TFL_RsqrtOp], + [MHLO_SineOp, TFL_SinOp], + [MHLO_SignOp, TFL_SignOp], + [MHLO_SqrtOp, TFL_SqrtOp], + [MHLO_TanhOp, TFL_TanhOp] +] in { + def : Pat< + (pair[0] $input), + (pair[1] $input)>; +} + +def : Pat< + (MHLO_ConvertOp $input), + (TFL_CastOp $input)>; + +def : Pat< + (MHLO_Expm1Op F32Tensor:$x), + (TFL_SubOp + (TFL_ExpOp $x), + (Arith_ConstantOp + ConstantAttr, "1.0f">), + TFL_AF_None)>; + +def : Pat< + (MHLO_IsFiniteOp F32Tensor:$x), + (TFL_EqualOp + (TFL_SubOp $x, $x, TFL_AF_None), + (Arith_ConstantOp + ConstantAttr, "0.0f">))>; + +def : Pat< + (MHLO_Log1pOp F32Tensor:$x), + (TFL_LogOp + (TFL_AddOp + $x, + (Arith_ConstantOp + ConstantAttr, "1.0f">), + TFL_AF_None))>; + +//===----------------------------------------------------------------------===// +// rounding +//===----------------------------------------------------------------------===// + +class ValueEquals : + Constraint>; + +def SameValue : + Constraint>; + +def FloatOrDefaultCompare : + Constraint>; + +def SameTypeOrDefaultCompare : + Constraint>; + +def ValueIsReciprocal : + Constraint>; + +def TensorIsSign : + Constraint>; + +def ValueGreaterThanZero : + Constraint>; + + +// Converts a dag of HLOs representing banker rounding (round x.5 to nearest +// even) to tfl.round. This only supports float types because mhlo.floor only +// supports float types. tf.round with integer input type will become an +// identity op, so we will never face an mhlo.floor with an integer input type. +// The pattern matched executes the following computation: +// frac = x - floor(x) +// to_even = (floor(x) - 2 * floor(0.5 * x)) == 1 +// if frac > 0.5 || (frac == 0.5 && to_even) +// return floor(x) + 1 +// else +// return floor(x) +def Round : Pat<(MHLO_SelectOp + (MHLO_OrOp + (MHLO_CompareOp (MHLO_SubtractOp:$frac + $input, + (MHLO_FloorOp:$floor $input)), + (MHLO_ConstantOp $half), + MHLO_ComparisonDirectionValue<"GT">, + $compare_type0), + (MHLO_AndOp + (MHLO_CompareOp + $frac1, + (MHLO_ConstantOp $half1), + MHLO_ComparisonDirectionValue<"EQ">, + $compare_type1), + (MHLO_CompareOp + (MHLO_SubtractOp + $floor1, + (MHLO_MulOp + (MHLO_FloorOp (MHLO_MulOp $input, (MHLO_ConstantOp $half2))), + (MHLO_ConstantOp $two))), + (MHLO_ConstantOp $one1), + MHLO_ComparisonDirectionValue<"EQ">, + $compare_type2))), + (MHLO_AddOp $floor2, (MHLO_ConstantOp $one)), + $floor3), + (TFL_RoundOp $input), + [(ValueEquals<"1.0"> $one), + (ValueEquals<"1.0"> $one1), + (ValueEquals<"2.0"> $two), + (ValueEquals<"0.5"> $half), + (ValueEquals<"0.5"> $half1), + (ValueEquals<"0.5"> $half2), + (SameValue $floor, $floor1), + (SameValue $floor, $floor2), + (SameValue $floor, $floor3), + (SameValue $frac, $frac1), + (FloatOrDefaultCompare $compare_type0), + (FloatOrDefaultCompare $compare_type1), + (FloatOrDefaultCompare $compare_type2)]>; + +// Converts a dag of HLOs representing floor_mod to tfl.floor_mod. +// The pattern matched executes the following computation: +// +// rem = remainder(arg0, arg1) +// for i in 0 to len(arg1): +// if ((rem[i] < 0) != (arg0[i] < 0) && arg0[i] != 0) +// rem[i] += arg1[i] +// return rem +def : Pat<(MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_CompareOp:$rltz + (MHLO_RemOp:$rem $arg, $arg1), + (MHLO_ConstantOp $cst), + MHLO_ComparisonDirectionValue<"LT">, + $compare_type), + (MHLO_CompareOp:$arg1ltz $arg1, (MHLO_ConstantOp $cst1), MHLO_ComparisonDirectionValue<"LT">, $compare_type1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type2), + (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)), + (MHLO_AddOp $rem2, $arg1), + $rem3), + (TFL_FloorModOp $arg, $arg1), + [(ValueEquals<"0.0"> $cst), + (ValueEquals<"0.0"> $cst1), + (ValueEquals<"0.0"> $cst2), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (SameValue $rem, $rem3), + (SameTypeOrDefaultCompare $compare_type, $cst), + (SameTypeOrDefaultCompare $compare_type1, $cst1)]>; + +// Converts a dag of HLOs representing floor_mod with a constant to +// tfl.floor_mod. The pattern matched executes the following computation: +// +// cst = value that is > 0 +// rem = remainder(arg0, cst) +// for i in 0 to len(arg1): +// if (rem[i] < 0 && rem[i] != 0) +// rem[i] += cst +// return rem +def : Pat<(MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp:$rltz + (MHLO_RemOp:$rem $arg, (MHLO_ConstantOp $cst)), + (MHLO_ConstantOp $cst1), + MHLO_ComparisonDirectionValue<"LT">, + $compare_type), + (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)), + (MHLO_AddOp $rem2, (MHLO_ConstantOp $cst3)), + $rem3), + (TFL_FloorModOp $arg, (Arith_ConstantOp $cst3)), + [(ValueGreaterThanZero $cst), + (ValueEquals<"0.0"> $cst1), + (ValueEquals<"0.0"> $cst2), + (SameValue $cst, $cst3), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (SameValue $rem, $rem3), + (SameTypeOrDefaultCompare $compare_type, $cst1), + (SameTypeOrDefaultCompare $compare_type3, $cst2)]>; + +// Converts a dag of HLOs representing floor_div to tfl.floor_div. +// The pattern matched executes the following computation: +// +// rem = remainder(arg0, arg1) +// for i in 0 to len(arg1): +// rem[i] = arg0[i] - rem[i] / arg1[i] +// if (rem[i] != 0 && sign(arg1[i]) != sign(rem[i])) +// rem[i] -= 1.0 +// return round_nearest_afz(rem) +// As a dag this looks like the following: +// round +// | +// -------- select +// | | \ +// && + div +// / | / \ +// != != div -1 +// / | / | / | +// rem 0.0 sn sn1 - $1 +// / | | | / | +// $0 $1 $1 rem $0 rem +// Note that named operators like 'sn' and 'sn1' are different values produced by +// the same function in this case the sign function. Named values like 'div' +// refer to the same value produced by the same function, in this case division. +// Mathematical symbols do not indicate a re-use of the value. +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, $arg1), + (MHLO_ConstantOp $cst), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type), + (MHLO_CompareOp + (MHLO_SignOp $arg1), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type1)), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), + $arg1b), + (MHLO_ConstantOp $cst_neg1)), + $div1)), + (TFL_FloorDivOp $arg0, $arg1), + [(ValueEquals<"0.0"> $cst), + (ValueEquals<"-1.0"> $cst_neg1), + (SameValue $div, $div1), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (FloatOrDefaultCompare $compare_type, $cst), + (FloatOrDefaultCompare $compare_type1, $cst)]>; + +// Converts a dag of HLOs representing floor_div with a splat constant to +// tfl.floor_div. The pattern matched executes the following computation: +// This particular pattern matches multiplication with the reciprocal of the +// constant instead of dividing by the constant. +// rem = remainder(arg0, cst) +// for i in 0 to len(arg0): +// rem[i] = (arg0[i] - rem[i]) * 1 / cst +// if (rem[i] != 0 && sign(cst) != sign(rem[i])) +// rem[i] += -1.0 +// return round_nearest_afz(rem) +// As a dag this looks like the following: +// round +// | +// -------- select +// | | \ +// && + mul +// / | / \ +// != != mul -1 +// / | / | / | +// rem 0.0 cs1 sn1 - cs2 +// / | | / | +// $0 cst rem $0 rem +// cs1 == sign(cst) +// cs2 = 1 / cst i.e. the reciprocal +// Note that named operators like 'sn' and 'sn1' are different values produced by +// the same function in this case the sign function. Named values like 'div' +// refer to the same value produced by the same function, in this case division. +// Mathematical symbols do not indicate a re-use of the value. +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)), + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type), + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type1)), + (MHLO_AddOp + (MHLO_MulOp:$mul + (MHLO_SubtractOp $arg0, $rem2), + (MHLO_ConstantOp $cst_recip)), + (MHLO_ConstantOp $cst_neg1)), + $mul1)), + (TFL_FloorDivOp $arg0, $cst), + [(ValueEquals<"0.0"> $cst_zero), + (ValueEquals<"-1.0"> $cst_neg1), + (TensorIsSign $cstv, $cst_sgn), + (ValueIsReciprocal $cstv, $cst_recip), + (SameValue $mul, $mul1), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (FloatOrDefaultCompare $compare_type, $cst_zero), + (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>; + +// Converts a dag of HLOs representing floor_div with a splat constant to +// tfl.floor_div. The pattern matched executes the following computation: +// This particular pattern matches division with the constant. +// . +// rem = remainder(arg0, cst) +// for i in 0 to len(arg0): +// rem[i] = (arg0[i] - rem[i]) / cst +// if (rem[i] != 0 && sign(cst) != sign(rem[i])) +// rem[i] -= 1.0 +// return round_nearest_afz(rem) +// As a dag this looks like the following: +// round +// | +// -------- select +// | | \ +// && + div +// / | / \ +// != != div -1 +// / | / | / | +// rem 0.0 cs1 sn1 - cs2 +// / | | / | +// $0 cst rem $0 rem +// cs1 == sign(cst) +// cs2 = 1 / cst i.e. the reciprocal +// Note that named operators like 'sn' and 'sn1' are different values produced by +// the same function in this case the sign function. Named values like 'div' +// refer to the same value produced by the same function, in this case division. +// Mathematical symbols do not indicate a re-use of the value. +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)), + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type), + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type1)), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), + (MHLO_ConstantOp $cstv1)), + (MHLO_ConstantOp $cst_neg1)), + $div1)), + (TFL_FloorDivOp $arg0, $cst), + [(ValueEquals<"0.0"> $cst_zero), + (ValueEquals<"-1.0"> $cst_neg1), + (TensorIsSign $cstv, $cst_sgn), + (SameValue $div, $div1), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (SameValue $cstv1, $cstv), + (FloatOrDefaultCompare $compare_type, $cst_zero), + (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>; + +// Converts a dag of HLOs representing floor_div with a broadcasted vector +// constant to tfl.floor_div. The pattern matched executes the following +// computation: +// scs = sign(cst) +// bcst = broadcast(cst) +// rem = remainder(arg0, bcst) +// for i in 0 to len(arg0): +// rem[i] = arg0[i] - rem[i] * / bcst +// if (rem[i] != 0 && scs != sign(rem[i])) +// rem[i] -= 1.0 +// return round_nearest_afz(rem) +// Where scs is a splat constant folded sign on the unbroadcasted tensor. +// +// As a dag this looks like the following: +// round +// | +// -------- select +// | | \ +// && + div +// / | / \ +// != != div -1 +// / | / | / | +// rem 0.0 scs sn1 - bcst +// / | | / | +// $0 bcst rem $0 rem +// | +// cst +// scs == sign(cst) == sign(bcst) +// Note that named operators like 'sn' and 'sn1' are different values produced by +// the same function in this case the sign function. Named values like 'div' +// refer to the same value produced by the same function, in this case division. +// Mathematical symbols do not indicate a re-use of the value. +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, + (MHLO_BroadcastInDimOp:$bcst + (MHLO_ConstantOp $cstv), + $broadcast_dimension)), + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type), + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type1)), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), + $bcst1), + (MHLO_ConstantOp $cst_neg1)), + $div1)), + (TFL_FloorDivOp $arg0, $bcst), + [(ValueEquals<"0.0"> $cst_zero), + (ValueEquals<"-1.0"> $cst_neg1), + (TensorIsSign $cstv, $cst_sgn), + (SameValue $bcst, $bcst1), + (SameValue $div, $div1), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (FloatOrDefaultCompare $compare_type, $cst_zero), + (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>; + + +//===----------------------------------------------------------------------===// +// ternary op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(MHLO_ClampOp $min, $arg, $max), + (TFL_MaximumOp (TFL_MinimumOp $arg, $max), $min)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc new file mode 100644 index 00000000000000..e717114610b527 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc @@ -0,0 +1,555 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "json/json.h" +#include "json/reader.h" +#include "json/value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Analysis/TopologicalSortUtils.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { + +#define GEN_PASS_DEF_BUILDSTABLEHLOCOMPOSITEPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +namespace { + +// Checks if this operation is a MarkTensor operation used to mark the +// boundaries of a composite. +static bool IsMarkTensorOp(mlir::Operation* op) { + if (op == nullptr) { + return false; + } + if (op->getNumOperands() != 1 || op->getNumResults() != 1) { + return false; + } + if (!llvm::isa(op)) { + return false; + } + auto target_name = + mlir::dyn_cast(op->getAttr("call_target_name")); + if (target_name == nullptr || target_name.str() != "mark_tensor") { + return false; + } + return true; +} + +struct BoundaryMetadata { + std::string name; + std::string id; + int64_t pos; + bool is_input; + std::unordered_map attrs; + + auto boundary_key() const { return absl::StrCat(name, "__@@__", id); } + + auto uid() const { return std::forward_as_tuple(name, id, pos, is_input); } + + bool operator==(const BoundaryMetadata& other) const { + return uid() == other.uid(); + } + bool operator<(const BoundaryMetadata& other) const { + return uid() < other.uid(); + } + + static std::unique_ptr Parse(llvm::StringRef str_ref) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(str_ref.str(), root)) { + return nullptr; + } + return Build(root); + } + + private: + template + static bool CopyJsonValue(const Json::Value& json, llvm::StringRef key, + Json::ValueType expected_type, T* to) { + if (!json.isMember(key.str()) || json[key.str()].type() != expected_type) { + return false; + } + + *to = json[key.str()].as(); + return true; + } + + static std::unique_ptr Build(const Json::Value& json) { + BoundaryMetadata metadata; + + bool is_valid_metadata_json = + CopyJsonValue(json, "name", Json::stringValue, &metadata.name) && + CopyJsonValue(json, "id", Json::stringValue, &metadata.id) && + CopyJsonValue(json, "pos", Json::intValue, &metadata.pos) && + CopyJsonValue(json, "is_input", Json::booleanValue, &metadata.is_input); + + if (!is_valid_metadata_json) { + return nullptr; + } + + Json::Value attrs_value = json["attr"]; + if (attrs_value.type() == Json::objectValue) { + for (const auto& key_value : attrs_value.getMemberNames()) { + metadata.attrs.insert({key_value, attrs_value[key_value]}); + } + } + return std::make_unique(std::move(metadata)); + } +}; + +class BuildStableHLOCompositePass + : public impl::BuildStableHLOCompositePassBase< + BuildStableHLOCompositePass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BuildStableHLOCompositePass); + + void runOnOperation() override { + mlir::ModuleOp module_op = getOperation(); + llvm::SmallVector func_ops( + module_op.getOps()); + for (mlir::func::FuncOp& func_op : func_ops) { + llvm::DenseMap op_order_map = + BuildOpOrderMap(func_op); + std::unordered_map> + boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op); + + for (const auto& [unused, ops] : boundary_output_ops_map) { + if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) { + func_op.emitError() << "failed to build composite."; + return signalPassFailure(); + } + } + } + + // Remove mark_tensor custom_call ops. + getOperation()->walk([](mlir::stablehlo::CustomCallOp op) { + if (!IsMarkTensorOp(op.getOperation())) { + return; + } + mlir::Value original_value = op.getOperand(0); + + for (mlir::Value result : op.getResults()) { + result.replaceAllUsesWith(original_value); + } + op.erase(); + }); + } + + private: + llvm::DenseMap BuildOpOrderMap( + mlir::func::FuncOp func_op) const { + llvm::DenseMap op_order_map; + for (const auto& op : llvm::enumerate(func_op.getOps())) { + op_order_map[&op.value()] = op.index(); + } + return op_order_map; + } + + std::unordered_map> + BuildBoundaryOutputOpsMap(mlir::func::FuncOp func_op) { + std::unordered_map> + boundary_output_ops; + + for (auto op : func_op.getOps()) { + auto metadata_or = GetBoundaryMetadata(op); + if (mlir::failed(metadata_or)) { + continue; + } + + std::unique_ptr metadata = std::move(*metadata_or); + if (metadata == nullptr || metadata->is_input) { + continue; + } + + auto& output_ops = boundary_output_ops[metadata->boundary_key()]; + if (metadata->pos >= output_ops.size()) { + output_ops.resize(metadata->pos + 1, nullptr); + } + output_ops[metadata->pos] = op.getOperation(); + } + return boundary_output_ops; + } + + mlir::FailureOr> GetBoundaryMetadata( + mlir::Operation* op) { + if (!IsMarkTensorOp(op)) { + return mlir::FailureOr>(nullptr); + } + auto backend_config = + mlir::dyn_cast(op->getAttr("backend_config")); + if (backend_config == nullptr) { + return mlir::FailureOr>(nullptr); + } + std::unique_ptr metadata = + BoundaryMetadata::Parse(backend_config); + if (metadata == nullptr) { + return op->emitError() << "invalid boundary metadata JSON."; + } + return metadata; + } + + mlir::FailureOr BuildAttrFromJson( + mlir::OpBuilder& builder, mlir::Operation* op, + const Json::Value& json_value) { + switch (json_value.type()) { + case Json::intValue: + case Json::uintValue: + return builder.getI64IntegerAttr(json_value.as()); + case Json::ValueType::realValue: + return builder.getF32FloatAttr(json_value.as()); + case Json::ValueType::booleanValue: + return builder.getBoolAttr(json_value.as()); + case Json::ValueType::stringValue: + return builder.getStringAttr(json_value.as()); + case Json::ValueType::arrayValue: { + if (json_value.empty()) { + return builder.getArrayAttr({}); + } + auto get_json_type = [](const Json::Value& json_value) { + auto ty = json_value.type(); + if (ty == Json::uintValue) { + return Json::intValue; + } + return ty; + }; + + auto head_type = get_json_type(json_value[0]); + bool is_homogeneous = llvm::all_of(json_value, [&](auto& el) { + return get_json_type(el) == head_type; + }); + if (!is_homogeneous) { + return op->emitError() + << "invalid JSON to MLIR, arrays must be homogeneous"; + } + + switch (head_type) { + case Json::intValue: { + llvm::SmallVector int_values; + for (const auto& json_value : json_value) { + int_values.push_back(json_value.as()); + } + return builder.getI64TensorAttr(int_values); + } + case Json::realValue: { + llvm::SmallVector float_values; + for (const auto& json_value : json_value) { + float_values.push_back(json_value.as()); + } + return mlir::DenseFPElementsAttr::get( + mlir::RankedTensorType::get(json_value.size(), + builder.getF32Type()), + float_values); + } + case Json::booleanValue: { + llvm::SmallVector bool_values; + for (const auto& json_value : json_value) { + bool_values.push_back(json_value.as()); + } + return mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get(json_value.size(), + builder.getI1Type()), + bool_values); + } + default: + return op->emitError() + << "invalid JSON to MLIR: invalid array type. arrays must " + "be " + "1-D homogeneous arrays of supported primitive types"; + } + } + default: + return op->emitError() + << "invalid JSON to MLIR: unsupported json value type"; + } + } + + mlir::FailureOr BuildDictionaryAttrFromJsonMap( + mlir::OpBuilder& builder, mlir::Operation* op, + const std::unordered_map& json_map) { + llvm::SmallVector named_attrs; + for (auto& [key, json] : json_map) { + mlir::FailureOr attribute_or = + BuildAttrFromJson(builder, op, json); + if (mlir::failed(attribute_or)) { + return mlir::failure(); + } + named_attrs.push_back({builder.getStringAttr(key), *attribute_or}); + } + return builder.getDictionaryAttr(named_attrs); + } + + mlir::LogicalResult BuildStableHLOComposite( + const llvm::SmallVector& output_ops, + const llvm::DenseMap& op_order_map) { + if (output_ops.empty()) { + return mlir::success(); + } + + // Get the output op with minimum order num as the representative. + mlir::Operation* first_output_op = output_ops[0]; + for (mlir::Operation* op : output_ops) { + if (op_order_map.at(op) < op_order_map.at(first_output_op)) { + first_output_op = op; + } + } + + auto metadata_or = GetBoundaryMetadata(first_output_op); + if (mlir::failed(metadata_or)) { + return mlir::failure(); + } + + std::unique_ptr metadata = std::move(*metadata_or); + if (metadata == nullptr || metadata->is_input) { + // There should always be a valid boundary output metadata associated with + // each op in output_ops. + return mlir::failure(); + } + + auto args_ops_or = + GetBoundaryArgsAndOps(output_ops, *metadata, op_order_map); + if (mlir::failed(args_ops_or)) { + return mlir::failure(); + } + + auto [args, impl_ops] = *args_ops_or; + + mlir::func::FuncOp impl_func = BuildStableHLOCompositeImplFunc( + output_ops, absl::StrCat(metadata->name, ".impl"), args, impl_ops); + mlir::FailureOr composite_op_or = + BuildStableHLOCompositeOp(first_output_op, impl_func, args, *metadata); + if (mlir::failed(composite_op_or)) { + return mlir::failure(); + } + mlir::Operation* composite_op = *composite_op_or; + + // Updates all users of this op's result(s) to use the results(s) of impl + // func call. + size_t composite_result_i = 0; + for (mlir::Operation* op : output_ops) { + for (size_t i = 0; i < op->getNumResults(); ++i) { + mlir::OpResult result = op->getResult(i); + result.replaceAllUsesWith( + composite_op->getResult(composite_result_i++)); + } + } + + if (!mlir::sortTopologically(composite_op->getBlock())) { + composite_op->emitError() + << "The graph is not acyclic after BuildStableHLOCompositePass pass."; + return mlir::failure(); + } + // The unused impl_ops will be eliminated with canonicalizer. + return mlir::success(); + } + + mlir::FailureOr, + llvm::SmallVector>> + GetBoundaryArgsAndOps( + const llvm::SmallVector boundary_output_ops, + const BoundaryMetadata& metadata, + const llvm::DenseMap& op_order_map) { + llvm::SetVector impl_ops_setvec; + llvm::SetVector> arg_pos_setvec; + llvm::SmallVector processing(boundary_output_ops.begin(), + boundary_output_ops.end()); + + // Reverse graph traversal: from boundary output op to boundary input op, + // global function arg, or stablehlo constant. + while (!processing.empty()) { + mlir::Operation* curr_op = processing.back(); + processing.pop_back(); + if (impl_ops_setvec.contains(curr_op)) { + continue; + } + + auto curr_metadata_or = GetBoundaryMetadata(curr_op); + if (mlir::failed(curr_metadata_or)) { + return mlir::failure(); + } + std::unique_ptr curr_metadata = + std::move(*curr_metadata_or); + if (curr_metadata != nullptr) { + if (curr_metadata->is_input && + curr_metadata->boundary_key() == metadata.boundary_key()) { + // Terminal condition: boundary input op. + + arg_pos_setvec.insert( + {mlir::dyn_cast(curr_op->getResult(0)), + curr_metadata->pos}); + continue; + } + } + + impl_ops_setvec.insert(curr_op); + for (mlir::Value value : curr_op->getOperands()) { + mlir::Operation* def_op = value.getDefiningOp(); + if (def_op == nullptr) { + // Terminal condition: global function arg + arg_pos_setvec.insert({value, std::numeric_limits::max()}); + } else if (llvm::isa(def_op)) { + // Terminal condition: constant + impl_ops_setvec.insert(def_op); + } else { + processing.push_back(def_op); + } + } + } + // Sorts all ops within the boundary by their line numbers in the input + // MLIR. The ops will be duplicated to the impl function following this + // order. + llvm::SmallVector impl_ops = impl_ops_setvec.takeVector(); + for (auto& op : impl_ops) { + if (!op_order_map.contains(op)) { + return op->emitError() + << "does not have a ordering number in its outer func."; + } + } + std::sort(impl_ops.begin(), impl_ops.end(), + [&op_order_map](const auto& a, const auto& b) { + return op_order_map.at(a) < op_order_map.at(b); + }); + + // Sorts boundary args by their positions. Note that the args of the + // composite and impl function may be more than the boundary inputs, because + // the MLIR is lowered from the functionalized graph and additional args may + // be Pytorch constants. In such case the position of those args would be + // undetermined, while they would always come after boundary inputs. + auto arg_pos_pairs = arg_pos_setvec.takeVector(); + std::stable_sort( + arg_pos_pairs.begin(), arg_pos_pairs.end(), + [](const auto& a, const auto& b) { return a.second < b.second; }); + llvm::SmallVector args; + args.reserve(arg_pos_pairs.size()); + for (auto& [arg, unused] : arg_pos_pairs) { + args.push_back(arg); + } + + return std::make_pair(std::move(args), std::move(impl_ops)); + } + + mlir::func::FuncOp BuildStableHLOCompositeImplFunc( + const llvm::SmallVector boundary_output_ops, + llvm::StringRef func_name, const llvm::SmallVector& args, + const llvm::SmallVector& impl_ops) { + mlir::ModuleOp module_op = getOperation(); + mlir::MLIRContext* context = &getContext(); + mlir::OpBuilder builder(context); + + // Creates composite impl function and duplicates all ops within the + // boundary in the function. + llvm::SmallVector arg_locs; + llvm::SmallVector arg_types; + for (auto& arg : args) { + arg_types.push_back(arg.getType()); + arg_locs.push_back(arg.getLoc()); + } + llvm::SmallVector result_types; + for (mlir::Operation* op : boundary_output_ops) { + result_types.append(op->getResultTypes().begin(), + op->getResultTypes().end()); + } + + mlir::func::FuncOp impl_func = builder.create( + module_op.getLoc(), func_name, + mlir::FunctionType::get(context, arg_types, result_types)); + mlir::IRMapping mapping; + builder.createBlock(&impl_func.getBody(), impl_func.begin(), arg_types, + arg_locs); + for (const auto& arg : llvm::enumerate(args)) { + mapping.map(arg.value(), impl_func.getArgument(arg.index())); + } + for (mlir::Operation* original_op : impl_ops) { + mlir::Operation* cloned_op = builder.clone(*original_op, mapping); + mapping.map(original_op, cloned_op); + } + + llvm::SmallVector results; + for (mlir::Operation* op : boundary_output_ops) { + results.append(mapping.lookup(op)->getResults().begin(), + mapping.lookup(op)->getResults().end()); + } + builder.create(impl_func.getBody().getLoc(), results); + + // Adds the new function to symbol table. + mlir::SymbolTable symbol_table(module_op); + impl_func.setPrivate(); + symbol_table.insert(impl_func); + + return impl_func; + } + + mlir::FailureOr BuildStableHLOCompositeOp( + mlir::Operation* boundary_output_op, mlir::func::FuncOp impl_func, + const llvm::SmallVector& args, + const BoundaryMetadata& metadata) { + mlir::MLIRContext* context = &getContext(); + mlir::OpBuilder builder(context); + + mlir::FailureOr attributes_or = + BuildDictionaryAttrFromJsonMap(builder, boundary_output_op, + metadata.attrs); + if (mlir::failed(attributes_or)) { + return boundary_output_op->emitError() + << "failed to transform boundary attr " + "JSON into composite attributes."; + } + + // Creates and inserts composite call op. + builder.setInsertionPointAfter(boundary_output_op); + mlir::Operation* composite_op = + builder.create( + boundary_output_op->getLoc(), + impl_func.getFunctionType().getResults(), args, metadata.name, + *attributes_or, impl_func.getSymName()); + return composite_op; + } +}; + +} // namespace +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc new file mode 100644 index 00000000000000..89e23c6edc4a24 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc @@ -0,0 +1,54 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { +#define GEN_PASS_DEF_LIFTCALLSITELOCCALLERPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +namespace { + +// JAX bridge generates a func.call for each op lowering +// These are inlined but loc will be messed up after the inline pass. This pass +// normalize the loc after inline pass. + +class LiftCallSiteLocCallerPass + : public impl::LiftCallSiteLocCallerPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftCallSiteLocCallerPass); + + void runOnOperation() override { + getOperation()->walk([](func::FuncOp func_op) { + for (Operation& op : func_op.getOps()) { + if (!mlir::isa(op.getLoc())) { + continue; + } + + auto loc = op.getLoc().dyn_cast(); + op.setLoc(loc.getCaller()); + } + }); + } +}; + +} // namespace +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stateful_error_reporter.h b/tensorflow/compiler/mlir/lite/stateful_error_reporter.h new file mode 100644 index 00000000000000..fbb82d3e54d121 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stateful_error_reporter.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ + +// LINT.IfChange +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite_migration { + +// Similar to tflite::ErrorReporter, except that it allows callers to get the +// last error message. +class StatefulErrorReporter : public tflite::ErrorReporter { + public: + // Returns last error message. Returns empty string if no error is reported. + virtual std::string message() = 0; +}; + +} // namespace tflite_migration +// LINT.ThenChange(//tensorflow/lite/stateful_error_reporter.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 9626a292b8eb6d..7b949d3d551151 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -151,6 +151,75 @@ func.func @mul_f16() -> (tensor, tensor<4xf16>, tensor<4xf16>, tensor<4xf16 func.return %5, %6, %7, %8 : tensor, tensor<4xf16>, tensor<4xf16>, tensor<4xf16> } +// CHECK-LABEL: @mul_zero +func.func @mul_zero(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %zero_int = arith.constant dense<0> : tensor<4xi32> + %zero_float = arith.constant dense<0.0> : tensor<4xf32> + + // CHECK-NOT: tfl.mul + // CHECK: return %cst, %cst_0 + + %0 = "tfl.mul"(%arg0, %zero_int) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.mul"(%arg1, %zero_float) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @mul_zero_lhs +func.func @mul_zero_lhs(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %zero_int = arith.constant dense<0> : tensor<4xi32> + %zero_float = arith.constant dense<0.0> : tensor<4xf32> + + // CHECK-NOT: tfl.mul + // CHECK: return %cst, %cst_0 + + %0 = "tfl.mul"(%zero_int, %arg0) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.mul"(%zero_float, %arg1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @mul_one +func.func @mul_one(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %one_int = arith.constant dense<1> : tensor<4xi32> + %one_float = arith.constant dense<1.0> : tensor<4xf32> + + // CHECK-NOT: tfl.mul + // CHECK: return %arg0, %arg1 + + %0 = "tfl.mul"(%arg0, %one_int) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.mul"(%arg1, %one_float) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @mul_one_lhs +func.func @mul_one_lhs(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %one_int = arith.constant dense<1> : tensor<4xi32> + %one_float = arith.constant dense<1.0> : tensor<4xf32> + + // CHECK-NOT: tfl.mul + // CHECK: return %arg0, %arg1 + + %0 = "tfl.mul"(%one_int, %arg0) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.mul"(%one_float, %arg1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @mul_one_quant +func.func @mul_one_quant(%arg0: tensor<32x!quant.uniform>) -> tensor<32x!quant.uniform> { + %one = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<1> : tensor<32xi8>} : () -> tensor<32x!quant.uniform> + + // CHECK: %[[MUL:.*]] = tfl.mul + // CHECK: return %[[MUL]] + + %0 = "tfl.mul"(%one, %arg0) {fused_activation_function = "NONE"} : (tensor<32x!quant.uniform>, tensor<32x!quant.uniform>) -> tensor<32x!quant.uniform> + + func.return %0 : tensor<32x!quant.uniform> +} + + // CHECK-LABEL: @elementwise_unary_ops func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor) { %0 = arith.constant dense<-1.0> : tensor @@ -191,6 +260,15 @@ func.func @max_with_neg_f32_max_val(%arg0 : tensor) -> (tensor, tensor // CHECK: return %[[ARG0]], %[[ARG0]] } +// CHECK-LABEL: @max_with_neg_inf +func.func @max_with_neg_inf(%arg0 : tensor) -> (tensor, tensor) { + %neg_inf = arith.constant dense<0xFF800000> : tensor + %0 = "tfl.maximum"(%arg0, %neg_inf) : (tensor, tensor) -> tensor + %1 = "tfl.maximum"(%neg_inf, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + // CHECK-LABEL: @min_with_f32_max_val // CHECK-SAME: (%[[ARG0:.+]]: tensor) func.func @min_with_f32_max_val(%arg0 : tensor) -> (tensor, tensor) { @@ -201,6 +279,15 @@ func.func @min_with_f32_max_val(%arg0 : tensor) -> (tensor, tensor) -> (tensor, tensor) { + %inf = arith.constant dense<0x7F800000> : tensor + %0 = "tfl.minimum"(%arg0, %inf) : (tensor, tensor) -> tensor + %1 = "tfl.minimum"(%inf, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + // CHECK-LABEL: @max_with_neg_f64_max_val // CHECK-SAME: (%[[ARG0:.+]]: tensor) func.func @max_with_neg_f64_max_val(%arg0 : tensor) -> (tensor, tensor) { @@ -672,6 +759,32 @@ func.func @div_dense_different_rank() -> tensor<1x2x2xf32> { // CHECK: return %[[CST]] } +// CHECK-LABEL: @div_one +func.func @div_one(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %one_int = arith.constant dense<1> : tensor<4xi32> + %one_float = arith.constant dense<1.0> : tensor<4xf32> + + // CHECK-NOT: tfl.div + // CHECK: return %arg0, %arg1 + + %0 = "tfl.div"(%arg0, %one_int) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.div"(%arg1, %one_float) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @div_one_quant +func.func @div_one_quant(%arg0: tensor<32x!quant.uniform>) -> tensor<32x!quant.uniform> { + %one = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<1> : tensor<32xi8>} : () -> tensor<32x!quant.uniform> + + // CHECK: %[[DIV:.*]] = tfl.div + // CHECK: return %[[DIV]] + + %0 = "tfl.div"(%arg0, %one) {fused_activation_function = "NONE"} : (tensor<32x!quant.uniform>, tensor<32x!quant.uniform>) -> tensor<32x!quant.uniform> + + func.return %0 : tensor<32x!quant.uniform> +} + // CHECK-LABEL: @rsqrt_bf16 func.func @rsqrt_bf16() -> tensor { %cst = arith.constant dense<4.0> : tensor @@ -779,6 +892,51 @@ func.func @cast_ui8_to_i1() -> tensor<4xi1> { // CHECK: return %[[CST]] } +// CHECK-LABEL: @cast_f32_to_i32 +func.func @cast_f32_to_i32() -> tensor<8xi32> { + %cst = arith.constant dense<[-1.0, 0.0, 1.5, 0.99, 1.175494351e-38, 3.402823466e+38, -3.402823466e+38, -1.175494351e-38]> : tensor<8xf32> + %0 = "tfl.cast"(%cst) : (tensor<8xf32>) -> tensor<8xi32> + func.return %0 : tensor<8xi32> +} + +// CHECK: %cst = arith.constant dense<[-1, 0, 1, 0, 0, 2147483647, -2147483648, 0]> : tensor<8xi32> + +// CHECK-LABEL: @cast_i32_to_f32 +func.func @cast_i32_to_f32() -> tensor<5xf32> { + %cst = arith.constant dense<[-1, 0, 2, 2147483647, -2147483648]> : tensor<5xi32> + %0 = "tfl.cast"(%cst) : (tensor<5xi32>) -> tensor<5xf32> + func.return %0 : tensor<5xf32> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 2.000000e+00, 2.14748365E+9, -2.14748365E+9]> : tensor<5xf32> + +// CHECK-LABEL: @cast_bool_to_f32 +func.func @cast_bool_to_f32() -> tensor<2xf32> { + %cst = arith.constant dense<[true, false]> : tensor<2xi1> + %0 = "tfl.cast"(%cst) : (tensor<2xi1>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: %cst = arith.constant dense<[1.000000e+00, 0.000000e+00]> : tensor<2xf32> + +// CHECK-LABEL: @cast_f64_to_f32 +func.func @cast_f64_to_f32() -> tensor<4xf32> { + %cst = arith.constant dense<[-1.0, 0.0, 1.5, 100.0]> : tensor<4xf64> + %0 = "tfl.cast"(%cst) : (tensor<4xf64>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.500000e+00, 1.000000e+02]> : tensor<4xf32> + +// CHECK-LABEL: @cast_f32_to_f64 +func.func @cast_f32_to_f64() -> tensor<4xf64> { + %cst = arith.constant dense<[-1.0, 0.0, 1.5, 100.0]> : tensor<4xf32> + %0 = "tfl.cast"(%cst) : (tensor<4xf32>) -> tensor<4xf64> + func.return %0 : tensor<4xf64> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.500000e+00, 1.000000e+02]> : tensor<4xf64> + // CHECK-LABEL: @ConstantFoldFullyConnectedSmall func.func @ConstantFoldFullyConnectedSmall() -> tensor<3xf32> { %cst_input = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> @@ -942,3 +1100,336 @@ func.func @ConstFoldEmbeddingLookup() -> (tensor<5x2xf32>, tensor<3x2x2xf32>) { // CHECK-DAG: %[[LOOKUP1:.*]] = arith.constant dense<{{\[\[\[}}1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]], {{\[\[}}5.000000e+00, 6.000000e+00], [7.000000e+00, 8.000000e+00]], {{\[\[}}1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]]> : tensor<3x2x2xf32> // CHECK: return %[[LOOKUP0]], %[[LOOKUP1]] : tensor<5x2xf32>, tensor<3x2x2xf32> } + +// CHECK-LABEL: @less_int_both_splat +func.func @less_int_both_splat() -> tensor<4xi1> { + %0 = arith.constant dense<3> : tensor<4xi32> + %1 = arith.constant dense<10> : tensor<4xi32> + + %2 = "tfl.less"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense : tensor<4xi1> + +// CHECK-LABEL: @less_int_one_splat +func.func @less_int_one_splat() -> tensor<4xi1> { + %0 = arith.constant dense<3> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.less"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK:%cst = arith.constant dense<[true, false, false, false]> : tensor<4xi1> + +// CHECK-LABEL: @less_int +func.func @less_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.less"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, false, false, true]> : tensor<4xi1> + +// CHECK-LABEL: @less_float +func.func @less_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.less"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, false, false, true]> : tensor<4xi1> + +// CHECK-LABEL: @less_equal_int +func.func @less_equal_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.less_equal"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, true]> : tensor<4xi1> + +// CHECK-LABEL: @less_equal_float +func.func @less_equal_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.less_equal"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, true]> : tensor<4xi1> + +// CHECK-LABEL: @greater_int +func.func @greater_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.greater"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true, false]> : tensor<4xi1> + +// CHECK-LABEL: @greater_float +func.func @greater_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.greater"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true, false]> : tensor<4xi1> + +// CHECK-LABEL: @greater_equal_int +func.func @greater_equal_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.greater_equal"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, true, true, false]> : tensor<4xi1> + +// CHECK-LABEL: @greater_equal_float +func.func @greater_equal_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.greater_equal"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, true, true, false]> : tensor<4xi1> + +// CHECK-LABEL: @equal_int +func.func @equal_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.equal"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, false]> : tensor<4xi1> + +// CHECK-LABEL: @equal_float +func.func @equal_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.equal"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, false]> : tensor<4xi1> + +// CHECK-LABEL: @not_equal_int +func.func @not_equal_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.not_equal"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true, true]> : tensor<4xi1> + +// CHECK-LABEL: @not_equal_float +func.func @not_equal_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.not_equal"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true, true]> : tensor<4xi1> + +// CHECK-LABEL: @logical_or +func.func @logical_or() -> tensor<3xi1> { + %0 = arith.constant dense<[true, false, true]> : tensor<3xi1> + %1 = arith.constant dense<[false, false, true]> : tensor<3xi1> + + %2 = "tfl.logical_or"(%0, %1) : (tensor<3xi1>, tensor<3xi1>) -> tensor<3xi1> + + func.return %2 : tensor<3xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true]> : tensor<3xi1> + +// CHECK-LABEL: @logical_and +func.func @logical_and() -> tensor<3xi1> { + %0 = arith.constant dense<[true, false, true]> : tensor<3xi1> + %1 = arith.constant dense<[false, false, true]> : tensor<3xi1> + + %2 = "tfl.logical_and"(%0, %1) : (tensor<3xi1>, tensor<3xi1>) -> tensor<3xi1> + + func.return %2 : tensor<3xi1> +} + +// CHECK: %cst = arith.constant dense<[false, false, true]> : tensor<3xi1> + +// CHECK-LABEL: @select_splat_cond +func.func @select_splat_cond() -> tensor<4xi32> { + %cond = arith.constant dense : tensor<4xi1> + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + %1 = arith.constant dense<[-1, -2, -3, -4]> : tensor<4xi32> + + %2 = "tfl.select"(%cond, %0, %1) : (tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + func.return %2 : tensor<4xi32> +} + +// CHECK: %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + +// CHECK-LABEL: select_splat_lhs +func.func @select_splat_lhs() -> tensor<4xi32> { + %cond = arith.constant dense<[true, true, false, false]> : tensor<4xi1> + %0 = arith.constant dense<0> : tensor<4xi32> + %1 = arith.constant dense<[-1, -2, -3, -4]> : tensor<4xi32> + + %2 = "tfl.select"(%cond, %0, %1) : (tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + func.return %2 : tensor<4xi32> +} + +// CHECK: %cst = arith.constant dense<[0, 0, -3, -4]> : tensor<4xi32> + +// CHECK-LABEL: select_float +func.func @select_float() -> tensor<4xf32> { + %cond = arith.constant dense<[true, true, false, false]> : tensor<4xi1> + %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + %1 = arith.constant dense<[-1.0, -2.0, -3.0, -4.0]> : tensor<4xf32> + + %2 = "tfl.select"(%cond, %0, %1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %2 : tensor<4xf32> +} + +// CHECK: %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, -3.000000e+00, -4.000000e+00]> : tensor<4xf32 + +// CHECK-LABEL: floor +func.func @floor() -> tensor<3xf32> { + %cst = arith.constant dense<[-1.0, 0.0, 0.99]> : tensor<3xf32> + %0 = "tfl.floor"(%cst) : (tensor<3xf32>) -> tensor<3xf32> + func.return %0 : tensor<3xf32> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 0.000000e+00]> : tensor<3xf32> + +// CHECK-LABEL: floor_f64 +func.func @floor_f64() -> tensor<3xf64> { + %cst = arith.constant dense<[-1.0, 0.0, 0.99]> : tensor<3xf64> + %0 = "tfl.floor"(%cst) : (tensor<3xf64>) -> tensor<3xf64> + func.return %0 : tensor<3xf64> +} + +// CHECK: tfl.floor + +// CHECK-LABEL: exp +func.func @exp() -> tensor<4xf32> { + %cst = arith.constant dense<[-1.0, 0.0, 0.99, 0.36787944117]> : tensor<4xf32> + %0 = "tfl.exp"(%cst) : (tensor<4xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: %cst = arith.constant dense<[0.36787945, 1.000000e+00, 2.69123459, 1.44466782]> : tensor<4xf32> + +// CHECK-LABEL: exp_f64 +func.func @exp_f64() -> tensor<4xf64> { + %cst = arith.constant dense<[-1.0, 0.0, 0.99, 0.36787944117]> : tensor<4xf64> + %0 = "tfl.exp"(%cst) : (tensor<4xf64>) -> tensor<4xf64> + func.return %0 : tensor<4xf64> +} + +// CHECK: tfl.exp + +// CHECK-LABEL: pow_float +func.func @pow_float() -> tensor<3xf32> { + %0 = arith.constant dense<[1.0, 0.0, 2.0]> : tensor<3xf32> + %1 = arith.constant dense<[2.0, 3.0, -1.5]> : tensor<3xf32> + + %2 = "tfl.pow"(%0, %1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + + func.return %2 : tensor<3xf32> +} + +// CHECK: %cst = arith.constant dense<[1.000000e+00, 0.000000e+00, 0.353553385]> : tensor<3xf32> + +// CHECK-LABEL: pow_int +func.func @pow_int() -> tensor<3xi32> { + %0 = arith.constant dense<[1, 0, 2]> : tensor<3xi32> + %1 = arith.constant dense<[2, 3, -1]> : tensor<3xi32> + + %2 = "tfl.pow"(%0, %1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + + func.return %2 : tensor<3xi32> +} + +// CHECK: %cst = arith.constant dense<[1, 0, 0]> : tensor<3xi32> + +// CHECK-LABEL: logical_not +func.func @logical_not() -> tensor<3xi1> { + %cst = arith.constant dense<[false, true, false]> : tensor<3xi1> + %0 = "tfl.logical_not"(%cst) : (tensor<3xi1>) -> tensor<3xi1> + func.return %0 : tensor<3xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true]> : tensor<3xi1> + +// CHECK-LABEL: logical_not_splat +func.func @logical_not_splat() -> tensor<3xi1> { + %cst = arith.constant dense : tensor<3xi1> + %0 = "tfl.logical_not"(%cst) : (tensor<3xi1>) -> tensor<3xi1> + func.return %0 : tensor<3xi1> +} + +// CHECK: %cst = arith.constant dense : tensor<3xi1> + +// CHECK-LABEL: bitwise_xor_i32 +func.func @bitwise_xor_i32() -> tensor<3xi32> { + %0 = arith.constant dense<[0, 5, 3]> : tensor<3xi32> + %1 = arith.constant dense<[5, 0, 7]> : tensor<3xi32> + + %2 = "tfl.bitwise_xor"(%0, %1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + + func.return %2 : tensor<3xi32> +} + +// CHECK: %cst = arith.constant dense<[5, 5, 4]> : tensor<3xi32> + +// CHECK-LABEL: bitwise_xor_ui8 +func.func @bitwise_xor_ui8() -> tensor<3xui8> { + %0 = arith.constant dense<[0, 5, 3]> : tensor<3xui8> + %1 = arith.constant dense<[5, 0, 7]> : tensor<3xui8> + + %2 = "tfl.bitwise_xor"(%0, %1) : (tensor<3xui8>, tensor<3xui8>) -> tensor<3xui8> + + func.return %2 : tensor<3xui8> +} + +// CHECK: %cst = arith.constant dense<[5, 5, 4]> : tensor<3xui8> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir index 7ea7e48777522e..77edd7a648fcaa 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir @@ -16,3 +16,50 @@ func.func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf3 %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> func.return %0 : tensor<*xf32> } + +// ----- + +func.func @tfl_if(%arg0: tensor) -> tensor { +// CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}) <{else_branch = @tfl.if_else, is_stateless = false, then_branch = @tfl.if_then}> : (tensor, tensor) -> tensor + %cst = arith.constant dense<0> : tensor + %0 = tfl.add %cst, %cst {fused_activation_function = "NONE"} : tensor + %1 = "tfl.if"(%arg0) ({ + %2 = func.call @tfl.if_then(%0) : (tensor) -> tensor + "tfl.yield"(%2) : (tensor) -> () + }, { + %2 = func.call @tfl.if_else(%0) : (tensor) -> tensor + "tfl.yield"(%2) : (tensor) -> () + }) : (tensor) -> tensor + return %1 : tensor +} +func.func private @tfl.if_then(%arg0: tensor) -> tensor { + return %arg0 : tensor +} +func.func private @tfl.if_else(%arg0: tensor) -> tensor { + %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor + return %0 : tensor +} + +// ----- + +func.func @tfl_if_multi_args(%arg0: tensor) -> tensor { +// CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) <{else_branch = @tfl.if_else_1, is_stateless = false, then_branch = @tfl.if_then_1}> : (tensor, tensor, tensor) -> tensor + %cst = arith.constant dense<0> : tensor + %0 = tfl.add %cst, %cst {fused_activation_function = "NONE"} : tensor + %1 = tfl.mul %cst, %cst {fused_activation_function = "NONE"} : tensor + %2 = "tfl.if"(%arg0) ({ + %2 = func.call @tfl.if_then_1(%0, %1) : (tensor, tensor) -> tensor + "tfl.yield"(%2) : (tensor) -> () + }, { + %2 = func.call @tfl.if_else_1(%0, %1) : (tensor, tensor) -> tensor + "tfl.yield"(%2) : (tensor) -> () + }) : (tensor) -> tensor + return %1 : tensor +} +func.func private @tfl.if_then_1(%arg0: tensor, %arg1: tensor) -> tensor { + return %arg0 : tensor +} +func.func private @tfl.if_else_1(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor + return %0 : tensor +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir index 32cd4552f0b15d..9d5bad8c7d6181 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir @@ -3,7 +3,7 @@ func.func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): - %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %0 = "tfl.pseudo_const" () {value = dense<2.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference") // CHECK: %[[MUL:.*]] = tfl.mul %2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul") diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir index 060d5fc871665a..cb87e4f0a2147f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir @@ -3,7 +3,7 @@ func.func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): - %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %0 = "tfl.pseudo_const" () {value = dense<2.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference") // CHECK: %[[MUL:.*]] = tfl.mul %2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul") diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 109407255f6780..4301cbf8627b79 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -1,11 +1,11 @@ // Run optimize pass only and check the results. -// RUN: tf-opt %s -tfl-optimize | FileCheck %s +// RUN: tf-opt %s -tfl-optimize='enable-canonicalization=false' | FileCheck %s // Run optimize pass and then canonicalize pass, and make sure some folding is applied. -// RUN: tf-opt %s -tfl-optimize='enable-canonicalization=true' | FileCheck --check-prefix=FOLD %s +// RUN: tf-opt %s -tfl-optimize | FileCheck --check-prefix=FOLD %s // Run legalize pass and then optimize pass, and make sure some fusing is applied. -// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize | FileCheck --check-prefix=Fusing %s +// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize='enable-canonicalization=false' | FileCheck --check-prefix=Fusing %s // Run legalize pass and then optimize pass, and make sure some fusing is applied, but no mul->fc. -// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize='disable-fuse-mul-and-fc=true' | FileCheck --check-prefix=NoFusing %s +// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize='enable-canonicalization=false disable-fuse-mul-and-fc=true' | FileCheck --check-prefix=NoFusing %s // CHECK-LABEL: fusedConv2dRelu func.func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x32x32x16xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir b/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir index a5da33ca90191b..8796d690f72796 100644 --- a/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir +++ b/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir @@ -164,4 +164,35 @@ func.func @pushTposeBcastScalarCstInput(%arg0: tensor<2x3x4x5xf32>) -> tensor<5x // CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x3x4x5xf32>, tensor) -> tensor<2x3x4x5xf32> // CHECK: %1 = "tfl.transpose"(%0, %cst_0) : (tensor<2x3x4x5xf32>, tensor<4xi32>) -> tensor<5x2x3x4xf32> +// ----- + +// CHECK-LABEL: pushTposeDynamicBcastScalarCstInput +func.func @pushTposeDynamicBcastScalarCstInput(%arg0: tensor) -> tensor<5x?x?x4xf32> { + %perm = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %perm) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> + %cst = arith.constant dense<1.0> : tensor + %1 = "tfl.add"(%0, %cst) { fused_activation_function = "NONE" } : (tensor<5x?x?x4xf32>, tensor) -> tensor<5x?x?x4xf32> + func.return %1 : tensor<5x?x?x4xf32> +} + +// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK: %cst_0 = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor) -> tensor +// CHECK: %1 = "tfl.transpose"(%0, %cst_0) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> +// ----- + +// CHECK-LABEL: doubleTposeDynamicInput +func.func @doubleTposeDynamicInput(%arg0: tensor, %arg1: tensor) -> tensor<5x?x?x4xf32> { + %perm = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %perm) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> + %perm1 = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> + %1 = "tfl.transpose"(%arg1, %perm1) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> + %2 = tfl.add %0, %1 { fused_activation_function = "NONE" } : tensor<5x?x?x4xf32> + func.return %2 : tensor<5x?x?x4xf32> +} + +// CHECK: %cst = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> +// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor +// CHECK: %1 = "tfl.transpose"(%0, %cst) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> +// CHECK: return %1 : tensor<5x?x?x4xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir b/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir index 477315d696783c..d9382fdeb3341b 100644 --- a/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: custom_op func.func @custom_op(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = "arith.constant" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> + %0 = "arith.constant" () {value = dense<2.0> : tensor<4xf32>} : () -> tensor<4xf32> %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // will be preserved since it has uses. %2 = "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> @@ -11,7 +11,7 @@ func.func @custom_op(%arg0: tensor<4xf32>) -> tensor<4xf32> { "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %2 : tensor<4xf32> -// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32> +// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<2.000000e+00> : tensor<4xf32> // CHECK-NEXT: %[[MUL:.*]] = tfl.mul %arg0, %[[CST]] {fused_activation_function = "NONE"} : tensor<4xf32> // CHECK-NEXT: %[[CUSTOM_1:.*]] = "tfl.custom_tf"(%[[MUL]], %[[CST]]) ({ // CHECK-NEXT: ^bb0(%arg1: tensor<4xf32>, %arg2: tensor<4xf32>): diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 782552599024ea..878026e6b47913 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -102,10 +102,8 @@ void AddQuantizationPasses(const mlir::TFL::PassConfig& pass_config, mlir::TFL::CreateOptimizeBatchMatmulPass()); } // Add TFLite optimize pass. - mlir::TFL::OptimizePassOptions optimize_pass_options; - optimize_pass_options.enable_canonicalization_ = true; pass_manager.addNestedPass( - mlir::TFL::CreateOptimizePass(optimize_pass_options)); + mlir::TFL::CreateOptimizePass()); } void AddVariableFreezingFromGlobalTensorsPasses( @@ -161,10 +159,16 @@ void AddDynamicRangeQuantizationPasses(const mlir::TFL::PassConfig& pass_config, } // Add TFLite optimize pass. - mlir::TFL::OptimizePassOptions optimize_pass_options; - optimize_pass_options.enable_canonicalization_ = true; pass_manager.addNestedPass( - mlir::TFL::CreateOptimizePass(optimize_pass_options)); + mlir::TFL::CreateOptimizePass()); +} + +void AddPytorchPasses(mlir::OpPassManager& pass_manager) { + pass_manager.addNestedPass(mlir::createCSEPass()); + pass_manager.addPass(mlir::odml::createBuildStableHLOCompositePass()); + pass_manager.addPass(mlir::createInlinerPass()); + pass_manager.addPass(mlir::odml::createLiftCallSiteLocCallerPass()); + pass_manager.addNestedPass(mlir::createCSEPass()); } void AddPreQuantizationStableHloToTfPasses( @@ -174,6 +178,10 @@ void AddPreQuantizationStableHloToTfPasses( pass_manager.addPass( mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); + if (pass_config.model_origin_framework == toco::TocoFlags::PYTORCH) { + AddPytorchPasses(pass_manager); + } + // Legalize MHLO to StableHLO should be moved closer to where it is needed // There are some entry points that start with HLO->MHLO like // jax_to_tfl_flatbuffer.cc which can likely be updated to emit StableHLO @@ -501,21 +509,30 @@ void AddPostVariableFreezingTFToTFLConversionPasses( pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass()); - if (!pass_config.unfold_batch_matmul) { - // Enable an optimization pass that transforms FC to BatchMatmul only when - // `unfold_batch_matmul=false`. - pass_manager->addNestedPass( - mlir::TFL::CreateOptimizeBatchMatmulPass()); - } - pass_manager->addPass(mlir::TFL::CreatePushTransposeThroughEwisePass()); - // Add TFLite optimize pass. mlir::TFL::OptimizePassOptions optimize_pass_options; - optimize_pass_options.enable_canonicalization_ = true; - optimize_pass_options.disable_fuse_mul_and_fc_ = + optimize_pass_options.disable_fuse_mul_and_fc = toco_flags.disable_fuse_mul_and_fc(); - pass_manager->addNestedPass( - mlir::TFL::CreateOptimizePass(optimize_pass_options)); + + auto add_tfl_optimization_passes = [&]() { + if (!pass_config.unfold_batch_matmul) { + // Enable an optimization pass that transforms FC to BatchMatmul only + // when `unfold_batch_matmul=false`. + pass_manager->addNestedPass( + mlir::TFL::CreateOptimizeBatchMatmulPass()); + } + pass_manager->addPass(mlir::TFL::CreatePushTransposeThroughEwisePass()); + + // Add TFLite optimize pass. + pass_manager->addNestedPass( + mlir::TFL::CreateOptimizePass(optimize_pass_options)); + }; + + // Run TFL optimization passes set multiple times as op fusion and + // reordering in later passes may enable further optimizations with earlier + // passes. + add_tfl_optimization_passes(); + add_tfl_optimization_passes(); // This pass operates on TensorFlow ops but is triggered after legalization // so that it can target constants introduced once TensorFlow Identity ops diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 02c4ac491573dd..bf3353e874bc87 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -61,10 +61,12 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/debug/debug.h" +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/metrics/converter_error_data.pb.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" @@ -96,7 +98,6 @@ limitations under the License. #include "tensorflow/core/ir/types/dialect.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" -#include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" @@ -309,13 +310,13 @@ absl::Status ApplyDynamicRangeQuantizationFromOldQuantizer( reinterpret_cast(translated_result.c_str()); const ::tflite::Model* input_model = ::tflite::GetModel(buffer); - ::tflite::optimize::BufferType quantized_type; + mlir::lite::toco_legacy::BufferType quantized_type; switch (quant_specs.inference_type) { case DT_QINT8: - quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8; + quantized_type = mlir::lite::toco_legacy::BufferType::QUANTIZED_INT8; break; case DT_HALF: - quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16; + quantized_type = mlir::lite::toco_legacy::BufferType::QUANTIZED_FLOAT16; break; default: return absl::InvalidArgumentError("Quantized type not supported"); @@ -323,9 +324,10 @@ absl::Status ApplyDynamicRangeQuantizationFromOldQuantizer( } bool use_updated_hybrid_scheme = !quant_specs.disable_per_channel; - absl::Status quantize_weights_status = ::tflite::optimize::QuantizeWeights( - &q_builder, input_model, quantized_type, use_updated_hybrid_scheme, - ::tflite::optimize::QuantizerType::OLD_QUANTIZER); + absl::Status quantize_weights_status = + mlir::lite::toco_legacy::QuantizeWeights( + &q_builder, input_model, quantized_type, use_updated_hybrid_scheme, + mlir::lite::toco_legacy::QuantizerType::OLD_QUANTIZER); if (!quantize_weights_status.ok()) return quantize_weights_status; const uint8_t* q_buffer = q_builder.GetBufferPointer(); *result = diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 841b81eddec8bc..b36fe6b55bbd93 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -16,6 +16,8 @@ limitations under the License. // This transformation pass takes operations in TensorFlowLite dialect and // optimizes them to resulting operations in TensorFlowLite dialect. +#include "tensorflow/compiler/mlir/lite/transforms/optimize.h" + #include #include #include @@ -54,7 +56,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" @@ -70,8 +71,6 @@ namespace TFL { //===----------------------------------------------------------------------===// // The actual Optimize Pass. namespace { -#define GEN_PASS_DEF_OPTIMIZEPASS -#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" constexpr char kRelu[] = "RELU"; constexpr char kRelu6[] = "RELU6"; @@ -236,27 +235,6 @@ bool HasSameStridedShape(TFL::Conv3DOp op, ArrayRef pre_pad_shape) { using ::llvm::cast; -// Optimize TFLite operations in functions. -class OptimizePass : public impl::OptimizePassBase { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass) - - OptimizePass() = default; - OptimizePass(const OptimizePass &) {} - explicit OptimizePass(bool enable_canonicalization, - bool disable_fuse_mul_and_fc = false) { - this->enable_canonicalization_ = enable_canonicalization; - this->disable_fuse_mul_and_fc_ = disable_fuse_mul_and_fc; - } - - explicit OptimizePass(const OptimizePassOptions &options) { - this->enable_canonicalization_ = options.enable_canonicalization_; - this->disable_fuse_mul_and_fc_ = options.disable_fuse_mul_and_fc_; - } - - void runOnOperation() override; -}; - // Return true if the product of dimension values of a subsection of the // tensor is equal to the non-contracting dimension after a reshape bool BroadcastDimsProductEqual(Value input, Value output, @@ -2636,6 +2614,7 @@ void AddCanonicalizationPatterns(MLIRContext *context, for (auto op : context->getRegisteredOperations()) op.getCanonicalizationPatterns(*patterns, context); } +} // namespace void OptimizePass::runOnOperation() { RewritePatternSet patterns(&getContext()); @@ -2692,14 +2671,6 @@ void OptimizePass::runOnOperation() { AddCanonicalizationPatterns(ctx, &phase_2_patterns); (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns)); } -} // namespace - -// Creates an instance of the TensorFlow Lite dialect Optimize pass. -std::unique_ptr> CreateOptimizePass( - bool enable_canonicalization, bool disable_fuse_mul_and_fc) { - return std::make_unique(enable_canonicalization, - disable_fuse_mul_and_fc); -} // Creates an instance of the TensorFlow Lite dialect Optimize pass. std::unique_ptr> CreateOptimizePass( diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.h b/tensorflow/compiler/mlir/lite/transforms/optimize.h new file mode 100644 index 00000000000000..477d2d23d7ad07 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.h @@ -0,0 +1,95 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_H_ + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { + +struct OptimizePassOptions { + bool enable_canonicalization = true; + bool disable_fuse_mul_and_fc = false; +}; + +class OptimizePass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass) + + OptimizePass() = default; + OptimizePass(const OptimizePass &) {} + explicit OptimizePass(bool enable_canonicalization, + bool disable_fuse_mul_and_fc = false) { + this->enable_canonicalization_ = enable_canonicalization; + this->disable_fuse_mul_and_fc_ = disable_fuse_mul_and_fc; + } + + explicit OptimizePass(const OptimizePassOptions &options) { + this->enable_canonicalization_ = options.enable_canonicalization; + this->disable_fuse_mul_and_fc_ = options.disable_fuse_mul_and_fc; + } + + void runOnOperation() final; + + /// Returns the command-line argument attached to this pass. + static constexpr llvm::StringLiteral getArgumentName() { + return llvm::StringLiteral("tfl-optimize"); + } + llvm::StringRef getArgument() const final { return "tfl-optimize"; } + + llvm::StringRef getDescription() const final { + return "Optimize within the TensorFlow Lite dialect"; + } + + /// Returns the derived pass name. + static constexpr llvm::StringLiteral getPassName() { + return llvm::StringLiteral("OptimizePass"); + } + llvm::StringRef getName() const final { return "OptimizePass"; } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(mlir::DialectRegistry ®istry) const final { + registry.insert(); + } + + private: + mlir::Pass::Option enable_canonicalization_{ + *this, "enable-canonicalization", + llvm::cl::desc("Enable canonicalization during optimization pass."), + llvm::cl::init(true)}; + mlir::Pass::Option disable_fuse_mul_and_fc_{ + *this, "disable-fuse-mul-and-fc", + llvm::cl::desc("Disable folding mul and fully connected ops during " + "optimization pass."), + llvm::cl::init(false)}; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 91b0ec519c59e2..a11b20000222ad 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -21,6 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/optimize.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" namespace mlir { @@ -60,8 +62,6 @@ std::unique_ptr> CreateLegalizeTFPass( std::unique_ptr> CreateLegalizeTFPass(); // Creates an instance of the TensorFlow Lite dialect Optimize pass. -std::unique_ptr> CreateOptimizePass( - bool enable_canonicalization, bool disable_fuse_mul_and_fc = false); std::unique_ptr> CreateOptimizePass(); // Creates an instance of the Tensorflow Lite batch matmul Optimize pass. @@ -287,6 +287,17 @@ std::unique_ptr> CreateRaiseCustomOpsPass( // quantization parameters. std::unique_ptr> CreateDefaultQuantParamsPass( const DefaultQuantParamsPassOptions& options); + +inline void registerOptimizePass() { + mlir::registerPass( + []() -> std::unique_ptr<::mlir::Pass> { return CreateOptimizePass(); }); +} + +inline void registerTensorFlowLitePasses() { + registerTensorFlowLiteTdPasses(); + registerOptimizePass(); +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index d0d2f3158f3a60..8db083175a2226 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -180,20 +180,6 @@ def ModifyIONodesPass : Pass<"tfl-modify-io-nodes", "mlir::func::FuncOp"> { ]; } -def OptimizePass : Pass<"tfl-optimize", "mlir::func::FuncOp"> { - let summary = "Optimize within the TensorFlow Lite dialect"; - let constructor = "CreateOptimizePass()"; - let dependentDialects = ["TFL::TensorFlowLiteDialect"]; - let options = [ - Option<"enable_canonicalization_", "enable-canonicalization", - "bool", "false", - "Enable canonicalization during optimization pass.">, - Option<"disable_fuse_mul_and_fc_", "disable-fuse-mul-and-fc", - "bool", "false", - "Disable folding mul and fully connected ops during optimization pass.">, - ]; -} - def OptimizeBatchMatmulPass : Pass<"tfl-optimize-batch-matmul", "mlir::func::FuncOp"> { let summary = "Optimize FC with BatchMatmul within the TensorFlow Lite dialect"; let constructor = "CreateOptimizeBatchMatmulPass()"; diff --git a/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc b/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc index 7a8b35e4be7cde..f01d8aa737d2e0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc +++ b/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc @@ -72,7 +72,7 @@ llvm::SmallVector PermuteShape(llvm::ArrayRef shape, // Determine if op commutes with transposes. Requires a strict // definition of Elementwise, all i/o shapes and types must be same-rank -// broadcastable and fully static. Consider moving this into attribute later. +// broadcastable. Consider moving this into attribute later. bool IsElementwise(Operation *op) { if (!(llvm::isa(op))) { @@ -90,11 +90,6 @@ bool IsElementwise(Operation *op) { return false; } - if (!opr1_type.hasStaticShape() && opr2_type.hasStaticShape() && - res_type.hasStaticShape()) { - return false; - } - return true; } diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index cf8d4487f2f593..8d9802deaeaa66 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -360,7 +360,7 @@ Status MlirFunctionOptimizationPass::Run( timings.Reset({kTfMlirCategory, "convert_mlir_to_graph"}); // Some or all passes are enabled. Convert MLIR module and return back // resulted graph. - Status status = tensorflow::tf2xla::v2::ConvertMlirToGraph( + Status status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( *module_ref, export_config, graph, flib_def, &control_ret_nodes); if (!status.ok()) { errors::AppendToMessage(&status, @@ -476,10 +476,11 @@ Status MlirV1CompatGraphOptimizationPass::Run( GraphExportConfig export_config; absl::flat_hash_set control_ret_nodes; - TF_RETURN_WITH_CONTEXT_IF_ERROR(tensorflow::tf2xla::v2::ConvertMlirToGraph( - *module_ref, export_config, options.graph, - options.flib_def, &control_ret_nodes), - "Error converting MLIR module back to graph"); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( + *module_ref, export_config, options.graph, options.flib_def, + &control_ret_nodes), + "Error converting MLIR module back to graph"); return absl::OkStatus(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc index d6859b7b95c84e..491fcb9f5e7946 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc @@ -240,7 +240,7 @@ absl::StatusOr ConvertMlirModuleToExportedModel( FunctionDefLibrary()}; std::unique_ptr graph; absl::flat_hash_set control_ret_nodes{}; - TF_RETURN_IF_ERROR(tensorflow::tf2xla::v2::ConvertMlirToGraph( + TF_RETURN_IF_ERROR(tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module_op, config, &graph, &flib_def, &control_ret_nodes)); GraphDef graph_def{}; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td index 75940a24cf484f..5f6449dbfa03bb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td @@ -30,14 +30,15 @@ include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" def LiftDotGeneralWithBiasSameShape : Pat< (StableHLO_AddOp:$res (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), $bias), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>; def LiftConvWithBiasSameShape : Pat< @@ -86,14 +87,15 @@ def LiftConvWithBias : Pat< def LiftDotGeneralWithBias : Pat< (StableHLO_AddOp:$res (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_BroadcastInDimOp $bias, $dims)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>; def LiftConvWithBiasDynamic : Pat< @@ -121,7 +123,7 @@ def LiftConvWithBiasDynamic : Pat< def LiftDotGeneralWithBiasDynamic : Pat< (StableHLO_AddOp:$res - (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_DynamicBroadcastInDimOp $bias, (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), @@ -130,7 +132,8 @@ def LiftDotGeneralWithBiasDynamic : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 10)>; //===----------------------------------------------------------------------===// @@ -161,14 +164,15 @@ def LiftConvWithRelu : Pat< def LiftDotGeneralWithRelu : Pat< (StableHLO_MaxOp:$res (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_ConstantOp $cst)), (LiftAsTFXlaCallModule<"composite_dot_general_with_relu_fn"> (ArgumentList $lhs, $rhs), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst)], [], (addBenefit 10)>; @@ -198,7 +202,7 @@ def LiftConvWithReluDynamic : Pat< def LiftDotGeneralWithReluDynamic : Pat< (StableHLO_MaxOp:$res - (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_DynamicBroadcastInDimOp (StableHLO_ConstantOp $cst), (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), @@ -207,7 +211,8 @@ def LiftDotGeneralWithReluDynamic : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 15)>; @@ -237,14 +242,15 @@ def LiftDotGeneralWithRelu6 : Pat< (StableHLO_ClampOp:$res (StableHLO_ConstantOp $cst_0), (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_dot_general_with_relu6_fn"> (ArgumentList $lhs, $rhs), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>; //===----------------------------------------------------------------------===// @@ -255,7 +261,7 @@ def LiftDotGeneralWithBiasSameShapeAndRelu : Pat< (StableHLO_MaxOp:$res (StableHLO_AddOp (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), $bias), (StableHLO_ConstantOp $cst)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_and_relu_fn"> @@ -263,7 +269,8 @@ def LiftDotGeneralWithBiasSameShapeAndRelu : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; @@ -320,7 +327,7 @@ def LiftDotGeneralWithBiasAndRelu : Pat< (StableHLO_MaxOp:$res (StableHLO_AddOp (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_BroadcastInDimOp $bias, $dims)), (StableHLO_ConstantOp $cst)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu_fn"> @@ -328,7 +335,8 @@ def LiftDotGeneralWithBiasAndRelu : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; @@ -363,7 +371,7 @@ def LiftConvWithBiasAndReluDynamic : Pat< def LiftDotGeneralWithBiasAndReluDynamic : Pat< (StableHLO_MaxOp:$res (StableHLO_AddOp:$add_0 - (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_DynamicBroadcastInDimOp $bias, (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), @@ -375,7 +383,8 @@ def LiftDotGeneralWithBiasAndReluDynamic : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias), (AreTheSameValue $dot_general_0, $dot_general_1), (AreTheSameValue $add_0, $add_1)], [], (addBenefit 15)>; @@ -384,7 +393,7 @@ def LiftDotGeneralWithBiasSameShapeAndRelu6 : Pat< (StableHLO_ConstantOp $cst_0), (StableHLO_AddOp (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), $bias), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_and_relu6_fn"> @@ -392,7 +401,8 @@ def LiftDotGeneralWithBiasSameShapeAndRelu6 : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>; def LiftConvWithBiasAndRelu6 : Pat< @@ -424,7 +434,7 @@ def LiftDotGeneralWithBiasAndRelu6 : Pat< (StableHLO_ConstantOp $cst_0), (StableHLO_AddOp (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_BroadcastInDimOp $bias, $dims)), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu6_fn"> @@ -432,7 +442,8 @@ def LiftDotGeneralWithBiasAndRelu6 : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>; def LiftConvWithBiasAndRelu6Dynamic : Pat< @@ -466,7 +477,7 @@ def LiftDotGeneralWithBiasAndRelu6Dynamic : Pat< (StableHLO_ConstantOp $cst_0), (StableHLO_AddOp (StableHLO_DotGeneralOp:$dot_general_0 - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_DynamicBroadcastInDimOp $bias, (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), @@ -476,5 +487,6 @@ def LiftDotGeneralWithBiasAndRelu6Dynamic : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 15)>; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td index eaa8a9092f41f2..db0103fea2b7e5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td @@ -47,13 +47,14 @@ def LiftConv : Pat< def LiftDotGeneral : Pat< (StableHLO_DotGeneralOp:$res - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (LiftAsTFXlaCallModule<"composite_dot_general_fn"> (ArgumentList $lhs, $rhs), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; def LiftGather : Pat< diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 0999d37da524c2..5ca03bfc209656 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -46,10 +46,7 @@ pytype_strict_library( # testonly = 1, # srcs = ["integration_test/quantize_model_test_base.py"], # tags = ["no_pip"], -# visibility = [ -# "//learning/brain/mlir/quantization/stablehlo:__subpackages__", -# "//tensorflow/compiler/mlir/quantization:__subpackages__", -# ], +# visibility = ["//visibility:private"], # deps = [ # "//third_party/py/mlir:ir", # "//third_party/py/mlir:stablehlo_dialect", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 24f3215c003e7e..cbde97456aca43 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1137,7 +1137,7 @@ to be batched.}]>:$captured_tensors, DefaultValuedOptionalAttr:$low_priority_allowed_batch_sizes, DefaultValuedOptionalAttr:$low_priority_max_enqueued_batches, DefaultValuedOptionalAttr, "\"low_priority_padding_with_max_batch_size\"">:$mixed_priority_policy, - DefaultValuedOptionalAttr, "\"PAD_UP\"">:$batch_padding_policy, + DefaultValuedOptionalAttr, "\"PAD_UP\"">:$batch_padding_policy, DefaultValuedOptionalAttr:$enable_large_batch_splitting ); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index f7bac4ba31b50a..b6e8e1c9b9ca07 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -974,7 +974,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: func @call_in_graph_1 func.func @call_in_graph_1(%arg0: tensor, %arg1: tensor<5x5x1x32xbf16>) -> tensor<*xbf16> { - // CHECK: tf_executor.fetch %outputs : tensor + // CHECK: tf_executor.fetch %outputs : tensor %0 = tf_executor.graph { %1:2 = tf_executor.island wraps "tf.PartitionedCall"(%arg0, %arg1) { config = "", config_proto = "", executor_type = "", f = @call_in_graph_func_1} : (tensor, tensor<5x5x1x32xbf16>) -> tensor<*xbf16> @@ -985,7 +985,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: func @call_in_graph_func_1 func.func @call_in_graph_func_1(%arg0: tensor, %arg1: tensor<5x5x1x32xbf16>) -> tensor { - // CHECK: tf_executor.fetch %outputs : tensor + // CHECK: tf_executor.fetch %outputs : tensor %0 = tf_executor.graph { %1:2 = tf_executor.island wraps "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}: (tensor, tensor<5x5x1x32xbf16>) -> tensor tf_executor.fetch %1#0 : tensor @@ -2265,4 +2265,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %3#1, %3#2, %4, %5 : tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32> } + // CHCK-LABEL: func @infer_return_type_static_out + func.func @infer_return_type_static_out(%arg0: tensor, %arg1: tensor) -> tensor<1x28x28x3xf32> { + %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}: (tensor, tensor) -> tensor<1x28x28x3xf32> + func.return %0 : tensor<1x28x28x3xf32> + } + + // CHCK: %0 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor, tensor) -> tensor<1x28x28x3xf32> + + } + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index d9110a2bff1fba..89e00142aa4a8f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -126,9 +126,9 @@ MLIRContext::Threading GetMlirContextThreading() { } // Compute a refined type between two types `lhs` and `rhs`, the result type -// is always more refined (i.e. has more static information) than `lhs` -// This method will actually merge the information contained in the -// types, it is capable of refining: +// is always at least as refined as (i.e. has more static information) than +// `lhs` This method will actually merge the information contained in the types, +// it is capable of refining: // tensor>> // and: // tensor>> @@ -2329,12 +2329,15 @@ bool ShapeInference::RefineWithInferTypeOpInterface( // Map each of the results of the call to the returned type of the // function. bool changed = false; - for (auto result : zip(op->getResults(), inferred)) { - if (std::get<0>(result).getType() == std::get<1>(result)) continue; - - if (!UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result), - std::get<0>(result))) + for (auto [result, inferred_type] : zip(op->getResults(), inferred)) { + auto result_type = result.getType(); + auto new_type = TypeMeet(inferred_type, result_type); + if (new_type == result_type) { + continue; + } + if (!UpdateTypeAndInsertIncompatibleUseCasts(new_type, result)) { continue; + } changed = true; } return changed; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 3735199d8a33c8..1ccfc8775d1c44 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -75,7 +75,7 @@ void GraphOptPass::runOnOperation() { GraphExportConfig confs; auto graph = std::make_unique(flib_def); absl::flat_hash_set control_ret_nodes; - Status status = tensorflow::tf2xla::v2::ConvertMlirToGraph( + Status status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module_in, confs, &graph, &flib_def, &control_ret_nodes); if (!status.ok()) { mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.message(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD index ad9befdfe5fb28..5ffc11344ba3ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -254,3 +254,44 @@ cc_library( "@llvm-project//llvm:Support", ], ) + +cc_library( + name = "node_order", + srcs = ["node_order.cc"], + hdrs = ["node_order.h"], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "node_order_test", + size = "small", + srcs = [ + "node_order_test.cc", + ], + deps = [ + ":node_order", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order.cc b/tensorflow/compiler/mlir/tensorflow/translate/node_order.cc new file mode 100644 index 00000000000000..58a2751cb4b2a5 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/node_order.cc @@ -0,0 +1,108 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/translate/node_order.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +void TopologicalOrdering( + const Graph& g, const std::function& emit, + const std::function& get_grouping_key) { + std::unordered_map group_key_string_to_integer; + absl::flat_hash_map node_to_group; + absl::flat_hash_map remaining_incoming_nodes; + using Ready = std::vector; + std::vector group_members_that_are_ready; + std::set groups_that_are_ready; + + // Visit all nodes once, for initialization. It doesn't matter whether we use + // BFS or DFS. + DFS( + g, [](Node*) {}, + [&](Node* n) { + // Find which group this node belongs to. + std::string group_key_string = get_grouping_key(n); + auto entry = group_key_string_to_integer.try_emplace( + group_key_string, group_key_string_to_integer.size()); + int group_key = entry.first->second; + node_to_group[n] = group_key; + if (!entry.second) { + group_members_that_are_ready.push_back({}); + } + + // Count the incoming nodes and store. Also remember nodes ("sources") + // that don't have any inputs. + auto in_nodes = n->in_nodes(); + int num_incoming = std::distance(in_nodes.begin(), in_nodes.end()); + remaining_incoming_nodes[n] = num_incoming; + if (num_incoming == 0) { + // NO_CDC: This array is max(group_key) + 1. + group_members_that_are_ready[group_key].push_back(n); + groups_that_are_ready.emplace(group_key); + } + }); + + int num_nodes = remaining_incoming_nodes.size(); + + // We emit one node per step, thus we just run this as often as we have nodes. + int current_group = 0; + for (int i = 0; i < num_nodes; i++) { + if (groups_that_are_ready.find(current_group) == + groups_that_are_ready.end()) { + current_group = *groups_that_are_ready.begin(); + } + + // NO_CDC: This array is max(group_key) + 1. + int size = group_members_that_are_ready[current_group].size(); + assert(size); + // NO_CDC: This array is max(group_key) + 1. + Node* node = group_members_that_are_ready[current_group][--size]; + // NO_CDC: This array is max(group_key) + 1. + group_members_that_are_ready[current_group].pop_back(); + if (size == 0) { + groups_that_are_ready.erase(current_group); + } + + // Emit the operation and make its results available. + emit(node); + + for (Node* out : node->out_nodes()) { + remaining_incoming_nodes[out]--; + if (remaining_incoming_nodes[out] == 0) { + int group_key = node_to_group[out]; + // NO_CDC: This array is max(group_key) + 1. + if (group_members_that_are_ready[group_key].empty()) { + groups_that_are_ready.emplace(group_key); + } + // NO_CDC: This array is max(group_key) + 1. + group_members_that_are_ready[group_key].push_back(out); + } + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order.h b/tensorflow/compiler/mlir/tensorflow/translate/node_order.h new file mode 100644 index 00000000000000..4cb8e75efa7613 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/node_order.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +struct GroupByDevice { + std::string operator()(const Node* node) const { + return node->requested_device(); + } +}; + +// Performs a topological ordering of nodes. +// This has the property that any child node of a parent node p is emitted +// before p. A grouping function is used to break ties if multiple child nodes +// (of possibly different parents) are ready to be emitted at some point, which +// is when we prefer to stay in the current group. +// The "emit" function is used for outputing the result, and is called once +// for each node. +// This algorithm is O(n). +void TopologicalOrdering( + const Graph& g, const std::function& emit, + const std::function& get_grouping_key); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc b/tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc new file mode 100644 index 00000000000000..fc1d6e177f1dcd --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc @@ -0,0 +1,239 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/translate/node_order.h" + +#include +#include +#include + +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/graph_def_builder_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_OP("TestUnary").Input("a: float").Output("o: float"); +REGISTER_OP("TestTwoOutputs").Output("a: float").Output("b: float"); +REGISTER_OP("TestBinary") + .Input("a: float") + .Input("b: float") + .Output("o: float"); + +// Compares that the order of nodes in 'inputs' respects the +// pair orders described in 'ordered_pairs'. +bool ExpectBefore(const std::vector>& ordered_pairs, + const std::vector& inputs, string* error) { + for (const std::pair& pair : ordered_pairs) { + const string& before_node = pair.first; + const string& after_node = pair.second; + bool seen_before = false; + bool seen_both = false; + for (const Node* node : inputs) { + if (!seen_before && after_node == node->name()) { + *error = std::string("Saw ") + after_node + std::string(" before ") + + before_node; + return false; + } + + if (before_node == node->name()) { + seen_before = true; + } else if (after_node == node->name()) { + seen_both = seen_before; + break; + } + } + if (!seen_both) { + *error = std::string("didn't see either ") + before_node + + std::string(" or ") + after_node; + return false; + } + } + + return true; +} + +TEST(AlgorithmTest, TopologicalOrdering) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + Node* n1 = SourceOp("TestParams", b.opts().WithName("n1")); + Node* n2 = + SourceOp("TestParams", b.opts().WithName("n2").WithControlInput(n1)); + Node* n3 = + SourceOp("TestParams", b.opts().WithName("n3").WithControlInput(n2)); + Node* n4 = BinaryOp("TestMul", n1, {n3, 0}, b.opts().WithName("n4")); + Node* n5 = BinaryOp("TestMul", n1, {n3, 0}, + b.opts().WithName("n5").WithControlInput(n1)); + Node* n6 = BinaryOp("TestMul", n2, {n3, 0}, b.opts().WithName("n6")); + n3->set_requested_device("a"); + n4->set_requested_device("a"); + n5->set_requested_device("b"); + n6->set_requested_device("b"); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order; + + TopologicalOrdering(g, [&](Node* n) { order.push_back(n); }, GroupByDevice()); + + std::vector> desired_order = { + {"n1", "n2"}, // because of control dependency + {"n2", "n3"}, // because of control dependency + {"n3", "n4"}, // because of NodeScorerDevice + {"n1", "n4"}, // data dependency + {"n1", "n5"}, // data dependency + {"n2", "n6"}, // data dependency + {"n3", "n4"}, // data dependency + {"n3", "n5"}, // data dependency + {"n3", "n6"}, // data dependency + }; + string error; + EXPECT_TRUE(ExpectBefore(desired_order, order, &error)) << error; +} + +TEST(AlgorithmTest, TopologicalOrderingOnShallowTree) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + Node* n1 = SourceOp("TestParams", b.opts().WithName("n1").WithDevice("a")); + Node* n2 = + SourceOp("TestParams", + b.opts().WithName("n2").WithDevice("b").WithControlInput(n1)); + Node* n3 = + SourceOp("TestParams", + b.opts().WithName("n3").WithDevice("c").WithControlInput(n2)); + Node* n4 = + SourceOp("TestParams", + b.opts().WithName("n4").WithDevice("a").WithControlInput(n1)); + Node* n5 = + SourceOp("TestParams", + b.opts().WithName("n5").WithDevice("b").WithControlInput(n2)); + Node* n6 = + SourceOp("TestParams", + b.opts().WithName("n6").WithDevice("c").WithControlInput(n3)); + Node* n7 = + SourceOp("TestParams", + b.opts().WithName("n7").WithDevice("a").WithControlInput(n4)); + Node* n8 = + SourceOp("TestParams", + b.opts().WithName("n8").WithDevice("b").WithControlInput(n5)); + Node* n9 = + SourceOp("TestParams", + b.opts().WithName("n9").WithDevice("c").WithControlInput(n6)); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order; + + TopologicalOrdering(g, [&](Node* n) { order.push_back(n); }, GroupByDevice()); + + std::vector desired_order = { + g.source_node(), n1, n4, n7, n2, n5, n8, n3, n6, n9, g.sink_node()}; + for (int i = 0; i < desired_order.size(); i++) { + desired_order[i] = g.FindNodeId(desired_order[i]->id()); + } + EXPECT_EQ(order, desired_order); +} + +TEST(AlgorithmTest, TopologicalOrderingGivesTheSameResultIfCalledTwice) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + SourceOp("TestParams", b.opts().WithName("n1")); + SourceOp("TestParams", b.opts().WithName("n2")); + SourceOp("TestParams", b.opts().WithName("n3")); + SourceOp("TestParams", b.opts().WithName("n4")); + SourceOp("TestParams", b.opts().WithName("n5")); + SourceOp("TestParams", b.opts().WithName("n6")); + SourceOp("TestParams", b.opts().WithName("n7")); + SourceOp("TestParams", b.opts().WithName("n8")); + SourceOp("TestParams", b.opts().WithName("n9")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order1; + std::vector order2; + + TopologicalOrdering( + g, [&](Node* n) { order1.push_back(n); }, + [&](const Node* node) { return std::string("same"); }); + + TopologicalOrdering( + g, [&](Node* n) { order2.push_back(n); }, + [&](const Node* node) { return std::string("same"); }); + + EXPECT_EQ(order1, order2); +} + +TEST(AlgorithmTest, TopologicalOrderingOnChain) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + Node* n1 = SourceOp("TestParams", b.opts().WithName("n1")); + Node* n2 = UnaryOp("TestUnary", n1, b.opts().WithName("n2")); + Node* n3 = UnaryOp("TestUnary", n2, b.opts().WithName("n3")); + Node* n4 = UnaryOp("TestUnary", n3, b.opts().WithName("n4")); + Node* n5 = UnaryOp("TestUnary", n4, b.opts().WithName("n5")); + Node* n6 = UnaryOp("TestUnary", n5, b.opts().WithName("n6")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order; + TopologicalOrdering(g, [&](Node* n) { order.push_back(n); }, GroupByDevice()); + + std::vector desired_order = {g.source_node(), n1, n2, n3, n4, n5, n6, + g.sink_node()}; + for (int i = 0; i < desired_order.size(); i++) { + desired_order[i] = g.FindNodeId(desired_order[i]->id()); + } + EXPECT_EQ(order, desired_order); +} + +TEST(AlgorithmTest, TopologicalOrderingOnMultipleOutputs) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + Node* n1 = SourceOp("TestTwoOutputs", b.opts().WithName("n1")); + UnaryOp("TestUnary", {n1, 0}, b.opts().WithName("n2")); + UnaryOp("TestUnary", {n1, 1}, b.opts().WithName("n3")); + UnaryOp("TestUnary", {n1, 0}, b.opts().WithName("n4")); + UnaryOp("TestUnary", {n1, 1}, b.opts().WithName("n5")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order; + TopologicalOrdering(g, [&](Node* n) { order.push_back(n); }, GroupByDevice()); + + std::vector> desired_order = { + {"n1", "n2"}, + {"n1", "n3"}, + {"n1", "n4"}, + {"n1", "n5"}, + }; + string error; + EXPECT_TRUE(ExpectBefore(desired_order, order, &error)) << error; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index ab73156f29c4b3..92ecf3082588ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -130,7 +130,7 @@ static LogicalResult MlirToGraphTranslateFunction(ModuleOp module, auto graph = std::make_unique(tensorflow::OpRegistry::Global()); absl::flat_hash_set control_ret_nodes; - auto status = tensorflow::tf2xla::v2::ConvertMlirToGraph( + auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module, confs, &graph, flib_def.get(), &control_ret_nodes); if (!status.ok()) { LOG(ERROR) << "Export to Graph failed: " << status; @@ -179,7 +179,7 @@ static LogicalResult MlirToGraphdefTranslateFunction( std::make_unique(tensorflow::OpRegistry::Global()); absl::flat_hash_set control_ret_nodes; - auto status = tensorflow::tf2xla::v2::ConvertMlirToGraph( + auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module, confs, &graph, &flib_def, &control_ret_nodes); if (!status.ok()) { LOG(ERROR) << "Export to Graph failed: " << status; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index 9fc2207d83dee8..265f83caf03727 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -201,7 +201,7 @@ Status PrepareAndExportToLibrary(mlir::ModuleOp module, GraphExportConfig config; config.export_entry_func_to_flib = true; absl::flat_hash_set control_ret_nodes; - return tensorflow::tf2xla::v2::ConvertMlirToGraph( + return tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module, config, /*graph=*/nullptr, flib_def, &control_ret_nodes); } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.cc index b0a770802ac34d..7645125770fae8 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.cc @@ -796,11 +796,11 @@ Status Exporter::Convert(mlir::ModuleOp module, } // namespace -Status ConvertMlirToGraph(mlir::ModuleOp module, - const GraphExportConfig& configs, - std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, - absl::flat_hash_set* control_ret_nodes) { +Status ConvertTfExecutorToGraph(mlir::ModuleOp module, + const GraphExportConfig& configs, + std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + absl::flat_hash_set* control_ret_nodes) { mlir::StatusScopedDiagnosticHandler sh(module.getContext()); if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus(); return sh.Combine( diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h index 7ee67aa221a91b..bd59770e8164fb 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h @@ -37,11 +37,11 @@ namespace v2 { // The "main" function of the module is stored in the graph and the rest of // functions are stored in the library. Control ret nodes are stored separately // in `control_ret_nodes`. -Status ConvertMlirToGraph(mlir::ModuleOp module, - const GraphExportConfig& configs, - std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, - absl::flat_hash_set* control_ret_nodes); +Status ConvertTfExecutorToGraph(mlir::ModuleOp module, + const GraphExportConfig& configs, + std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + absl::flat_hash_set* control_ret_nodes); // Converts an MLIR function and adds it to a FunctionLibraryDefinition. Status ConvertMlirFunctionToFunctionLibraryDef(mlir::func::FuncOp func, diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index dca10693d74b01..9f45164ba4dfe3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -3145,7 +3145,8 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { // (The batch dimensions are checked by the broadcasting logic) rewriter.replaceOpWithNewOp( op, op.getType(), lhs, rhs, dimension_numbers, - /*precision_config=*/GetPrecisionConfig(&rewriter)); + /*precision_config=*/GetPrecisionConfig(&rewriter), + /*algorithm=*/DotAlgorithmAttr{}); return success(); } }; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 401d1e8b954e40..185216448a15ed 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -60,6 +60,8 @@ def CastElementsToI64Elements : NativeCodeCall< "hlo::convertElementsAttr(" "$0.cast(), $_builder.getIntegerType(64)).cast()">; +def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; + //===----------------------------------------------------------------------===// // ApproximateEqual op pattern. //===----------------------------------------------------------------------===// @@ -760,7 +762,8 @@ def HasValidPrecisionConfig : Constraint>; def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), (MHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), - (ToPrecisionConfigsAttr $precision_config)), + (ToPrecisionConfigsAttr $precision_config), + (EmptyDotAlgorithmAttr)), [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; //===----------------------------------------------------------------------===// @@ -770,7 +773,8 @@ def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), (MHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), - (ToPrecisionConfigsAttr $precision_config)), + (ToPrecisionConfigsAttr $precision_config), + (EmptyDotAlgorithmAttr)), [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir new file mode 100644 index 00000000000000..02afa969970004 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir @@ -0,0 +1,8 @@ +// RUN: tf-tfrt-opt %s -tf-device-cleanup | FileCheck %s + +// CHECK-LABEL: func @ops_with_device +func.func @ops_with_device() { + %0 = "tf.VarHandleOp"() {container = "", shared_name = "var", device = "/device/..."} : () -> tensor>> + // CHECK-NOT: device = "/device/..." + func.return +} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 80969fec73cba5..2ec0fdd9d4c215 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -69,6 +69,7 @@ cc_library( "lower_to_ifrt_restore_variable.cc", "rewrite_cluster_to_ifrt_call.cc", "sink_variable_as_named_array.cc", + "tf_device_cleanup.cc", "tf_identity_propagation.cc", "tf_ifrt_passes.cc", "tf_restore_merging.cc", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td index 7cdc5576ae5465..9c37c58c0e37ba 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td @@ -129,3 +129,14 @@ def TfIdentityPropagationPass let constructor = "CreateTfIdentityPropagationPass()"; } +def TfDeviceCleanupPass : Pass<"tf-device-cleanup", "mlir::func::FuncOp"> { + let summary = "Cleans up device attributes from all ops"; + + let description = [{ + This pass removes `device` attributes from all TF ops. Some Serving + doesn't rely on `device` attributes from SavedModel. + }]; + + let constructor = "CreateTfDeviceCleanupPass()"; +} + diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc new file mode 100644 index 00000000000000..b40c94e6a1de07 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_TFDEVICECLEANUPPASS +#define GEN_PASS_DECL_TFDEVICECLEANUPPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +class TfDeviceCleanupPass + : public impl::TfDeviceCleanupPassBase { + public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + func.walk([](mlir::Operation* op) { + if (llvm::isa(op->getDialect())) { + op->removeAttr("device"); + } + }); + } +}; + +} // namespace + +std::unique_ptr> +CreateTfDeviceCleanupPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc index 2802cb5a94503a..6d49f9a06141c9 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc @@ -81,6 +81,10 @@ void AddClusterToIfrtRuntimeOpsPassPipeline(OpPassManager& pm, pm.addPass(CreateRewriteClusterToIfrtCallPass()); + // After device program is extracted, we can clean up device attributes from + // all ops. + pm.addNestedPass(CreateTfDeviceCleanupPass()); + // Sink VarHandle with ReadVariableOp: subsequent SinkVariableAsNamedArrayPass // rely on the co-existence of VarHandle and ReadVariable in the same // function. diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h index 93713fbdc13646..92d9b06dc6765a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h @@ -57,6 +57,10 @@ CreateTfRestorePruningPass(); std::unique_ptr> CreateLowerToIfrtRestoreVariablePass(); +// Creates a pass that cleans up device attributes from all ops. +std::unique_ptr> +CreateTfDeviceCleanupPass(); + #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index 51abf57bc00951..2b97ec6a9536ac 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -100,14 +100,14 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( AddTfDeviceAssignmentPasses(pm, options); // After the standard pass, we now have MLIR in TF dialect, and now we convert - // reference variable to resource variables, which is besteffort. + // reference variable to resource variables, which is best effort. pm.addPass(CreateConvertReferenceVariableToResourceVariablePass()); // Move the tf.Assert op to the end of the function, so that it does not // impose unnecessary control dependencies on other ops. pm.addPass(tfrt_compiler::CreateReorderTfAssertPass()); - // Optimze the side-effects of control flow ops by examining the ops in its + // Optimize the side-effects of control flow ops by examining the ops in its // callees. pm.addPass(tfrt_compiler::CreateOptimizeTfControlFlowSideEffectPass()); @@ -117,10 +117,11 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( // Merge non-side-effecting tf.If ops if their operands are the same. pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass()); - // Lower bound on the number of batch threads in `tf.BatchFunction`. - pm.addPass(tfrt_compiler::CreateReconfigBatchOpPass( - {.min_num_batch_threads = options.min_num_batch_threads, - .min_max_enqueued_batches = options.min_max_enqueued_batches})); + pm.addPass(tfrt_compiler::CreateReconfigBatchOpPass({ + .min_num_batch_threads = options.min_num_batch_threads, + .min_max_enqueued_batches = options.min_max_enqueued_batches, + .batch_padding_policy = options.batch_padding_policy, + })); // Deduplicate functions invoked by tf.BatchFunction with the same // shared_name diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD index cb517d1039711f..83b70c251d8bf7 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD @@ -20,11 +20,15 @@ cc_library( deps = [ "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", + "//tensorflow/core/tfrt/mlrt/bytecode:function", + "//tensorflow/core/tfrt/mlrt/bytecode:kernel", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -38,10 +42,15 @@ tf_cc_test( data = glob(["testdata/**"]), deps = [ ":mlir_to_bytecode", + "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", "//tensorflow/core/tfrt/mlrt/interpreter:attribute_span", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:resource_loader", @@ -57,10 +66,15 @@ cc_library( hdrs = ["test_utils.h"], deps = [ # copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/stubs:tfrt_native_lowering_impl", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/tfrt/graph_executor:sync_resource_state", "//tensorflow/core/tfrt/mlrt/attribute", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:kernel", @@ -70,7 +84,9 @@ cc_library( "//tensorflow/core/tfrt/stubs:tfrt_native_lowering_stub", "//tensorflow/core/tfrt/utils:tensor_util", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", "@tf_runtime//:hostcontext", + "@tf_runtime//:support", "@tf_runtime//:tensor", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc index d3b19eb3447cf7..52b1826f4a1f65 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -25,14 +25,26 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/function.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" namespace mlrt { namespace { diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h index 7f5416d230cb05..950865644effcc 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h @@ -22,9 +22,16 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" namespace mlrt { diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc index 9f02f1d3c2a531..d7d3065d847d7f 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc @@ -19,9 +19,20 @@ limitations under the License. #include #include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h" #include "tsl/platform/resource_loader.h" diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc index b5a3cb9550c558..e4f9e6f77ba2fc 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc @@ -22,10 +22,19 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/mlrt/attribute/attribute.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace mlrt { namespace testing { diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h index d569f32175f78c..6140c71149c9ee 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h @@ -21,10 +21,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h" #include "tensorflow/core/tfrt/mlrt/attribute/attribute.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" @@ -34,10 +39,13 @@ limitations under the License. #include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h" #include "tensorflow/core/tfrt/utils/tensor_util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime #include "tfrt/host_context/host_allocator.h" // from @tf_runtime #include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/support/string_util.h" // from @tf_runtime +#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime #include "tfrt/tensor/dense_tensor_utils.h" // from @tf_runtime namespace mlrt { diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index 9f71dce30675fe..69ff39c3dcf95e 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -125,7 +125,7 @@ struct TfrtCompileOptions { // For TFRT, if true, tf.While's iterations will be parallelized on a // best-effort basis. This is currently experimental. MLRT attempts to convert // tf.while to tf_mlrt.map_fn regardless of this flag. For tf.While that - // cannot be onverted tf_mlrt.map_fn, MLRT try to parallerize tf.while's + // cannot be converted tf_mlrt.map_fn, MLRT try to parallelize tf.while's // iterations on a best-effort basis. bool enable_while_parallel_iterations = false; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/BUILD index e7ac7e34840545..08ed97bfc70f3b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/BUILD @@ -14,6 +14,7 @@ glob_lit_tests( "no_rocm", ], driver = "//tensorflow/compiler/mlir:run_lit.sh", + hermetic_cuda_data_dir = "%S/../../../../../../../../cuda_nvcc", test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 0972f67e6605d0..48f973c8f472de 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -89,20 +89,35 @@ py_strict_test( ], ) +#LINT.IfChange(combined_tests) +# If you add a new tf_xla_py_strict_test please either add the test file to one of the combined test +# targets that matches in all tags and other settings or add a new combined test target. tf_xla_combined_py_test( name = "ops_test_mlir_false", size = "medium", - enable_mlir_bridge = False, package = "tensorflow.compiler.tests", python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], - test_files = [ + tests = [ # go/keep-sorted start - "adadelta_test.py", + ":adadelta_test_lib", # go/keep-sorted end ], +) +#LINT.ThenChange(:individual_tests) + +#LINT.IfChange(individual_tests) +tf_xla_py_strict_test( + name = "adadelta_test", + size = "medium", + srcs = ["adadelta_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", + ], deps = [ ":xla_test", "//tensorflow/python/framework:constant_op", @@ -118,7 +133,6 @@ tf_xla_py_strict_test( name = "adagrad_test", size = "small", srcs = ["adagrad_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -138,7 +152,6 @@ tf_xla_py_strict_test( name = "adagrad_da_test", size = "small", srcs = ["adagrad_da_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -160,7 +173,6 @@ tf_xla_py_strict_test( name = "adam_test", size = "small", srcs = ["adam_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -190,7 +202,6 @@ tf_xla_py_strict_test( # copybara:uncomment_end # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -210,7 +221,6 @@ tf_xla_py_strict_test( name = "argminmax_test", size = "small", srcs = ["argminmax_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -231,7 +241,6 @@ tf_xla_py_strict_test( name = "binary_ops_test", size = "medium", srcs = ["binary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -260,7 +269,6 @@ tf_xla_py_strict_test( name = "complex_div_test", size = "medium", srcs = ["complex_div_test.py"], - enable_mlir_bridge = True, enabled_backends = [ "cpu", "gpu", @@ -287,7 +295,6 @@ tf_xla_py_strict_test( name = "bucketize_op_test", size = "small", srcs = ["bucketize_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -307,7 +314,6 @@ tf_xla_py_strict_test( name = "categorical_op_test", size = "small", srcs = ["categorical_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -330,7 +336,6 @@ tf_xla_py_strict_test( name = "cholesky_op_test", size = "medium", srcs = ["cholesky_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -358,7 +363,6 @@ tf_xla_py_strict_test( # #TODO(b/286470564): Remove once the bug is fixed. # disable_tpu_tfrt = True, # copybara:uncomment_end - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -388,7 +392,6 @@ tf_xla_py_strict_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -409,7 +412,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["searchsorted_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -454,7 +456,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -475,7 +476,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["matrix_solve_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -496,7 +496,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["matrix_triangular_solve_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_oss", # TODO(b/295649328): fix failed nightly tests @@ -520,7 +519,6 @@ tf_xla_py_strict_test( name = "clustering_test", size = "small", srcs = ["clustering_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -541,7 +539,6 @@ tf_xla_py_strict_test( name = "concat_ops_test", size = "medium", srcs = ["concat_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -566,7 +563,6 @@ tf_xla_py_strict_test( name = "conv2d_test", size = "medium", srcs = ["conv2d_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -612,7 +608,6 @@ tf_xla_py_strict_test( name = "conv3d_test", size = "medium", srcs = ["conv3d_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -638,7 +633,6 @@ tf_xla_py_strict_test( name = "depthwise_conv_op_test", size = "medium", srcs = ["depthwise_conv_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -664,7 +658,6 @@ tf_xla_py_strict_test( name = "dynamic_slice_ops_test", size = "small", srcs = ["dynamic_slice_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -683,7 +676,6 @@ tf_xla_py_strict_test( name = "einsum_op_test", size = "medium", srcs = ["einsum_op_test.py"], - enable_mlir_bridge = True, enabled_backends = [ "cpu", "gpu", @@ -708,7 +700,6 @@ tf_xla_py_strict_test( name = "reshape_op_test", size = "small", srcs = ["reshape_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -762,7 +753,6 @@ tf_xla_py_strict_test( name = "eager_test", size = "medium", srcs = ["eager_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "multi_and_single_gpu", @@ -800,7 +790,6 @@ tf_xla_py_strict_test( name = "fifo_queue_test", size = "medium", srcs = ["fifo_queue_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -818,7 +807,6 @@ tf_xla_py_strict_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - enable_mlir_bridge = False, python_version = "PY3", shard_count = 12, tags = [ @@ -842,7 +830,6 @@ tf_xla_py_strict_test( name = "slice_ops_test", size = "medium", srcs = ["slice_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -862,7 +849,6 @@ tf_xla_py_strict_test( name = "ftrl_test", size = "medium", srcs = ["ftrl_test.py"], - enable_mlir_bridge = False, python_version = "PY3", shard_count = 8, tags = [ @@ -885,7 +871,6 @@ tf_xla_py_strict_test( name = "ftrl_ops_test", size = "medium", srcs = ["ftrl_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -905,7 +890,6 @@ tf_xla_py_strict_test( name = "function_test", size = "small", srcs = ["function_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -926,7 +910,6 @@ tf_xla_py_strict_test( size = "small", timeout = "long", srcs = ["image_ops_test.py"], - enable_mlir_bridge = False, enabled_backends = [ "cpu", "gpu", @@ -958,7 +941,6 @@ tf_xla_py_strict_test( name = "listdiff_op_test", size = "small", srcs = ["listdiff_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -977,7 +959,6 @@ tf_xla_py_strict_test( name = "lrn_ops_test", size = "medium", srcs = ["lrn_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -999,7 +980,6 @@ tf_xla_py_strict_test( name = "manip_ops_test", size = "small", srcs = ["manip_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1019,7 +999,6 @@ tf_xla_py_strict_test( size = "medium", timeout = "long", srcs = ["matrix_band_part_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_aarch64", # TODO(b/315533266) @@ -1042,7 +1021,6 @@ tf_xla_py_strict_test( size = "medium", timeout = "long", srcs = ["matrix_diag_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 8, tags = [ @@ -1061,7 +1039,6 @@ tf_xla_py_strict_test( name = "momentum_test", size = "small", srcs = ["momentum_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1083,7 +1060,6 @@ tf_xla_py_strict_test( name = "nary_ops_test", size = "small", srcs = ["nary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1103,7 +1079,6 @@ tf_xla_py_strict_test( name = "nullary_ops_test", size = "small", srcs = ["nullary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1121,7 +1096,6 @@ tf_xla_py_strict_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1146,7 +1120,6 @@ tf_xla_py_strict_test( name = "pooling_ops_3d_test", size = "medium", srcs = ["pooling_ops_3d_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1168,7 +1141,6 @@ tf_xla_py_strict_test( name = "proximal_adagrad_test", size = "medium", srcs = ["proximal_adagrad_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1189,7 +1161,6 @@ tf_xla_py_strict_test( name = "proximal_gradient_descent_test", size = "medium", srcs = ["proximal_gradient_descent_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1215,7 +1186,6 @@ tf_xla_py_strict_test( "cpu", "cpu_ondemand", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1244,7 +1214,6 @@ tf_xla_py_strict_test( "cpu", "cpu_ondemand", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 1, tags = [ @@ -1287,7 +1256,6 @@ tf_xla_py_strict_test( name = "reduce_ops_test", size = "medium", srcs = ["reduce_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1310,7 +1278,6 @@ tf_xla_py_strict_test( name = "reduce_window_test", size = "small", srcs = ["reduce_window_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1330,7 +1297,6 @@ tf_xla_py_strict_test( name = "reverse_ops_test", size = "medium", srcs = ["reverse_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1349,7 +1315,6 @@ tf_xla_py_strict_test( name = "reverse_sequence_op_test", size = "medium", srcs = ["reverse_sequence_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1369,7 +1334,6 @@ tf_xla_py_strict_test( # name = "reverse_sequence_op_args_test", # size = "medium", # srcs = ["reverse_sequence_op_args_test.py"], -# enable_mlir_bridge = False, # main = "reverse_sequence_op_args_test.py", # python_version = "PY3", # tags = [ @@ -1392,7 +1356,6 @@ tf_xla_py_strict_test( name = "rmsprop_test", size = "small", srcs = ["rmsprop_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1412,7 +1375,6 @@ tf_xla_py_strict_test( name = "scan_ops_test", size = "medium", srcs = ["scan_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -1438,7 +1400,6 @@ tf_xla_py_strict_test( name = "segment_reduction_ops_test", size = "medium", srcs = ["segment_reduction_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1459,7 +1420,6 @@ tf_xla_py_strict_test( name = "spacetobatch_op_test", size = "medium", srcs = ["spacetobatch_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 3, tags = [ @@ -1480,7 +1440,6 @@ tf_xla_py_strict_test( name = "sparse_to_dense_op_test", size = "medium", srcs = ["sparse_to_dense_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1500,7 +1459,6 @@ tf_xla_py_strict_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "config-cuda-only", @@ -1528,7 +1486,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -1566,7 +1523,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -1605,8 +1561,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - # TODO(b/232442915): Enable MLIR. - enable_mlir_bridge = False, python_version = "PY3", shard_count = 20, tags = [ @@ -1637,7 +1591,6 @@ tf_xla_py_strict_test( srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "config-cuda-only", @@ -1677,7 +1630,6 @@ tf_xla_py_strict_test( # copybara:uncomment_end # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1699,7 +1651,6 @@ tf_xla_py_strict_test( name = "ternary_ops_test", size = "medium", srcs = ["ternary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 8, tags = [ @@ -1723,7 +1674,6 @@ tf_xla_py_strict_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1751,7 +1701,6 @@ tf_xla_py_strict_test( name = "fused_batchnorm_test", size = "medium", srcs = ["fused_batchnorm_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1776,7 +1725,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["variable_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1811,7 +1759,6 @@ tf_xla_py_strict_test( # #TODO(b/291130193): Remove once the bug is fixed. # disable_tpu_tfrt = True, # copybara:uncomment_end - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1839,7 +1786,6 @@ tf_xla_py_strict_test( size = "small", srcs = ["case_test.py"], disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1907,7 +1853,6 @@ tf_xla_py_strict_test( name = "gather_nd_op_test", size = "medium", srcs = ["gather_nd_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1926,7 +1871,6 @@ tf_xla_py_strict_test( name = "scatter_nd_op_test", size = "medium", srcs = ["scatter_nd_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1946,7 +1890,6 @@ tf_xla_py_strict_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 50, # Times out in fastbuild mode. @@ -1977,7 +1920,6 @@ tf_xla_py_strict_test( name = "data_format_ops_test", size = "small", srcs = ["data_format_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1996,7 +1938,6 @@ tf_xla_py_strict_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2014,723 +1955,383 @@ tf_xla_py_strict_test( ], ) -cuda_py_strict_test( - name = "xla_device_gpu_test", - size = "small", - srcs = ["xla_device_gpu_test.py"], +tf_xla_py_strict_test( + name = "fake_quant_ops_test", + size = "medium", + srcs = ["fake_quant_ops_test.py"], + python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, deps = [ - "//tensorflow/python/client:session", - "//tensorflow/python/eager:context", + ":xla_test", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/ops:array_ops_gen", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", ], ) -cuda_py_strict_test( - name = "jit_test", +tf_xla_py_strict_test( + name = "placeholder_test", + size = "small", + srcs = ["placeholder_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:test", + ], +) + +tf_xla_py_strict_test( + name = "quantized_ops_test", size = "medium", - srcs = ["jit_test.py"], - #shard_count = 5, + srcs = ["quantized_ops_test.py"], + python_version = "PY3", tags = [ - "no_cuda_asan", # Times out. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, deps = [ - ":test_utils", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/client:session", - "//tensorflow/python/compiler/xla:compiler_py", + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:cond", - "//tensorflow/python/ops:control_flow_ops", - "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:bitwise_ops", "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:nn_ops", - "//tensorflow/python/ops:while_loop", - "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:test", "//third_party/py/numpy", ], ) -cuda_py_strict_test( - name = "async_comp_test", +tf_xla_py_strict_test( + name = "xla_ops_test", size = "medium", - srcs = ["async_comp_test.py"], - shard_count = 1, + srcs = ["xla_ops_test.py"], + disabled_backends = [ + "gpu", + "gpu_a100", + "gpu_h100", + ], + python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python/client:session", + ":xla_test", + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:random_ops_util", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@local_xla//xla:xla_data_proto_py", ], ) -cuda_py_strict_test( - name = "dense_layer_test", - size = "medium", - srcs = ["dense_layer_test.py"], +tf_xla_py_strict_test( + name = "xla_custom_call_ops_test", + size = "small", + srcs = ["xla_custom_call_ops_test.py"], + disabled_backends = [ + "gpu", + "gpu_a100", + "gpu_h100", + ], + python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, + use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ - ":test_utils", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/compiler/xla:compiler_py", + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", - "//tensorflow/python/layers", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:variables", + "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/ops:random_ops", "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", ], ) -cc_library( - name = "randomized_tests_library", - testonly = 1, - srcs = ["randomized_tests.cc"], +tf_xla_py_strict_test( + name = "runtime_shape_check_test", + size = "small", + srcs = ["runtime_shape_check_test.py"], + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], + python_version = "PY3", + tags = [ + "no_pip", + "notap", + ], + use_xla_device = False, deps = [ - "//tensorflow/compiler/jit", - "//tensorflow/compiler/jit:common", - "//tensorflow/compiler/jit:flags_headers", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow_opensource", - "//tensorflow/core:test", - "//tensorflow/core:testlib", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:status", - "@local_xla//xla:xla_data_proto_cc", + ":xla_test", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", ], ) -tf_cuda_cc_test( - name = "randomized_tests", +tf_xla_py_strict_test( + name = "conv_node_name_test", size = "medium", - args = ["--tf_xla_test_use_mlir=false"], - shard_count = 20, - # This test is randomized, so only run it if explicitly requested. + srcs = ["conv_node_name_test.py"], + python_version = "PY3", + shard_count = 5, tags = [ - "manual", + "no_oss", # TODO(b/148108508): Re-enable this test in OSS. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", - ] + tf_cuda_tests_tags(), - deps = [":randomized_tests_library"], + ], + deps = [ + ":xla_test", + "//tensorflow/python/framework:ops", + "//tensorflow/python/layers", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", + ], ) -tf_cuda_cc_test( - name = "randomized_tests_mlir", +tf_xla_py_strict_test( + name = "tridiagonal_solve_ops_test", size = "medium", - args = ["--tf_xla_test_use_mlir=true"], - shard_count = 20, - # This test is randomized, so only run it if explicitly requested. + srcs = ["tridiagonal_solve_ops_test.py"], + python_version = "PY3", tags = [ - "manual", "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", - ] + tf_cuda_tests_tags(), - deps = [":randomized_tests_library"], + "optonly", + ], + deps = [ + ":xla_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:gradients", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops/linalg:linalg_impl", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], ) -# Create a deterministic version of randomized_tests_mlir with fixed seed. -# This can be used in presubmit checks as it is no longer randomized. -tf_cuda_cc_test( - name = "randomized_tests_mlir_seeded", +tf_xla_py_strict_test( + name = "tridiagonal_matmul_ops_test", size = "medium", - args = [ - "--tf_xla_random_seed=200839030", - "--tf_xla_test_use_mlir=true", - "--tf_xla_test_device=GPU:0", - ], - shard_count = 20, - tags = [ - "config-cuda-only", - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "no_rocm", # ROCmSoftwarePlatform #958 - "noasan", # TODO(b/201651800) - "requires-gpu-nvidia", - ] + tf_cuda_tests_tags(), - deps = [":randomized_tests_library"], -) - -# Create a deterministic version of randomized_tests with fixed seed. -# This can be used in presubmit checks as it is no longer randomized. -tf_cuda_cc_test( - name = "randomized_tests_seeded", - size = "medium", - args = [ - "--tf_xla_random_seed=200839030", - "--tf_xla_test_use_mlir=false", - "--tf_xla_test_device=GPU:0", - ], - shard_count = 20, - tags = [ - "config-cuda-only", - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "no_rocm", # ROCmSoftwarePlatform #958 - "noasan", # TODO(b/201651800) - "requires-gpu-nvidia", - ] + tf_cuda_tests_tags(), - deps = [":randomized_tests_library"], -) - -tf_cuda_cc_test( - name = "unary_ops_composition_test", - srcs = ["unary_ops_composition_test.cc"], - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - ] + tf_cuda_tests_tags(), - deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/compiler/jit", - "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/jit:xla_kernel_creator", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels:ops_testutil", - "@local_tsl//tsl/platform:status", - ], -) - -py_strict_library( - name = "lstm", - testonly = 1, - srcs = ["lstm.py"], - srcs_version = "PY3", - deps = [ - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:random_ops", - "//tensorflow/python/ops:variable_v1", - "@six_archive//:six", - ], -) - -cuda_py_strict_test( - name = "lstm_test", - srcs = ["lstm_test.py"], + srcs = ["tridiagonal_matmul_ops_test.py"], + python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, deps = [ - ":lstm", ":xla_test", + "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:gradients_impl", - "//tensorflow/python/ops:init_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:variables", - "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/ops/linalg:linalg_impl", + "//tensorflow/python/platform:test", "//third_party/py/numpy", ], ) -# An example of ahead-of-time compilation using tfcompile. The -# lstm_layer_inference.pbtxt file was generated by running lstm_test -# --dump_graph_dir, and the config file was written by hand. -# -# Run the following to build a minimal benchmark of the computation on Android: -# $ bazel build -c opt --cxxopt='-std=c++11' --linkopt='-lm' \ -# --cpu=armeabi-v7a \ -# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ -# --crosstool_top=//external:android/crosstool \ -# //tensorflow/compiler/tests:lstm_layer_inference_benchmark - -# -# Currently the resulting binary size is ~190KB -tf_library( - name = "lstm_layer_inference", - testonly = 1, - config = "lstm_layer_inference.config.pbtxt", - cpp_class = "LSTMLayerInference", - graph = "lstm_layer_inference.pbtxt", - tags = ["manual"], - tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], -) - tf_xla_py_strict_test( - name = "fake_quant_ops_test", + name = "special_math_test", size = "medium", - srcs = ["fake_quant_ops_test.py"], - enable_mlir_bridge = True, - python_version = "PY3", + srcs = ["special_math_test.py"], + shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], deps = [ ":xla_test", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:array_ops_gen", - "//tensorflow/python/platform:test", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:gradient_checker_v2", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:math_ops_gen", + "//tensorflow/python/ops:random_ops_gen", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", + "@absl_py//absl/flags", + "@absl_py//absl/testing:parameterized", ], ) tf_xla_py_strict_test( - name = "placeholder_test", - size = "small", - srcs = ["placeholder_test.py"], - enable_mlir_bridge = True, - python_version = "PY3", + name = "repeat_op_test", + size = "medium", + srcs = ["repeat_op_test.py"], + shard_count = 1, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], deps = [ ":xla_test", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:resource_variable_ops", - "//tensorflow/python/ops:variables", - "//tensorflow/python/platform:test", + "//tensorflow/python/platform:client_testlib", ], ) tf_xla_py_strict_test( - name = "quantized_ops_test", + name = "image_ops_jit_compile_test", size = "medium", - srcs = ["quantized_ops_test.py"], - enable_mlir_bridge = False, - python_version = "PY3", + srcs = ["image_ops_jit_compile_test.py"], + disabled_backends = [ + "cpu_ondemand", + ], + shard_count = 1, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], + use_xla_device = False, deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python/framework:constant_op", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:bitwise_ops", + "//tensorflow/python/ops:image_ops", "//tensorflow/python/ops:math_ops", - "//tensorflow/python/platform:test", - "//third_party/py/numpy", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", ], ) tf_xla_py_strict_test( - name = "xla_ops_test", + name = "ensure_shape_op_test", size = "medium", - srcs = ["xla_ops_test.py"], - disabled_backends = [ - "gpu", - "gpu_a100", - "gpu_h100", - ], - enable_mlir_bridge = True, + srcs = ["ensure_shape_op_test.py"], python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:function", - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:array_ops_stack", - "//tensorflow/python/ops:random_ops_util", - "//tensorflow/python/platform:test", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - "@local_xla//xla:xla_data_proto_py", + "//tensorflow/python/ops:check_ops", + "//tensorflow/python/platform:client_testlib", ], ) tf_xla_py_strict_test( - name = "xla_custom_call_ops_test", + name = "where_op_test", size = "small", - srcs = ["xla_custom_call_ops_test.py"], - disabled_backends = [ + srcs = ["where_op_test.py"], + enabled_backends = [ + "cpu", "gpu", "gpu_a100", "gpu_h100", ], - enable_mlir_bridge = False, - python_version = "PY3", tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", + "no_pip", + "optonly", ], - use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:config", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", - "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/tpu:tpu_py", ], ) tf_xla_py_strict_test( - name = "runtime_shape_check_test", + name = "where_op_tpu_test", size = "small", - srcs = ["runtime_shape_check_test.py"], + srcs = ["where_op_test.py"], + args = [ + "--tpu_use_tfrt=true", + ], disabled_backends = [ "cpu", "cpu_ondemand", + "gpu", + "gpu_a100", + "gpu_h100", ], - enable_mlir_bridge = False, - python_version = "PY3", + main = "where_op_test.py", tags = [ "no_pip", - "notap", + "optonly", ], - use_xla_device = False, deps = [ ":xla_test", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:config", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/tpu:tpu_py", ], ) tf_xla_py_strict_test( - name = "conv_node_name_test", - size = "medium", - srcs = ["conv_node_name_test.py"], - enable_mlir_bridge = True, + name = "const_arg_test", + size = "small", + srcs = ["const_arg_test.py"], python_version = "PY3", - shard_count = 5, tags = [ - "no_oss", # TODO(b/148108508): Re-enable this test in OSS. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], deps = [ ":xla_test", - "//tensorflow/python/framework:ops", - "//tensorflow/python/layers", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/platform:test", - "//third_party/py/numpy", - ], -) - -tf_xla_py_strict_test( - name = "tridiagonal_solve_ops_test", - size = "medium", - srcs = ["tridiagonal_solve_ops_test.py"], - enable_mlir_bridge = True, - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:test_lib", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:gradients", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops/linalg:linalg_impl", - "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", - ], -) - -tf_xla_py_strict_test( - name = "tridiagonal_matmul_ops_test", - size = "medium", - srcs = ["tridiagonal_matmul_ops_test.py"], - enable_mlir_bridge = True, - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:array_ops_stack", - "//tensorflow/python/ops:gradient_checker_v2", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:stateless_random_ops", - "//tensorflow/python/ops/linalg:linalg_impl", - "//tensorflow/python/platform:test", - "//third_party/py/numpy", - ], -) - -tf_xla_py_strict_test( - name = "special_math_test", - size = "medium", - srcs = ["special_math_test.py"], - enable_mlir_bridge = True, - shard_count = 5, - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/ops:gradient_checker_v2", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:math_ops_gen", - "//tensorflow/python/ops:random_ops_gen", - "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", - "@absl_py//absl/flags", - "@absl_py//absl/testing:parameterized", - ], -) - -tf_xla_py_strict_test( - name = "repeat_op_test", - size = "medium", - srcs = ["repeat_op_test.py"], - enable_mlir_bridge = True, - shard_count = 1, - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/platform:client_testlib", - ], -) - -tf_xla_py_strict_test( - name = "image_ops_jit_compile_test", - size = "medium", - srcs = ["image_ops_jit_compile_test.py"], - disabled_backends = [ - "cpu_ondemand", - ], - enable_mlir_bridge = False, - shard_count = 1, - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - use_xla_device = False, - deps = [ - ":xla_test", - "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python/eager:backprop", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:ops", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:image_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:variables", - "//tensorflow/python/platform:client_testlib", - ], -) - -tf_xla_py_strict_test( - name = "ensure_shape_op_test", - size = "medium", - srcs = ["ensure_shape_op_test.py"], - enable_mlir_bridge = False, - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:check_ops", - "//tensorflow/python/platform:client_testlib", - ], -) - -tf_xla_py_strict_test( - name = "where_op_test", - size = "small", - srcs = ["where_op_test.py"], - enable_mlir_bridge = False, - enabled_backends = [ - "cpu", - "gpu", - "gpu_a100", - "gpu_h100", - ], - tags = [ - "no_pip", - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/framework:config", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/platform:client_testlib", - "//tensorflow/python/tpu:tpu_py", - ], -) - -tf_xla_py_strict_test( - name = "where_op_tpu_test", - size = "small", - srcs = ["where_op_test.py"], - args = [ - "--tpu_use_tfrt=true", - ], - disabled_backends = [ - "cpu", - "cpu_ondemand", - "gpu", - "gpu_a100", - "gpu_h100", - ], - enable_mlir_bridge = False, - main = "where_op_test.py", - tags = [ - "no_pip", - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/framework:config", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/platform:client_testlib", - "//tensorflow/python/tpu:tpu_py", - ], -) - -tf_xla_py_strict_test( - name = "const_arg_test", - size = "small", - srcs = ["const_arg_test.py"], - enable_mlir_bridge = False, - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - ], - deps = [ - ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/platform:test", - ], -) - -cuda_py_strict_test( - name = "const_test", - size = "small", - srcs = ["const_test.py"], - python_version = "PY3", - xla_enable_strict_auto_jit = False, - xla_enabled = True, - deps = [ - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:test_lib", - "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", - ], -) - -tpu_py_strict_test( - name = "giant_const_op_test", - srcs = [ - "giant_const_op_test.py", - ], - disable_experimental = True, - # TODO(b/188995810): Add an optimization in MLIR importer to not - # materialize giant splat constants. - disable_mlir_bridge = True, - python_version = "PY3", - tags = ["no_oss"], - deps = [ - "//tensorflow/python/distribute:tpu_strategy", - "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/eager:remote", - "//tensorflow/python/eager:test", - "//tensorflow/python/framework:config", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/platform:flags", - "//third_party/py/numpy", ], ) @@ -2744,7 +2345,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - enable_mlir_bridge = False, python_version = "PY3", shard_count = 10, tags = [ @@ -2769,7 +2369,6 @@ tpu_py_strict_test( name = "approx_topk_test", srcs = ["approx_topk_test.py"], disable_experimental = False, - disable_mlir_bridge = False, tags = ["no_oss"], deps = [ "//tensorflow/python/eager:backprop", @@ -2790,7 +2389,6 @@ tf_xla_py_strict_test( name = "xla_call_module_test", size = "small", srcs = ["xla_call_module_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2821,7 +2419,6 @@ tf_xla_py_strict_test( srcs = ["xla_call_module_test.py"], # cpu_ondemand overrides the TF_XLA_FLAGS disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, env = {"TF_XLA_FLAGS": "--tf_xla_call_module_disabled_checks=platform"}, main = "xla_call_module_test.py", python_version = "PY3", @@ -2853,7 +2450,6 @@ tf_xla_py_strict_test( size = "small", srcs = ["xla_call_module_test.py"], disabled_backends = ["cpu_ondemand"], # cpu_ondemand overrides the TF_XLA_FLAGS - enable_mlir_bridge = False, env = {"TF_XLA_FLAGS": "--tf_xla_call_module_disabled_checks=shape_assertions"}, main = "xla_call_module_test.py", python_version = "PY3", @@ -2884,7 +2480,6 @@ tf_xla_py_strict_test( name = "bincount_op_test", size = "small", srcs = ["bincount_op_test.py"], - enable_mlir_bridge = False, python_version = "PY3", shard_count = 1, tags = [ @@ -2902,7 +2497,6 @@ tf_xla_py_strict_test( name = "unique_ops_test", size = "small", srcs = ["unique_ops_test.py"], - enable_mlir_bridge = False, enabled_backends = [ "cpu", "gpu", @@ -2929,7 +2523,6 @@ tpu_py_strict_test( size = "small", srcs = ["mean_op_test.py"], disable_experimental = False, - disable_mlir_bridge = False, tags = [ "notsan", # timesout ], @@ -2948,7 +2541,6 @@ tf_xla_py_strict_test( name = "xla_dump_to_test", size = "medium", srcs = ["xla_dump_to_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2968,7 +2560,6 @@ tf_xla_py_strict_test( # name = "xla_dump_to_sponge_test", # size = "medium", # srcs = ["xla_dump_to_sponge_test.py"], -# enable_mlir_bridge = True, # python_version = "PY3", # tags = [ # "optonly", @@ -2982,3 +2573,327 @@ tf_xla_py_strict_test( # ], # ) # copybara:uncomment_end +#LINT.ThenChange(:combined_tests) + +cuda_py_strict_test( + name = "xla_device_gpu_test", + size = "small", + srcs = ["xla_device_gpu_test.py"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + "//tensorflow/python/client:session", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", + ], +) + +cuda_py_strict_test( + name = "jit_test", + size = "medium", + srcs = ["jit_test.py"], + #shard_count = 5, + tags = [ + "no_cuda_asan", # Times out. + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + ":test_utils", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/client:session", + "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:cond", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:while_loop", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +cuda_py_strict_test( + name = "async_comp_test", + size = "medium", + srcs = ["async_comp_test.py"], + shard_count = 1, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python/client:session", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", + ], +) + +cuda_py_strict_test( + name = "dense_layer_test", + size = "medium", + srcs = ["dense_layer_test.py"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + ":test_utils", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/framework:ops", + "//tensorflow/python/layers", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +cc_library( + name = "randomized_tests_library", + testonly = 1, + srcs = ["randomized_tests.cc"], + deps = [ + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:test", + "//tensorflow/core:testlib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:status", + "@local_xla//xla:xla_data_proto_cc", + ], +) + +tf_cuda_cc_test( + name = "randomized_tests", + size = "medium", + args = ["--tf_xla_test_use_mlir=false"], + shard_count = 20, + # This test is randomized, so only run it if explicitly requested. + tags = [ + "manual", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", + ] + tf_cuda_tests_tags(), + deps = [":randomized_tests_library"], +) + +tf_cuda_cc_test( + name = "randomized_tests_mlir", + size = "medium", + args = ["--tf_xla_test_use_mlir=true"], + shard_count = 20, + # This test is randomized, so only run it if explicitly requested. + tags = [ + "manual", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", + ] + tf_cuda_tests_tags(), + deps = [":randomized_tests_library"], +) + +# Create a deterministic version of randomized_tests_mlir with fixed seed. +# This can be used in presubmit checks as it is no longer randomized. +tf_cuda_cc_test( + name = "randomized_tests_mlir_seeded", + size = "medium", + args = [ + "--tf_xla_random_seed=200839030", + "--tf_xla_test_use_mlir=true", + "--tf_xla_test_device=GPU:0", + ], + shard_count = 20, + tags = [ + "config-cuda-only", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "no_rocm", # ROCmSoftwarePlatform #958 + "noasan", # TODO(b/201651800) + "requires-gpu-nvidia", + ] + tf_cuda_tests_tags(), + deps = [":randomized_tests_library"], +) + +# Create a deterministic version of randomized_tests with fixed seed. +# This can be used in presubmit checks as it is no longer randomized. +tf_cuda_cc_test( + name = "randomized_tests_seeded", + size = "medium", + args = [ + "--tf_xla_random_seed=200839030", + "--tf_xla_test_use_mlir=false", + "--tf_xla_test_device=GPU:0", + ], + shard_count = 20, + tags = [ + "config-cuda-only", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "no_rocm", # ROCmSoftwarePlatform #958 + "noasan", # TODO(b/201651800) + "requires-gpu-nvidia", + ] + tf_cuda_tests_tags(), + deps = [":randomized_tests_library"], +) + +tf_cuda_cc_test( + name = "unary_ops_composition_test", + srcs = ["unary_ops_composition_test.cc"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ] + tf_cuda_tests_tags(), + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:xla_kernel_creator", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "@local_tsl//tsl/platform:status", + ], +) + +py_strict_library( + name = "lstm", + testonly = 1, + srcs = ["lstm.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:variable_v1", + "@six_archive//:six", + ], +) + +cuda_py_strict_test( + name = "lstm_test", + srcs = ["lstm_test.py"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + ":lstm", + ":xla_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:init_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +# An example of ahead-of-time compilation using tfcompile. The +# lstm_layer_inference.pbtxt file was generated by running lstm_test +# --dump_graph_dir, and the config file was written by hand. +# +# Run the following to build a minimal benchmark of the computation on Android: +# $ bazel build -c opt --cxxopt='-std=c++11' --linkopt='-lm' \ +# --cpu=armeabi-v7a \ +# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ +# --crosstool_top=//external:android/crosstool \ +# //tensorflow/compiler/tests:lstm_layer_inference_benchmark + +# +# Currently the resulting binary size is ~190KB +tf_library( + name = "lstm_layer_inference", + testonly = 1, + config = "lstm_layer_inference.config.pbtxt", + cpp_class = "LSTMLayerInference", + graph = "lstm_layer_inference.pbtxt", + tags = ["manual"], + tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], +) + +cuda_py_strict_test( + name = "const_test", + size = "small", + srcs = ["const_test.py"], + python_version = "PY3", + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +tpu_py_strict_test( + name = "giant_const_op_test", + srcs = [ + "giant_const_op_test.py", + ], + disable_experimental = True, + # TODO(b/188995810): Add an optimization in MLIR importer to not + # materialize giant splat constants. + python_version = "PY3", + tags = ["no_oss"], + deps = [ + "//tensorflow/python/distribute:tpu_strategy", + "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:remote", + "//tensorflow/python/eager:test", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/platform:flags", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index b54c2e54fa3552..12fa6dd7d04bd2 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -392,9 +392,9 @@ def testNumericOps(self): for dtype in self.numeric_types: self._testBinary( math_ops.subtract, - np.array([1, 2, 100], dtype=dtype), - np.array([10, 20, -1], dtype=dtype), - expected=np.array([-9, -18, 101], dtype=dtype)) + np.array([1, 20, 100], dtype=dtype), + np.array([1, 2, 1], dtype=dtype), + expected=np.array([0, 18, 99], dtype=dtype)) self._testBinary( math_ops.subtract, dtype(5), @@ -402,9 +402,9 @@ def testNumericOps(self): expected=np.array([4, 3], dtype=dtype)) self._testBinary( math_ops.subtract, - np.array([[1], [2]], dtype=dtype), + np.array([[7], [10]], dtype=dtype), dtype(7), - expected=np.array([[-6], [-5]], dtype=dtype)) + expected=np.array([[0], [3]], dtype=dtype)) # min/max not supported for complex if dtype not in self.complex_types | {np.uint8, np.int8}: @@ -461,13 +461,13 @@ def testNumericOps(self): self._testBinary( nn_ops.bias_add, np.array([[1, 2], [3, 4]], dtype=dtype), - np.array([2, -1], dtype=dtype), - expected=np.array([[3, 1], [5, 3]], dtype=dtype)) + np.array([2, 0], dtype=dtype), + expected=np.array([[3, 2], [5, 4]], dtype=dtype)) self._testBinary( nn_ops.bias_add, np.array([[[[1, 2], [3, 4]]]], dtype=dtype), - np.array([2, -1], dtype=dtype), - expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + np.array([2, 0], dtype=dtype), + expected=np.array([[[[3, 2], [5, 4]]]], dtype=dtype)) if np.int64 in self.numeric_types: self._testBinary( @@ -998,8 +998,8 @@ def testFill(self): self._testBinary( array_ops.fill, np.array([], dtype=np.int32), - dtype(-42), - expected=dtype(-42)) + dtype(42), + expected=dtype(42)) self._testBinary( array_ops.fill, np.array([1, 2], dtype=np.int32), diff --git a/tensorflow/compiler/tests/build_combined_defs.bzl b/tensorflow/compiler/tests/build_combined_defs.bzl index 0463fe19326729..92f04ab6215c91 100644 --- a/tensorflow/compiler/tests/build_combined_defs.bzl +++ b/tensorflow/compiler/tests/build_combined_defs.bzl @@ -3,16 +3,30 @@ load("//tensorflow:strict.default.bzl", "py_strict_test") load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") -def tf_xla_combined_py_test(name = "", package = None, test_files = [], **kwargs): +def parse_label_name(label): + """Parse a label into just the name. + + Args: + label: string in relative or absolute form. + + Returns: + The name of the label. + """ + colon_split = label.split(":") + if len(colon_split) == 1: # no ":" in label + return label + return colon_split[-1] + +def tf_xla_combined_py_test(name = "", package = None, tests = [], **kwargs): """Generates combined tf_xla_py_test targets, one per XLA backend. - All tests found in the list test_files are combined into one new test which is then passed on to + All srcs found in the list tests are combined into one new test which is then passed on to tf_xla_py_test which creates a new target per XLA backend. Args: name: Name of the target. - package: The package that all tests in test_files belong to. - test_files: The test files to be combined and tested. + package: The package that all tests in tests belong to. + tests: The test targets to be combined and tested. Assumes all tests are in the same package. **kwargs: keyword arguments passed onto the tf_xla_py_test rule. """ @@ -23,7 +37,7 @@ def tf_xla_combined_py_test(name = "", package = None, test_files = [], **kwargs native.genrule( name = name + "_gen", testonly = 1, - srcs = test_files, + srcs = tests, outs = [test_file], cmd = """ mkdir -p $(@D) && cat > $@ << EOF @@ -33,7 +47,7 @@ from tensorflow.python.platform import test if __name__ == "__main__": test.main() EOF - """ % "\n".join(["from %s.%s import *" % (package, test[:-3]) for test in test_files]), + """ % "\n".join(["from %s.%s import *" % (package, parse_label_name(test)[:-4]) for test in tests]), tools = [], tags = ["generated_python_test=%s.%s" % (package, name)], ) @@ -41,6 +55,9 @@ EOF tf_xla_py_test( name = name, test_rule = py_strict_test, - srcs = [test_file] + test_files, + srcs = [test_file], + deps = [ + "//tensorflow/python/platform:client_testlib", + ] + tests, **kwargs ) diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index ce6b626683e281..fb5cb0448e8224 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -1,5 +1,6 @@ """Build rules for Tensorflow/XLA testing.""" +load("//tensorflow:py.default.bzl", "py_library") load("//tensorflow:strict.default.bzl", "py_strict_test") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") @@ -73,6 +74,12 @@ def tf_xla_py_test( cpu_xla_device = "CPU" gpu_xla_device = "GPU" + py_library( + name = name + "_lib", + srcs = srcs, + deps = deps, + testonly = 1, + ) for backend in backends: test_name = "{}_{}".format(name, backend) backend_tags = ["tf_xla_{}".format(backend)] @@ -139,7 +146,7 @@ def tf_xla_py_test( args = backend_args, main = "{}.py".format(name) if main == None else main, data = data + backend_data, - deps = deps + backend_deps + extra_dep, + deps = deps + backend_deps + extra_dep + [name + "_lib"], tags = test_tags + extra_tag, exec_properties = tf_exec_properties({"tags": test_tags}), **kwargs diff --git a/tensorflow/compiler/tests/const_test.py b/tensorflow/compiler/tests/const_test.py index 4e11a436e850af..bb1f3e23a7306e 100644 --- a/tensorflow/compiler/tests/const_test.py +++ b/tensorflow/compiler/tests/const_test.py @@ -33,15 +33,33 @@ class ConstOpTest(test_util.TensorFlowTestCase): # @test_util.run_v2_only def testConst(self): types = { - dtypes.bool, dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, - dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64, - dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, - dtypes.float8_e5m2, dtypes.float8_e4m3fn, + dtypes.bool, + dtypes.int8, + dtypes.int16, + dtypes.int32, + dtypes.int64, + dtypes.uint8, + dtypes.uint16, + dtypes.uint32, + dtypes.uint64, + dtypes.float16, + dtypes.bfloat16, + dtypes.float32, + dtypes.float64, + dtypes.float8_e5m2, + dtypes.float8_e4m3fn, } for dtype in types: with self.subTest(dtype=dtype): if dtype == dtypes.bool: values = [True, False] + elif dtype in [ + dtypes.uint8, + dtypes.uint16, + dtypes.uint32, + dtypes.uint64, + ]: + values = [0., 1., dtype.min, dtype.max] else: values = [0., 1., -1., dtype.min, dtype.max] if dtype.is_floating: diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 7abf9a0bba1122..9f4221cfdebe11 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -50,10 +50,10 @@ def testUpdateSlice(self): self._assertOpOutputMatchesExpected( xla.dynamic_update_slice, [ np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), - np.array([-1, -2, -3], dtype=dtype), + np.array([11, 12, 13], dtype=dtype), np.array([6], dtype=np.int32) ], - expected=np.array([1, 2, 3, 4, 5, 6, -1, -2, -3, 10], dtype=dtype)) + expected=np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10], dtype=dtype)) self._assertOpOutputMatchesExpected( xla.dynamic_update_slice, [ diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 930fd21ab42c27..ded287593029ff 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -236,25 +236,25 @@ def testUnsortedSegmentSum2DIndices3DData(self): for dtype in self.numeric_types: data = np.array( [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ - 200, 201, 202 - ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], + 80, 81, 82 + ], [123, 124, 125]], [[103, 104, 105], [106, 107, 108]]], dtype=dtype) indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) num_segments = 8 y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( - [[210, 211, 212], [110, 111, 112], [310, 311, 312], [ + [[123, 124, 125], [110, 111, 112], [106, 107, 108], [ 100, 102, 104 - ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]], + ], [0, 0, 0.], [90, 92, 94], [103, 104, 105], [0, 0, 0]], dtype=dtype), y) def testUnsortedSegmentSum1DIndices3DData(self): for dtype in self.numeric_types: data = np.array( [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ - 200, 201, 202 - ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], + 120, 121, 122 + ], [123, 124, 125]], [[103, 104, 105], [106, 107, 108]]], dtype=dtype) indices = np.array([3, 0, 2, 5], dtype=np.int32) num_segments = 6 @@ -262,8 +262,8 @@ def testUnsortedSegmentSum1DIndices3DData(self): self.assertAllClose( np.array( [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], - [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]], - [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]], + [[120, 121, 122], [123, 124, 125]], [[0, 1, 2.], [10, 11, 12]], + [[0, 0, 0], [0, 0, 0]], [[103, 104, 105], [106, 107, 108]]], dtype=dtype), y) def testUnsortedSegmentSumShapeError(self): diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index e4937d223165da..809db242ac4afe 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -185,21 +185,28 @@ def testSlice(self): np.array([[], [], []], dtype=dtype), np.array([1, 0], dtype=np.int32), np.array([2, 0], dtype=np.int32), - expected=np.array([[], []], dtype=dtype)) + expected=np.array([[], []], dtype=dtype), + ) self._testTernary( array_ops.slice, np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), np.array([0, 1], dtype=np.int32), np.array([2, 1], dtype=np.int32), - expected=np.array([[2], [5]], dtype=dtype)) + expected=np.array([[2], [5]], dtype=dtype), + ) def testClipByValue(self): - for dtype in self.numeric_types - self.complex_types: + for dtype in ( + self.numeric_types - self.complex_types - self.unsigned_int_types + ): test_cases = [ (np.array([2, 4, 5], dtype=dtype), dtype(7)), # (dtype(1), np.array([2, 4, 5], dtype=dtype)), # - (np.array([-2, 7, 7], dtype=dtype), np.array([-2, 9, 8], dtype=dtype)) + ( + np.array([-2, 7, 7], dtype=dtype), + np.array([-2, 9, 8], dtype=dtype), + ), ] x = np.array([-2, 10, 6], dtype=dtype) for lower, upper in test_cases: diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ee0967d2150e3d..99b997561b41c3 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -773,7 +773,7 @@ def testComplexOps(self): expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype])) def testIntOps(self): - for dtype in self.int_types: + for dtype in self.int_types - self.unsigned_int_types: self._assertOpOutputMatchesExpected( bitwise_ops.invert, np.array([0, -1, 1, 16, 42], dtype=dtype), @@ -923,7 +923,10 @@ def _testCast(self, src_type, dst_type): if src_type.is_integer: imin = np.iinfo(src_np_dtype).min imax = np.iinfo(src_np_dtype).max - src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype) + if src_type.is_unsigned: + src = np.array([imin, imax, 0, 1], dtype=src_np_dtype) + else: + src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype) elif src_type in self.float_tf_types: if dst_type.is_integer: imin = np.iinfo(dst_np_dtype).min @@ -936,63 +939,75 @@ def _testCast(self, src_type, dst_type): eps = np.finfo(dst_np_dtype).eps src = np.array( [fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf], - dtype=src_np_dtype) + dtype=src_np_dtype, + ) dst = src.astype(dst_np_dtype) self._assertOpOutputMatchesExpected( lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), src, - expected=dst) + expected=dst, + ) def testBitcast(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int32), - np.array([1, 0x3f800000], np.int32), - expected=np.array([1, 0x3f800000], np.int32)) + np.array([1, 0x3F800000], np.int32), + expected=np.array([1, 0x3F800000], np.int32), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.float32), - np.array([1, 0x3f800000], np.int32), - expected=np.array([1e-45, 1.0], np.float32)) + np.array([1, 0x3F800000], np.int32), + expected=np.array([1e-45, 1.0], np.float32), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int32), np.array([1e-45, 1.0], np.float32), - expected=np.array([1, 0x3f800000], np.int32)) + expected=np.array([1, 0x3F800000], np.int32), + ) if np.int64 in self.numeric_types: self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int64), - np.array([1, 0x100000003f800000], np.uint64), - expected=np.array([1, 0x100000003f800000], np.int64)) + np.array([1, 0x100000003F800000], np.uint64), + expected=np.array([1, 0x100000003F800000], np.int64), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.uint64), - np.array([1, 0x100000003f800000], np.int64), - expected=np.array([1, 0x100000003f800000], np.uint64)) + np.array([1, 0x100000003F800000], np.int64), + expected=np.array([1, 0x100000003F800000], np.uint64), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.float64), np.array( - [0, 0x3FF0000000000000, 0xc3af161421c8e000, 0x4032000000000007], + [0, 0x3FF0000000000000, 0xC3AF161421C8E000, 0x4032000000000007], np.uint64, ), expected=np.array( [0, 1.0, -1.12e+18, 18.000000000000024869], np.float64 ), - atol=0 + atol=0, ) def testBitcastInt8ToFloat(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.float32), - np.array([[1, 0, 0, 0], [0xd0, 0x0f, 0x49, 0x40]], np.int8), - expected=np.array([1e-45, 3.14159], np.float32)) + np.array([[1, 0, 0, 0], [0xD0, 0x0F, 0x49, 0x40]]).astype(np.int8), + expected=np.array([1e-45, 3.14159], np.float32), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.np.int8), np.array([1e-45, 3.14159], np.float32), - expected=np.array([[1, 0, 0, 0], [0xd0, 0x0f, 0x49, 0x40]], np.int8)) + expected=np.array([[1, 0, 0, 0], [0xD0, 0x0F, 0x49, 0x40]]).astype( + np.int8 + ), + ) def testInvertPermutation(self): for np_dtype in [np.int32, np.int64]: self._assertOpOutputMatchesExpected( array_ops.invert_permutation, np.array([1, 2, 0], np_dtype), - expected=np.array([2, 0, 1], dtype=np_dtype)) + expected=np.array([2, 0, 1], dtype=np_dtype), + ) def testInvertPermutationTwiceIsNoop(self): @@ -1013,12 +1028,12 @@ def testRank(self): self._assertOpOutputMatchesExpected( rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( - rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1)) + rank_op, np.array([0, 1], dtype=dtype), expected=np.int32(1)) self._assertOpOutputMatchesExpected( - rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) + rank_op, np.array([[0, 1]], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( rank_op, - np.array([[-1], [1], [4]], dtype=dtype), + np.array([[0], [1], [4]], dtype=dtype), expected=np.int32(2)) def testShape(self): @@ -1032,15 +1047,15 @@ def testShape(self): expected=np.array([2, 0], dtype=np.int32)) self._assertOpOutputMatchesExpected( shape_op, - np.array([-1, 1], dtype=dtype), + np.array([0, 1], dtype=dtype), expected=np.array([2], dtype=np.int32)) self._assertOpOutputMatchesExpected( shape_op, - np.array([[-1, 1]], dtype=dtype), + np.array([[0, 1]], dtype=dtype), expected=np.array([1, 2], dtype=np.int32)) self._assertOpOutputMatchesExpected( shape_op, - np.array([[-1], [1], [4]], dtype=dtype), + np.array([[0], [1], [4]], dtype=dtype), expected=np.array([3, 1], dtype=np.int32)) def testSize(self): @@ -1051,12 +1066,12 @@ def testSize(self): self._assertOpOutputMatchesExpected( size_op, np.array([[], []], dtype=dtype), expected=np.int32(0)) self._assertOpOutputMatchesExpected( - size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2)) + size_op, np.array([0, 1], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( - size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) + size_op, np.array([[0, 1]], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( size_op, - np.array([[-1], [1], [4]], dtype=dtype), + np.array([[0], [1], [4]], dtype=dtype), expected=np.int32(3)) def testSizeWithInt64OutType(self): @@ -1067,7 +1082,7 @@ def size_op(x): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( size_op, - np.array([[-1], [1], [4]], dtype=dtype), + np.array([[0], [1], [4]], dtype=dtype), expected=np.int64(3)) def testUnpack(self): diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 9ba3dedf4a6f54..e8695e29d7bfb9 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -211,7 +211,6 @@ tf_cuda_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:statusor", @@ -223,9 +222,11 @@ tf_cuda_library( "@local_xla//xla/service:custom_call_target_registry", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/stream_executor", + "@local_xla//xla/stream_executor:stream_finder", "@local_xla//xla/stream_executor/gpu:gpu_executor_header", "@local_xla//xla/stream_executor/gpu:gpu_stream_header", "@local_xla//xla/stream_executor/gpu:gpu_types_header", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index 4971fd07eeaa7e..441c2b400d1b4d 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -56,6 +56,8 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_finder.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/process_state.h" @@ -73,7 +75,6 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/statusor.h" @@ -580,11 +581,8 @@ Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque, }(); if (platform_status != nullptr) return *platform_status; - se::StreamExecutorConfig config; - config.gpu_stream = stream_handle; - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - platform->GetExecutor(config)); - se::Stream* stream = executor->FindAllocatedStream(stream_handle); + TF_ASSIGN_OR_RETURN(se::Stream * stream, + stream_executor::FindStream(platform, stream_handle)); if (!stream) { return xla::Internal("Stream not found for %p", stream_handle); } diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 48f0622b75f79b..f46bd78cf6960b 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -331,11 +331,6 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { return nullptr; } - absl::StatusOr GetExecutor( - const se::StreamExecutorConfig& config) override { - return nullptr; - } - private: string name_; }; diff --git a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc index 9b7930ad1ef590..473be0c108896d 100644 --- a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc +++ b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc @@ -64,7 +64,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { } // Build up a todo list of ops to replace, *then* modify the graph - gtl::InlinedVector matches; + absl::InlinedVector matches; for (Node* n : g->op_nodes()) { if (n->type_string() == "AccumulateNV2") { matches.push_back(n); diff --git a/tensorflow/core/common_runtime/arg_ret_placement.cc b/tensorflow/core/common_runtime/arg_ret_placement.cc index a995564c8c2964..386c54849a254e 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement.cc +++ b/tensorflow/core/common_runtime/arg_ret_placement.cc @@ -255,7 +255,7 @@ Status SetAllocAttrsForArgs(const gtl::InlinedVector& nodes, /*weak_flag=*/false, nullptr, &alloc_attrs); } -Status WeakSetAllocAttrsForArgs(const gtl::InlinedVector& nodes, +Status WeakSetAllocAttrsForArgs(const absl::InlinedVector& nodes, const DataTypeVector& dtypes, std::vector& alloc_attrs) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/true, diff --git a/tensorflow/core/common_runtime/arg_ret_placement.h b/tensorflow/core/common_runtime/arg_ret_placement.h index 4f00d18e3bb6ca..fd8a4858b83c8e 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement.h +++ b/tensorflow/core/common_runtime/arg_ret_placement.h @@ -93,7 +93,7 @@ Status SetAllocAttrsForRets(const gtl::InlinedVector& nodes, // ops) based on dtype. Logging of warnings if an int32 ret does not have // expected full_type information (i.e. if the source of the input to the ret // does not have expected full type information) can be enabled. -Status WeakSetAllocAttrsForRets(const gtl::InlinedVector& nodes, +Status WeakSetAllocAttrsForRets(const absl::InlinedVector& nodes, const DataTypeVector& dtypes, std::vector& alloc_attrs); diff --git a/tensorflow/core/common_runtime/arg_ret_placement_test.cc b/tensorflow/core/common_runtime/arg_ret_placement_test.cc index 11b8bdb19c064b..284702a4ecc3e2 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement_test.cc +++ b/tensorflow/core/common_runtime/arg_ret_placement_test.cc @@ -205,7 +205,7 @@ TEST_F(FullTypeGraphUtilsTest, MemoryTypeRetWithFT) { } TEST_F(FullTypeGraphUtilsTest, AllowAttrRetWithFT) { - gtl::InlinedVector nodes; + absl::InlinedVector nodes; DataTypeVector dtypes; std::vector alloc_attrs; diff --git a/tensorflow/core/common_runtime/collective_test_util.cc b/tensorflow/core/common_runtime/collective_test_util.cc index 24c85b321ae0d9..18ef2ab824daf1 100644 --- a/tensorflow/core/common_runtime/collective_test_util.cc +++ b/tensorflow/core/common_runtime/collective_test_util.cc @@ -325,10 +325,11 @@ Status RunCollective(CollectiveTestEnv* test_env, CollectiveParams* col_params, op_params.step_id = kStepId; op_params.device = device; op_params.cancellation_manager = &cancellation_manager; - gtl::InlinedVector inputs; + absl::InlinedVector inputs; inputs.push_back(TensorValue(&input_buffer)); op_params.inputs = inputs; - gtl::InlinedVector input_aa({AllocatorAttributes()}); + absl::InlinedVector input_aa( + {AllocatorAttributes()}); op_params.input_alloc_attrs = input_aa; DeviceContext* dev_ctx = nullptr; auto* dev_info = device->tensorflow_accelerator_device_info(); diff --git a/tensorflow/core/common_runtime/collective_util.h b/tensorflow/core/common_runtime/collective_util.h index 01fb8b8c81cd2f..b53e779701afce 100644 --- a/tensorflow/core/common_runtime/collective_util.h +++ b/tensorflow/core/common_runtime/collective_util.h @@ -37,9 +37,9 @@ string SubdivPermDebugString(const CollectiveParams& col_params); class SubContext { public: OpKernelContext::Params sub_params_; - gtl::InlinedVector sub_inputs_; - gtl::InlinedVector sub_input_attr_; - gtl::InlinedVector sub_input_dc_; + absl::InlinedVector sub_inputs_; + absl::InlinedVector sub_input_attr_; + absl::InlinedVector sub_input_dc_; // Used only for Binary and Unary Ops for which we require // the calculation to be in-place on the first input. int forward_from_ = 0; diff --git a/tensorflow/core/common_runtime/entry.h b/tensorflow/core/common_runtime/entry.h index 9164cce3eae94c..82bf44eae816b9 100644 --- a/tensorflow/core/common_runtime/entry.h +++ b/tensorflow/core/common_runtime/entry.h @@ -134,7 +134,7 @@ struct Entry { }; // TODO(b/152925936): Re-evaluate this constant with current usage patterns. -typedef gtl::InlinedVector EntryVector; +typedef absl::InlinedVector EntryVector; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 48ec47636e30df..2054114de4d86d 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -142,8 +142,8 @@ struct KernelTimer { }; // TODO(b/152925936): Re-evaluate these constants with current usage patterns. -typedef gtl::InlinedVector TensorValueVec; -typedef gtl::InlinedVector AllocatorAttributeVec; +typedef absl::InlinedVector TensorValueVec; +typedef absl::InlinedVector AllocatorAttributeVec; class ExecutorImpl : public Executor { public: diff --git a/tensorflow/core/common_runtime/function_body.cc b/tensorflow/core/common_runtime/function_body.cc index 1ca6f6a535ceb6..60a6f41f1d8162 100644 --- a/tensorflow/core/common_runtime/function_body.cc +++ b/tensorflow/core/common_runtime/function_body.cc @@ -35,7 +35,7 @@ FunctionBody::FunctionBody(core::RefCountPtr&& record, this->arg_nodes.resize(arg_types.size()); this->ret_nodes.resize(ret_types.size()); for (Node* n : this->graph->op_nodes()) { - gtl::InlinedVector* node_vec; + absl::InlinedVector* node_vec; if (n->type_string() == FunctionLibraryDefinition::kRetOp || n->type_string() == FunctionLibraryDefinition::kDeviceRetOp) { node_vec = &this->ret_nodes; diff --git a/tensorflow/core/common_runtime/function_body.h b/tensorflow/core/common_runtime/function_body.h index 97d27f51099e31..959f9803227764 100644 --- a/tensorflow/core/common_runtime/function_body.h +++ b/tensorflow/core/common_runtime/function_body.h @@ -37,11 +37,11 @@ struct FunctionBody { DataTypeVector ret_types; // arg_nodes[i] contains the i'th function input. In other words, // GetNodeAttr(arg_nodes[i]->attrs(), "index") == i. - gtl::InlinedVector arg_nodes; + absl::InlinedVector arg_nodes; // ret_nodes[i] contains the i'th function output. In other words, // GetNodeAttr(ret_nodes[i]->attrs(), "index") == i. - gtl::InlinedVector ret_nodes; - gtl::InlinedVector control_ret_nodes; + absl::InlinedVector ret_nodes; + absl::InlinedVector control_ret_nodes; FunctionBody() {} FunctionBody(core::RefCountPtr&& record, diff --git a/tensorflow/core/common_runtime/function_utils.cc b/tensorflow/core/common_runtime/function_utils.cc index facd31481c05ed..56389623808262 100644 --- a/tensorflow/core/common_runtime/function_utils.cc +++ b/tensorflow/core/common_runtime/function_utils.cc @@ -162,7 +162,7 @@ bool RemoveIdentityNodes(Graph* g) { bool RemoveListArrayConverter(Graph* g) { VLOG(2) << "Removing list array converter"; - gtl::InlinedVector matches; + absl::InlinedVector matches; for (Node* n : g->nodes()) { if ((n->type_string() == "_ListToArray") || (n->type_string() == "_ArrayToList")) { diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index 10d7eccafd2ad6..2483a1bb239b8b 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -1,4 +1,5 @@ load("@bazel_skylib//lib:selects.bzl", "selects") +load("@local_xla//xla/tsl:tsl.bzl", "if_hermetic_cuda_libs") load( "//tensorflow:tensorflow.bzl", "clean_dep", @@ -140,6 +141,19 @@ filegroup( visibility = ["//visibility:private"], ) +cc_library( + name = "gpu_runtime_hermetic_cuda_deps", + visibility = ["//visibility:public"], + deps = if_hermetic_cuda_libs([ + "@local_xla//xla/tsl/cuda:cudart", + "@local_xla//xla/tsl/cuda:cublas", + "@local_xla//xla/tsl/cuda:cufft", + "@local_xla//xla/tsl/cuda:cusolver", + "@local_xla//xla/tsl/cuda:cusparse", + "@local_xla//xla/tsl/cuda:cudnn", + ]), +) + tf_cuda_library( name = "gpu_runtime_impl", srcs = [ @@ -159,6 +173,7 @@ tf_cuda_library( "@local_xla//xla/stream_executor/cuda:cuda_platform", "@local_xla//xla/stream_executor/gpu:gpu_stream", "@local_xla//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", + ":gpu_runtime_hermetic_cuda_deps", ], defines = if_linux_x86_64(["TF_PLATFORM_LINUX_X86_64"]), features = ["-layering_check"], diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index dcc3bfb4335d95..1c8f6283c57c07 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -29,14 +29,14 @@ namespace tensorflow { class GPUDeviceContext : public DeviceContext { public: // Does not take ownership of streams. - GPUDeviceContext(int stream_id, se::Stream* stream, + GPUDeviceContext( + int stream_id, se::Stream* stream, #if TENSORFLOW_USE_ROCM - se::Stream* nccl_stream, + se::Stream* nccl_stream, #endif - se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, - gtl::InlinedVector device_to_device_stream, - Allocator* host_memory_allocator) + se::Stream* host_to_device_stream, se::Stream* device_to_host_stream, + absl::InlinedVector device_to_device_stream, + Allocator* host_memory_allocator) : stream_id_(stream_id), stream_(stream), #if TENSORFLOW_USE_ROCM @@ -96,7 +96,7 @@ class GPUDeviceContext : public DeviceContext { // The stream to use for copying data from GPU to host. se::Stream* device_to_host_stream_; // Streams to use for copying data between GPUs. - gtl::InlinedVector device_to_device_stream_; + absl::InlinedVector device_to_device_stream_; // The allocator to use for allocating pinned host memory. // Not owned. Allocator* host_memory_allocator_; diff --git a/tensorflow/core/common_runtime/gradients.cc b/tensorflow/core/common_runtime/gradients.cc index 7c48847cb22149..b91d6986705fcc 100644 --- a/tensorflow/core/common_runtime/gradients.cc +++ b/tensorflow/core/common_runtime/gradients.cc @@ -345,7 +345,7 @@ Status SymbolicGradientBuilder::Compute() { InitBackprop(); // Backward propagation. - gtl::InlinedVector dy; + absl::InlinedVector dy; while (!ready_.empty()) { // n has collected all gradients. Node* n = ready_.front(); diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 78c2713d0d0766..3705ede827e0f6 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -370,7 +370,7 @@ class GraphConstructor { // Mapping between index within node_defs_ and the index within node_defs_ of // all nodes it outputs to. - std::vector> outputs_; + std::vector> outputs_; // Used in the conversion from node_defs_ to g_ to represent the ith input // of a node. diff --git a/tensorflow/core/common_runtime/graph_view.cc b/tensorflow/core/common_runtime/graph_view.cc index 29458524bd5051..4bbd22c89dfe6f 100644 --- a/tensorflow/core/common_runtime/graph_view.cc +++ b/tensorflow/core/common_runtime/graph_view.cc @@ -157,7 +157,7 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { // a given output slot. For all but the last, we need to do a copy of the // Tensor when propagating results downstream in the graph, but for the // last one, we can just do a move of the Tensor object to propagate it. - gtl::InlinedVector last_indices(num_outputs, nullptr); + absl::InlinedVector last_indices(num_outputs, nullptr); EdgeInfo* dst_edge = item->output_edge_base(); for (auto e : n->out_edges()) { if (e->IsControlEdge()) continue; diff --git a/tensorflow/core/common_runtime/inspecting_placer.cc b/tensorflow/core/common_runtime/inspecting_placer.cc index a84cd700874d8c..8a0eb150dd497d 100644 --- a/tensorflow/core/common_runtime/inspecting_placer.cc +++ b/tensorflow/core/common_runtime/inspecting_placer.cc @@ -77,7 +77,7 @@ class ColocationGraphToIOColocationGroups { ColocationGraph* colocation_graph) : colocation_graph_(colocation_graph), next_group_id_(0) {} - void AssignGroups(const gtl::InlinedVector& nodes, + void AssignGroups(const absl::InlinedVector& nodes, std::vector* groups) { for (int i = 0; i < nodes.size(); ++i) { int root_id = colocation_graph_->FindAndUpdateRoot(nodes[i]->id()); diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc index b20d0057c727d0..63fd2f1b59c223 100644 --- a/tensorflow/core/common_runtime/local_device.cc +++ b/tensorflow/core/common_runtime/local_device.cc @@ -126,7 +126,7 @@ LocalDevice::LocalDevice(const SessionOptions& options, // computations. static mutex& global_tp_mu = *new mutex; static auto& global_tp_info TF_GUARDED_BY(global_tp_mu) = - *new gtl::InlinedVector; + *new absl::InlinedVector; mutex_lock l(global_tp_mu); if (options.config.experimental().use_numa_affinity()) { diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc index bb6cc17b5b1665..8829f6d9d270c0 100644 --- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc +++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc @@ -38,7 +38,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass { "Parallel concat removal should happen before partitioning and a " "graph should be available."); } - gtl::InlinedVector matches; + absl::InlinedVector matches; for (Node* n : g->op_nodes()) { if (n->type_string() == "ParallelConcat") { matches.push_back(n); diff --git a/tensorflow/core/config/flag_defs.h b/tensorflow/core/config/flag_defs.h index a773fbb1b20c1c..23e9989a31edb7 100644 --- a/tensorflow/core/config/flag_defs.h +++ b/tensorflow/core/config/flag_defs.h @@ -64,6 +64,9 @@ class Flags { // TODO(b/341325107): Make this behavior the default and remove the flag. TF_DECLARE_FLAG(enable_function_pruning_before_inlining, false, "If true, functions will be pruned before inlining.") + TF_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs, false, + "If true, TF2XLA encapsulation will be skipped for non-TPU " + "graphs.") // LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc) }; diff --git a/tensorflow/core/config/flags_api_wrapper.cc b/tensorflow/core/config/flags_api_wrapper.cc index 096d48c5dc1720..060ede3846df23 100644 --- a/tensorflow/core/config/flags_api_wrapper.cc +++ b/tensorflow/core/config/flags_api_wrapper.cc @@ -55,5 +55,6 @@ PYBIND11_MODULE(flags_pybind, m) { TF_PY_DECLARE_FLAG(enable_colocation_key_propagation_in_while_op_lowering); TF_PY_DECLARE_FLAG(enable_tf2min_ici_weight) TF_PY_DECLARE_FLAG(enable_function_pruning_before_inlining) + TF_PY_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs) // LINT.ThenChange(//tensorflow/core/config/flag_defs.h) }; diff --git a/tensorflow/core/data/dataset_test_base.cc b/tensorflow/core/data/dataset_test_base.cc index 7e295e367285a3..e770b4fa9df02d 100644 --- a/tensorflow/core/data/dataset_test_base.cc +++ b/tensorflow/core/data/dataset_test_base.cc @@ -348,7 +348,7 @@ Status DatasetOpsTestBase::CreateOpKernel( Status DatasetOpsTestBase::CreateDatasetContext( OpKernel* const dateset_kernel, - gtl::InlinedVector* const inputs, + absl::InlinedVector* const inputs, std::unique_ptr* dataset_context_params, std::unique_ptr* dataset_context) { Status status = CheckOpKernelInput(*dateset_kernel, *inputs); @@ -515,13 +515,13 @@ Status DatasetOpsTestBase::RunFunction( } Status DatasetOpsTestBase::CreateOpKernelContext( - OpKernel* kernel, gtl::InlinedVector* inputs, + OpKernel* kernel, absl::InlinedVector* inputs, std::unique_ptr* context) { return CreateOpKernelContext(kernel, inputs, ¶ms_, context); } Status DatasetOpsTestBase::CreateOpKernelContext( - OpKernel* kernel, gtl::InlinedVector* inputs, + OpKernel* kernel, absl::InlinedVector* inputs, std::unique_ptr* context_params, std::unique_ptr* context) { auto params = std::make_unique(); @@ -565,7 +565,7 @@ Status DatasetOpsTestBase::CreateSerializationContext( } Status DatasetOpsTestBase::CheckOpKernelInput( - const OpKernel& kernel, const gtl::InlinedVector& inputs) { + const OpKernel& kernel, const absl::InlinedVector& inputs) { if (kernel.num_inputs() != inputs.size()) { return errors::InvalidArgument("The number of input elements should be ", kernel.num_inputs(), @@ -575,7 +575,7 @@ Status DatasetOpsTestBase::CheckOpKernelInput( } Status DatasetOpsTestBase::AddDatasetInput( - gtl::InlinedVector* inputs, DataTypeVector input_types, + absl::InlinedVector* inputs, DataTypeVector input_types, DataType dtype, const TensorShape& shape) { if (input_types.size() < inputs->size()) { return errors::InvalidArgument("Adding more inputs than types: ", @@ -862,7 +862,7 @@ Status DatasetOpsTestBase::RunDatasetOp( input_datasets.push_back(t.get()); created_tensors->push_back(std::move(t)); } - gtl::InlinedVector inputs; + absl::InlinedVector inputs; inputs.reserve(input_datasets.size()); for (auto input_dataset : input_datasets) { inputs.emplace_back(TensorValue(input_dataset)); @@ -985,7 +985,7 @@ Status DatasetOpsTestBase::MakeDatasetTensor( TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes)); auto input_tensors = dataset_params.GetInputTensors(); - gtl::InlinedVector inputs; + absl::InlinedVector inputs; inputs.reserve(input_datasets.size() + input_tensors.size()); for (auto input_dataset : input_datasets) { inputs.emplace_back(TensorValue(input_dataset)); @@ -1165,7 +1165,7 @@ std::vector TensorSliceDatasetParams::TensorSliceShapes( const std::vector& input_components) { std::vector shapes; for (const auto& component : input_components) { - gtl::InlinedVector partial_dim_sizes; + absl::InlinedVector partial_dim_sizes; for (int i = 1; i < component.dims(); ++i) { partial_dim_sizes.push_back(component.dim_size(i)); } diff --git a/tensorflow/core/data/dataset_test_base.h b/tensorflow/core/data/dataset_test_base.h index ec9805b806fe0d..e7278237d9f130 100644 --- a/tensorflow/core/data/dataset_test_base.h +++ b/tensorflow/core/data/dataset_test_base.h @@ -766,7 +766,7 @@ class DatasetOpsTestBase : public ::testing::Test { // Creates a new op kernel context. Status CreateDatasetContext( - OpKernel* dateset_kernel, gtl::InlinedVector* inputs, + OpKernel* dateset_kernel, absl::InlinedVector* inputs, std::unique_ptr* dataset_context_params, std::unique_ptr* dataset_context); @@ -798,16 +798,16 @@ class DatasetOpsTestBase : public ::testing::Test { // Checks that the size of `inputs` matches the requirement of the op kernel. Status CheckOpKernelInput(const OpKernel& kernel, - const gtl::InlinedVector& inputs); + const absl::InlinedVector& inputs); // Creates a new context for running the dataset operation. Status CreateOpKernelContext(OpKernel* kernel, - gtl::InlinedVector* inputs, + absl::InlinedVector* inputs, std::unique_ptr* context); // Creates a new context for running the dataset operation. Status CreateOpKernelContext(OpKernel* kernel, - gtl::InlinedVector* inputs, + absl::InlinedVector* inputs, std::unique_ptr* params, std::unique_ptr* context); @@ -856,7 +856,7 @@ class DatasetOpsTestBase : public ::testing::Test { // Adds an empty tensor with the specified dtype and shape to the input // vector. - Status AddDatasetInput(gtl::InlinedVector* inputs, + Status AddDatasetInput(absl::InlinedVector* inputs, DataTypeVector input_types, DataType dtype, const TensorShape& shape); diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index cc7ed17cdd767a..19345990f355f8 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -1018,7 +1018,7 @@ REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("no_compression", RandomJobSamplePercentage<0>, AllTasks); -REGISTER_DATASET_EXPERIMENT("no_compression_v2", RandomJobSamplePercentage<50>, +REGISTER_DATASET_EXPERIMENT("no_compression_v2", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>, AllTasks); diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc index e581f6e3cbe3e8..2e107eb29b0778 100644 --- a/tensorflow/core/data/dataset_utils_test.cc +++ b/tensorflow/core/data/dataset_utils_test.cc @@ -359,11 +359,10 @@ TEST_P(GetExperimentsOptTest, DatasetUtils) { auto opt_ins = test_case.opt_ins; auto opt_outs = test_case.opt_outs; if (!opt_ins.empty()) { - setenv("TF_DATA_EXPERIMENT_OPT_IN", str_util::Join(opt_ins, ",").c_str(), - 1); + setenv("TF_DATA_EXPERIMENT_OPT_IN", absl::StrJoin(opt_ins, ",").c_str(), 1); } if (!opt_outs.empty()) { - setenv("TF_DATA_EXPERIMENT_OPT_OUT", str_util::Join(opt_outs, ",").c_str(), + setenv("TF_DATA_EXPERIMENT_OPT_OUT", absl::StrJoin(opt_outs, ",").c_str(), 1); } const std::string job_name = "job"; @@ -376,14 +375,14 @@ TEST_P(GetExperimentsOptTest, DatasetUtils) { for (const auto& experiment : test_case.expected_in) { EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end()) << "experiment=" << experiment << " opt_ins={" - << str_util::Join(opt_ins, ",") << "} opt_outs={" - << str_util::Join(opt_outs, ",") << "}"; + << absl::StrJoin(opt_ins, ",") << "} opt_outs={" + << absl::StrJoin(opt_outs, ",") << "}"; } for (const auto& experiment : test_case.expected_out) { EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end()) << "experiment=" << experiment << " opt_ins={" - << str_util::Join(opt_ins, ",") << "} opt_outs={" - << str_util::Join(opt_outs, ",") << "}"; + << absl::StrJoin(opt_ins, ",") << "} opt_outs={" + << absl::StrJoin(opt_outs, ",") << "}"; } if (!opt_ins.empty()) { diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index cecae5351d765f..8309b8cdd210a6 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -680,8 +680,6 @@ cc_library( # copybara:uncomment copts = ["-Wthread-safety-analysis"], deps = [ ":credentials_factory", - "//tensorflow/core:framework", - "//tensorflow/core/data:dataset_utils", ], ) @@ -1055,6 +1053,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:platform_port", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/data/service/client/BUILD b/tensorflow/core/data/service/client/BUILD index 34b0b7a1a562de..60e2da30b5a8ac 100644 --- a/tensorflow/core/data/service/client/BUILD +++ b/tensorflow/core/data/service/client/BUILD @@ -57,7 +57,6 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:retrying_utils", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index dbd65ea37d7a61..a323a02b6096bc 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -53,7 +53,6 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" -#include "tsl/platform/host_info.h" #include "tsl/platform/retrying_utils.h" #include "tsl/protobuf/error_codes.pb.h" @@ -381,9 +380,7 @@ DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback( absl::StatusOr> DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) { if (params_.data_transfer_protocol == kLocalTransferProtocol || - // TODO(b/291994182): Use remote workers in unit tests. - (tsl::port::JobUid() != -1 && - LocalWorkers::Get(task_info.worker_address()) != nullptr)) { + ForceLocalProtocol(task_info.worker_address())) { DataTransferServerInfo info; info.set_protocol(kLocalTransferProtocol); info.set_address(task_info.worker_address()); diff --git a/tensorflow/core/data/service/py_utils.cc b/tensorflow/core/data/service/py_utils.cc index d14e1c9d1ed2cf..be5308df607f98 100644 --- a/tensorflow/core/data/service/py_utils.cc +++ b/tensorflow/core/data/service/py_utils.cc @@ -17,9 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/service/credentials_factory.h" -#include "tensorflow/core/framework/metrics.h" namespace tensorflow { namespace data { @@ -39,17 +37,5 @@ std::string DefaultProtocol() { return "grpc"; } -bool DisableCompressionAtRegistrationTime() { -#if defined(PLATFORM_GOOGLE) - if (!GetExperiments().contains("no_compression_v2")) { - return false; - } - metrics::RecordTFDataServiceCompressionAction( - "disabled_at_registration_time"); - return true; -#endif // PLATFORM_GOOGLE - return false; -} - } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/py_utils.h b/tensorflow/core/data/service/py_utils.h index 010c155022fee3..b0ea8928a3af4e 100644 --- a/tensorflow/core/data/service/py_utils.h +++ b/tensorflow/core/data/service/py_utils.h @@ -27,10 +27,6 @@ namespace data { // Returns the default protocol to use for tf.data service control flow. std::string DefaultProtocol(); -// Returns `true` if tf.data service compression is to be disabled at -// registration time. -bool DisableCompressionAtRegistrationTime(); - } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_client.cc b/tensorflow/core/data/service/worker_client.cc index 871d549a729c3b..673bc59976c814 100644 --- a/tensorflow/core/data/service/worker_client.cc +++ b/tensorflow/core/data/service/worker_client.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/errors.h" +#include "tsl/platform/host_info.h" namespace tensorflow { namespace data { @@ -91,7 +92,7 @@ Status DataServiceWorkerClient::EnsureInitialized() { } std::string DataServiceWorkerClient::GetDataTransferProtocol() const { - if (LocalWorkers::Get(address_) != nullptr) { + if (ForceLocalProtocol(address_)) { return kLocalTransferProtocol; } return transfer_protocol_; @@ -275,5 +276,13 @@ class LocalTransferClientRegistrar { }; static LocalTransferClientRegistrar local_client_registrar; +bool ForceLocalProtocol(const std::string& worker_address) { + // TODO(b/291994182): Use remote workers in unit tests. + if (tsl::port::JobUid() == -1) { + return false; + } + return LocalWorkers::Get(worker_address) != nullptr; +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_client.h b/tensorflow/core/data/service/worker_client.h index 0799ab72999044..014afdc6a98d1c 100644 --- a/tensorflow/core/data/service/worker_client.h +++ b/tensorflow/core/data/service/worker_client.h @@ -22,11 +22,8 @@ limitations under the License. #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/data_transfer.h" #include "tensorflow/core/data/service/worker.pb.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace data { @@ -85,6 +82,10 @@ CreateDataServiceWorkerClient( const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, Allocator* allocator); +// If true, clients should use local protocol for data transfer (disregarding +// any other user-specified or runtime-defaulted protocol). +bool ForceLocalProtocol(const std::string& worker_address); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc index a1dd6179cc8e50..8874484c835af0 100644 --- a/tensorflow/core/data/snapshot_utils.cc +++ b/tensorflow/core/data/snapshot_utils.cc @@ -579,7 +579,7 @@ class Reader::NestedDataset : public DatasetBase { std::vector datasets) : DatasetBase(std::move(ctx)), datasets_(datasets) { dtypes_.push_back(DT_VARIANT); - gtl::InlinedVector element_dim_sizes; + absl::InlinedVector element_dim_sizes; element_dim_sizes.push_back(1); partial_shapes_.emplace_back(element_dim_sizes); } @@ -859,9 +859,9 @@ Status CustomReader::Initialize(Env* env) { } Status CustomReader::ReadTensors(std::vector* read_tensors) { - profiler::TraceMe activity( + tsl::profiler::TraceMe activity( [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); if (version_ == 0 || compression_type_ != io::compression::kSnappy) { return ReadTensorsV0(read_tensors); } diff --git a/tensorflow/core/example/CMakeLists.txt b/tensorflow/core/example/CMakeLists.txt new file mode 100644 index 00000000000000..2450c9eddd5107 --- /dev/null +++ b/tensorflow/core/example/CMakeLists.txt @@ -0,0 +1,50 @@ +# +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if (NOT TARGET protobuf::libprotobuf) + find_package(Protobuf REQUIRED) +endif() + +set(GEN_PROTO_DIR ${CMAKE_CURRENT_BINARY_DIR}/tensorflow/core/example) + +# Generate feature proto .h, .cc and lib. +list(APPEND feature_generated_files ${GEN_PROTO_DIR}/feature.pb.h ${GEN_PROTO_DIR}/feature.pb.cc) + +add_custom_command( + OUTPUT ${feature_generated_files} + COMMAND ${Protobuf_PROTOC_EXECUTABLE} + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} --proto_path=${TENSORFLOW_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/feature.proto + DEPENDS ${Protobuf_PROTOC_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/feature.proto +) + +set_source_files_properties(${feature_generated_files} PROPERTIES GENERATED TRUE) +add_library(feature_proto ${feature_generated_files}) +target_link_libraries(feature_proto protobuf::libprotobuf) +target_include_directories(feature_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +# Generate example proto .h, .cc and lib. +list(APPEND example_generated_files ${GEN_PROTO_DIR}/example.pb.h ${GEN_PROTO_DIR}/example.pb.cc) + +add_custom_command( + OUTPUT ${example_generated_files} + COMMAND ${Protobuf_PROTOC_EXECUTABLE} + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} --proto_path=${TENSORFLOW_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/example.proto + DEPENDS ${Protobuf_PROTOC_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/example.proto ${feature_generated_files} +) + +set_source_files_properties(${example_generated_files} PROPERTIES GENERATED TRUE) +add_library(example_proto ${example_generated_files}) +target_link_libraries(example_proto feature_proto protobuf::libprotobuf) +target_include_directories(example_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) \ No newline at end of file diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc index 7e85b25a9df6f7..6557a4cec7598e 100644 --- a/tensorflow/core/framework/allocator_test.cc +++ b/tensorflow/core/framework/allocator_test.cc @@ -236,7 +236,7 @@ TEST(CPUAllocatorTest, ProfilerReporting) { // Get profiling results tensorflow::profiler::XSpace xspace; - EXPECT_EQ(OkStatus(), profiler->CollectData(&xspace)); + EXPECT_EQ(absl::OkStatus(), profiler->CollectData(&xspace)); // Validate the output const auto plane = ::tsl::profiler::FindPlaneWithName( diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index 31bddbea68f93b..351ba293276456 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -446,7 +446,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { } } - return OkStatus(); + return absl::OkStatus(); } bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { @@ -530,7 +530,7 @@ void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; } DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, FIELD) DEFINE_SET_ATTR_VALUE_ONE(const string&, s) -DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, s) +DEFINE_SET_ATTR_VALUE_LIST(absl::Span, s) DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) DEFINE_SET_ATTR_VALUE_BOTH(int64_t, i) DEFINE_SET_ATTR_VALUE_BOTH(int32_t, i) @@ -545,7 +545,7 @@ void SetAttrValue(const tstring& value, AttrValue* out) { out->set_s(value.data(), value.size()); } -void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); for (const auto& v : value) { out->mutable_list()->add_s(v.data(), v.size()); @@ -556,7 +556,7 @@ void SetAttrValue(StringPiece value, AttrValue* out) { out->set_s(value.data(), value.size()); } -void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { out->mutable_list()->add_s(v.data(), v.size()); @@ -582,21 +582,21 @@ void SetAttrValue(const PartialTensorShape& value, AttrValue* out) { value.AsProto(out->mutable_shape()); } -void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { v.AsProto(out->mutable_list()->add_shape()); } } -void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { *out->mutable_list()->add_shape() = v; } } -void SetAttrValue(const gtl::ArraySlice value, +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { @@ -612,7 +612,7 @@ void SetAttrValue(const Tensor& value, AttrValue* out) { } } -void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { if (v.NumElements() > 1) { @@ -627,7 +627,7 @@ void SetAttrValue(const TensorProto& value, AttrValue* out) { *out->mutable_tensor() = value; } -void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { *out->mutable_list()->add_tensor() = v; @@ -638,7 +638,7 @@ void SetAttrValue(const NameAttrList& value, AttrValue* out) { *out->mutable_func() = value; } -void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { *out->mutable_list()->add_func() = v; diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index bf99f54baa6b45..996acd12d78b3b 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -207,7 +207,7 @@ Status CollectiveRegistry::Register(const string& collective_name, collective_name); } registry->emplace_back(collective_name, std::move(factory)); - return OkStatus(); + return absl::OkStatus(); } /*static*/ @@ -222,7 +222,7 @@ Status CollectiveRegistry::LookupHelper( } else { *implementation = reg_info.factory(); } - return OkStatus(); + return absl::OkStatus(); } } return errors::Internal( diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index be1bfc2581e2aa..b400203013b0b2 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -84,7 +84,7 @@ Status GetWindowedOutputSizeFromDimsV2( /*evenly_divisible=*/false, output_size)); break; } - return OkStatus(); + return absl::OkStatus(); } Status GetWindowedOutputSizeFromDims( @@ -112,7 +112,7 @@ Status UnchangedShape(shape_inference::InferenceContext* c) { if (handle_data != nullptr) { c->set_output_handle_shapes_and_types(0, *handle_data); } - return OkStatus(); + return absl::OkStatus(); } Status MatMulShape(shape_inference::InferenceContext* c) { @@ -135,7 +135,7 @@ Status MatMulShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged)); c->set_output(0, c->Matrix(output_rows, output_cols)); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -155,7 +155,7 @@ Status ValidateEinsumEllipsis(absl::string_view subscript, "Periods found outside of ellipsis in subscript: ", subscript); } *found_ellipsis = num_periods > 0; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -166,7 +166,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { // more latin alphabets and contains at most one ellipsis ('...'). string equation; TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation)); - gtl::InlinedVector input_labels; + absl::InlinedVector input_labels; string output_labels; TF_RETURN_IF_ERROR( ValidateEinsumEquation(equation, &input_labels, &output_labels)); @@ -185,7 +185,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { // Validate input subscripts, build the label to dimension mapping and obtain // the broadcast shapes that map to ellipsis. absl::flat_hash_map label_to_dimension; - gtl::InlinedVector input_bcast_shapes(c->num_inputs()); + absl::InlinedVector input_bcast_shapes(c->num_inputs()); for (int i = 0, end = c->num_inputs(); i < end; ++i) { bool has_ellipsis = false; TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis)); @@ -276,7 +276,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { // unknown, then the output shape should have unknown rank. if (!c->RankKnown(output_bcast_shape)) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } } else { // If the output subscripts don't have ellipsis then make sure the output @@ -311,7 +311,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { output_dims.push_back(dimension_it->second); } c->set_output(0, c->MakeShape(output_dims)); - return OkStatus(); + return absl::OkStatus(); } Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { @@ -348,7 +348,7 @@ Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status BatchMatMulShape(shape_inference::InferenceContext* c) { @@ -382,7 +382,7 @@ Status BatchMatMulShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR( c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // -------------------------------------------------------------------------- @@ -407,7 +407,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) { // If rank unknown, return unknown shape. if (!c->RankKnown(input_shape)) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } // Output has the same shape as the input, and matches the length of @@ -443,7 +443,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) { } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status BiasAddGradShape(shape_inference::InferenceContext* c) { @@ -460,7 +460,7 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(c->Dim(input_shape, -1))); } - return OkStatus(); + return absl::OkStatus(); } Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, @@ -479,7 +479,7 @@ Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, } } - return OkStatus(); + return absl::OkStatus(); } Status DatasetIteratorShape(shape_inference::InferenceContext* c) { @@ -499,7 +499,7 @@ Status DatasetIteratorShape(shape_inference::InferenceContext* c) { output_shapes[i], &output_shape_handle)); c->set_output(static_cast(i), output_shape_handle); } - return OkStatus(); + return absl::OkStatus(); } Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, @@ -524,12 +524,12 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, context->MakeDim(spatial[spatial_dim]); } *out = context->MakeShape(dims_actual); - return OkStatus(); + return absl::OkStatus(); } Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, DimensionHandle* batch_dim, - gtl::MutableArraySlice spatial_dims, + absl::Span spatial_dims, DimensionHandle* filter_dim, InferenceContext* context) { const int32_t rank = @@ -550,12 +550,12 @@ Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)), filter_dim)); } - return OkStatus(); + return absl::OkStatus(); } // vect_size must be provided if format is NCHW_VECT_C. Status ShapeFromDimensions(DimensionHandle batch_dim, - gtl::ArraySlice spatial_dims, + absl::Span spatial_dims, DimensionHandle filter_dim, TensorFormat format, absl::optional vect_size, InferenceContext* context, ShapeHandle* shape) { @@ -585,7 +585,7 @@ Status ShapeFromDimensions(DimensionHandle batch_dim, } *shape = context->MakeShape(out_dims); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -652,7 +652,7 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c, DimensionHandle batch_size_dim; DimensionHandle input_depth_dim; - gtl::InlinedVector input_spatial_dims(2); + absl::InlinedVector input_spatial_dims(2); TF_RETURN_IF_ERROR(DimensionsFromShape( conv_input_shape, data_format, &batch_size_dim, absl::MakeSpan(input_spatial_dims), &input_depth_dim, c)); @@ -760,7 +760,7 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c, batch_size_dim, {output_rows, output_cols}, output_depth_dim, data_format, vect_size, c, &output_shape)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -777,7 +777,7 @@ Status ConvShape(shape_inference::InferenceContext* c) { if (input_rank == InferenceContext::kUnknownRank || filter_rank == InferenceContext::kUnknownRank) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } int batch_dims; @@ -981,7 +981,7 @@ Status ConvShape(shape_inference::InferenceContext* c) { output_shape = c->MakeShape(output_shape_vector); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } // Shape function for Conv2D-like operations that support explicit padding. @@ -1107,7 +1107,7 @@ Status Conv3DShape(shape_inference::InferenceContext* c) { output_cols, output_depth_dim}); } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { @@ -1130,7 +1130,7 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { DimensionHandle batch_size_dim; DimensionHandle output_grad_depth_dim; - gtl::InlinedVector output_grad_spatial_dims(2); + absl::InlinedVector output_grad_spatial_dims(2); TF_RETURN_IF_ERROR(DimensionsFromShape( output_grad_shape, data_format, &batch_size_dim, absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c)); @@ -1151,7 +1151,7 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { // input_grad_depth_dim from input_sizes; otherwise we compute it as // c->Dim(filter_shape,2). DimensionHandle input_grad_depth_dim; - gtl::InlinedVector specified_input_grad_spatial_dims(2); + absl::InlinedVector specified_input_grad_spatial_dims(2); int specified_input_grad_rank = c->Rank(specified_input_grad_shape); if (specified_input_grad_rank == 4) { DimensionHandle specified_batch_size_dim; @@ -1179,7 +1179,7 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim, data_format, /*vect_size=*/absl::nullopt, c, &input_grad_shape)); c->set_output(0, input_grad_shape); - return OkStatus(); + return absl::OkStatus(); } Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) { @@ -1198,7 +1198,7 @@ Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh)); TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh)); c->set_output(0, sh); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -1320,7 +1320,7 @@ Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c, c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } }; // namespace @@ -1400,7 +1400,7 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { {output_rows, output_cols}, depth_dim, &output_shape, c)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status AvgPoolGradShape(shape_inference::InferenceContext* c) { @@ -1408,7 +1408,7 @@ Status AvgPoolGradShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormShape(shape_inference::InferenceContext* c) { @@ -1450,13 +1450,13 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) { c->set_output(2, vector_shape); c->set_output(3, vector_shape); c->set_output(4, vector_shape); - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(FusedBatchNormShape(c)); c->set_output(5, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { @@ -1481,7 +1481,7 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { "_FusedBatchNormEx channel dimension must be divisible by 4."); } - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { @@ -1522,7 +1522,7 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { c->set_output(2, c->Vector(channel_dim)); c->set_output(3, c->Vector(0)); c->set_output(4, c->Vector(0)); - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { @@ -1531,7 +1531,7 @@ Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { int num_side_inputs; TF_RETURN_IF_ERROR(c->GetAttr("num_side_inputs", &num_side_inputs)); if (num_side_inputs == 0) { - return OkStatus(); + return absl::OkStatus(); } string data_format_str; @@ -1558,7 +1558,7 @@ Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { &side_input_backprop)); c->set_output(5, side_input_backprop); - return OkStatus(); + return absl::OkStatus(); } Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, @@ -1581,7 +1581,7 @@ Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, num_elements, " elements."); } } - return OkStatus(); + return absl::OkStatus(); } Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { @@ -1594,7 +1594,7 @@ Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || diag_index_tensor == nullptr) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } int32_t lower_diag_index = 0; int32_t upper_diag_index = 0; @@ -1634,7 +1634,7 @@ Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { } dims.push_back(c->MakeDim(max_diag_len)); c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { @@ -1651,7 +1651,7 @@ Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || diag_index_tensor == nullptr) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } int32_t lower_diag_index = 0; int32_t upper_diag_index = 0; @@ -1735,7 +1735,7 @@ Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { output_col_dim, &output_shape)); } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { @@ -1807,7 +1807,7 @@ Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape)); } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, @@ -1903,7 +1903,7 @@ Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, output_depth, &output_shape, c)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status MaxPoolShape(shape_inference::InferenceContext* c) { @@ -1954,7 +1954,7 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2); if (kernel_sizes_tensor == nullptr) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); auto kernel_sizes_vec = kernel_sizes_tensor->flat(); @@ -1964,7 +1964,7 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); if (strides_tensor == nullptr) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } strides.resize(strides_tensor->shape().num_elements()); auto strides_vec = strides_tensor->flat(); @@ -2017,7 +2017,7 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { output_depth, &output_shape, c)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status Pool3DShape(shape_inference::InferenceContext* c) { @@ -2099,7 +2099,7 @@ Status Pool3DShape(shape_inference::InferenceContext* c) { } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status MaxPool3DGradShape(shape_inference::InferenceContext* c) { @@ -2111,14 +2111,14 @@ Status AvgPool3DGradShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); } Status UnknownShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->UnknownShape()); } - return OkStatus(); + return absl::OkStatus(); } template @@ -2141,7 +2141,7 @@ Status ReductionShapeHelper(const Tensor* reduction_indices_t, true_indices->insert(wrapped_index); } - return OkStatus(); + return absl::OkStatus(); } Status ReductionShape(InferenceContext* c) { @@ -2167,7 +2167,7 @@ Status ReductionShape(InferenceContext* c) { if (keep_dims && c->RankKnown(input)) { // output rank matches input input if . c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); - return OkStatus(); + return absl::OkStatus(); } else { return shape_inference::UnknownShape(c); } @@ -2198,7 +2198,7 @@ Status ReductionShape(InferenceContext* c) { } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } Status ConcatShapeHelper(InferenceContext* c, int start_value_index, @@ -2220,7 +2220,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, } if (rank == InferenceContext::kUnknownRank) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } else if (rank == 0) { return errors::InvalidArgument( "Can't concatenate scalars (use tf.stack instead)"); @@ -2235,7 +2235,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, dims.reserve(rank); for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } // Merge all the non-concat dims, and sum the concat dim to make an output @@ -2286,7 +2286,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, c->Concatenate(output_before, c->Vector(output_middle), &s)); TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); } Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { @@ -2315,7 +2315,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, CHECK_NOTNULL(out); if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { *out = c->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } const int32_t rank_x = c->Rank(shape_x); const int32_t rank_y = c->Rank(shape_y); @@ -2347,13 +2347,13 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, if (c->Value(dim_x) > 1) { if (!incompatible_shape_error) { *out = c->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } dims.push_back(dim_x); } else if (c->Value(dim_y) > 1) { if (!incompatible_shape_error) { *out = c->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } dims.push_back(dim_y); } else if (c->Value(dim_x) == 1) { @@ -2367,7 +2367,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, } else { if (!incompatible_shape_error) { *out = c->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } dims.push_back(c->UnknownDim()); } @@ -2386,7 +2386,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, if (!s.ok()) { if (!incompatible_shape_error) { *out = c->MakeShape({}); - return OkStatus(); + return absl::OkStatus(); } return s; } @@ -2395,14 +2395,14 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, } *out = c->MakeShape(dims); - return OkStatus(); + return absl::OkStatus(); } Status RandomShape(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { @@ -2433,7 +2433,7 @@ Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { out = c->UnknownShape(); } c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -2463,7 +2463,7 @@ Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -2507,7 +2507,7 @@ Status SliceShape(InferenceContext* c) { SliceHelper(c, begin_value, sizes_value, &dims)); } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } else { // In case `sizes` is not available (`sizes_value` is null), // we could try to use `MakeShapeFromShapeTensor` here. @@ -2529,18 +2529,18 @@ Status SliceShape(InferenceContext* c) { dims.emplace_back(c->Dim(sizes_value, i)); } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } // We might know the rank of the input. if (c->RankKnown(input)) { c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); - return OkStatus(); + return absl::OkStatus(); } else { return shape_inference::UnknownShape(c); } } - return OkStatus(); + return absl::OkStatus(); } Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, @@ -2581,7 +2581,7 @@ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, } } - return OkStatus(); + return absl::OkStatus(); } Status ValidateVariableResourceHandle( @@ -2601,7 +2601,7 @@ Status ValidateVariableResourceHandle( DataTypeString(value_dtype)); } } - return OkStatus(); + return absl::OkStatus(); } Status GatherNdShape(InferenceContext* c) { @@ -2620,7 +2620,7 @@ Status GatherNdShape(InferenceContext* c) { if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } if (c->Value(r_dim) > c->Rank(params)) { @@ -2637,7 +2637,7 @@ Status GatherNdShape(InferenceContext* c) { ShapeHandle out; TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, @@ -2700,7 +2700,7 @@ Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, // This is called for tf.scatter_nd; output is a tensor with this shape. c->set_output(0, input_shape); } - return OkStatus(); + return absl::OkStatus(); } Status ExplicitShape(InferenceContext* c) { @@ -2709,7 +2709,7 @@ Status ExplicitShape(InferenceContext* c) { ShapeHandle output_shape; TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status ExplicitShapes(InferenceContext* c) { @@ -2724,7 +2724,7 @@ Status ExplicitShapes(InferenceContext* c) { c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape)); c->set_output(i, output_shape); } - return OkStatus(); + return absl::OkStatus(); } Status SparseReduceShapeFn(InferenceContext* c) { @@ -2770,7 +2770,7 @@ Status SparseReduceShapeFn(InferenceContext* c) { } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } return UnknownShape(c); } @@ -2784,7 +2784,7 @@ Status QuantizedConv2DShape(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); c->set_output(1, c->Scalar()); c->set_output(2, c->Scalar()); - return OkStatus(); + return absl::OkStatus(); } Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { @@ -2831,19 +2831,19 @@ Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { c->set_output(1, channel); c->set_output(2, channel); } - return OkStatus(); + return absl::OkStatus(); } Status FusedQuantizedConv2DShape(InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShapeImpl(c, true)); TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4)); - return OkStatus(); + return absl::OkStatus(); } Status FusedQuantizedDepthwiseConv2D(InferenceContext* c) { TF_RETURN_IF_ERROR(DepthwiseConv2DNativeShapeImpl(c, true)); TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4)); - return OkStatus(); + return absl::OkStatus(); } Status QuantizedAvgPoolShape(InferenceContext* c) { @@ -2853,7 +2853,7 @@ Status QuantizedAvgPoolShape(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); c->set_output(1, c->Scalar()); c->set_output(2, c->Scalar()); - return OkStatus(); + return absl::OkStatus(); } Status QuantizeV2Shape(InferenceContext* c) { @@ -2879,7 +2879,7 @@ Status QuantizeV2Shape(InferenceContext* c) { } c->set_output(1, minmax); c->set_output(2, minmax); - return OkStatus(); + return absl::OkStatus(); } Status ReduceScatterShape(shape_inference::InferenceContext* c) { @@ -2887,7 +2887,7 @@ Status ReduceScatterShape(shape_inference::InferenceContext* c) { if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. c->set_output(0, in); - return OkStatus(); + return absl::OkStatus(); } shape_inference::ShapeHandle group_assignment_shape = c->input(1); @@ -2898,7 +2898,7 @@ Status ReduceScatterShape(shape_inference::InferenceContext* c) { const Tensor* scatter_dimension = c->input_tensor(2); if (!scatter_dimension) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } int64_t scatter_dim; TF_RETURN_IF_ERROR(c->GetScalarFromTensor(scatter_dimension, &scatter_dim)); @@ -2919,7 +2919,7 @@ Status ReduceScatterShape(shape_inference::InferenceContext* c) { } } c->set_output(0, c->MakeShape(out_dims)); - return OkStatus(); + return absl::OkStatus(); } } // namespace shape_inference diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index ce65aa99d13706..f1d43d6c2abfd3 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -52,7 +52,7 @@ inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Transfers shape of input(0) to output(0), after asserting its rank >= . @@ -61,7 +61,7 @@ inline Status UnchangedShapeWithRankAtLeast( ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Transfers shape of input(0) to output(0), after asserting its rank <= . @@ -70,18 +70,18 @@ inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c, ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Shape function for use with ops no outputs. inline Status NoOutputs(shape_inference::InferenceContext* c) { - return OkStatus(); + return absl::OkStatus(); } // Shape function for ops that output a single scalar value. inline Status ScalarShape(shape_inference::InferenceContext* c) { c->set_output(0, c->Scalar()); - return OkStatus(); + return absl::OkStatus(); } // Shape function for binary ops where both inputs and the output match. @@ -89,7 +89,7 @@ inline Status MergeBothInputsShapeFn(InferenceContext* c) { ShapeHandle out; TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Shape function for dataset iterators. @@ -240,7 +240,7 @@ inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( c, c->input(0), c->input(1), true, &out)); c->set_output(output_index, out); - return OkStatus(); + return absl::OkStatus(); } // Shape function for binary operators that broadcast their inputs. diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 188e9813359e9a..4fd31ab201458e 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -211,7 +211,7 @@ static Status WrappedDatasetVariantDeviceCopy( const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { *to = WrappedDatasetVariantWrapper(from); - return OkStatus(); + return absl::OkStatus(); } #define REGISTER_OPTIONAL_COPY(DIRECTION) \ @@ -248,7 +248,7 @@ Status GraphDefBuilderWrapper::AddDataset( Status GraphDefBuilderWrapper::AddDataset( const DatasetBase* dataset, const std::vector>& inputs, - const std::vector>>& list_inputs, + const std::vector>>& list_inputs, const std::vector>& attrs, Node** output) { return AddDataset(dataset, inputs, list_inputs, attrs, @@ -258,7 +258,7 @@ Status GraphDefBuilderWrapper::AddDataset( Status GraphDefBuilderWrapper::AddDataset( const DatasetBase* dataset, const std::vector>& inputs, - const std::vector>>& list_inputs, + const std::vector>>& list_inputs, const std::vector>& attrs, bool use_dataset_name, Node** output) { auto& type_string = dataset->type_string(); @@ -320,7 +320,7 @@ Status GraphDefBuilderWrapper::AddDataset( return errors::Internal("AddDataset: Failed to build ", type_string, " op with error ", opts->StatusToString()); } - return OkStatus(); + return absl::OkStatus(); } Status GraphDefBuilderWrapper::AddFunction( @@ -329,7 +329,7 @@ Status GraphDefBuilderWrapper::AddFunction( if (b_->HasFunction(function_name)) { VLOG(1) << "Function with name " << function_name << "already exists in" << " the graph. It will not be added again."; - return OkStatus(); + return absl::OkStatus(); } const FunctionDef* f_def = lib_def.Find(function_name); if (f_def == nullptr) { @@ -363,7 +363,7 @@ Status GraphDefBuilderWrapper::AddFunction( for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) { TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def)); } - return OkStatus(); + return absl::OkStatus(); } void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val, @@ -529,7 +529,7 @@ Status MemoryCheckpoint::Save(IteratorStateWriter* writer) const { auto [prefix, key] = id_registry_->Get(id); TF_RETURN_IF_ERROR(writer->WriteTensor(prefix, key, value)); } - return OkStatus(); + return absl::OkStatus(); } Status IteratorBase::InitializeBase(IteratorContext* ctx, @@ -551,7 +551,7 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx, cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); }); } } - return OkStatus(); + return absl::OkStatus(); } Status GetCompressedElementFromVariantTensor( @@ -569,7 +569,7 @@ Status GetCompressedElementFromVariantTensor( "Tensor must be a `CompressedElement` object."); } *out_compressed_element = compressed_element; - return OkStatus(); + return absl::OkStatus(); } int64_t GetAllocatedBytes(const std::vector& element) { @@ -619,7 +619,7 @@ int64_t GetTotalBytes(const std::vector& element) { } std::string FullName(const std::string& prefix, const std::string& name) { - if (str_util::StrContains(name, kColon)) { + if (absl::StrContains(name, kColon)) { LOG(ERROR) << name << " should not contain " << kColon; } @@ -627,7 +627,7 @@ std::string FullName(const std::string& prefix, const std::string& name) { } Status ExtractIteratorPrefix(StringPiece key, string* prefix) { - if (!str_util::StartsWith(key, data::kFullNameRandomHex)) { + if (!absl::StartsWith(key, data::kFullNameRandomHex)) { return errors::InvalidArgument("Key: ", key, " was not generated using full_name."); } @@ -639,7 +639,7 @@ Status ExtractIteratorPrefix(StringPiece key, string* prefix) { string real_key = split_keys[1]; const int pos = real_key.rfind(kColon); *prefix = real_key.substr(0, pos); - return OkStatus(); + return absl::OkStatus(); } Status GetDatasetFromVariantTensor(const Tensor& tensor, @@ -658,7 +658,7 @@ Status GetDatasetFromVariantTensor(const Tensor& tensor, if (*out_dataset == nullptr) { return errors::Internal("Read uninitialized Dataset variant."); } - return OkStatus(); + return absl::OkStatus(); } Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { @@ -668,7 +668,7 @@ Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { "Dataset tensor must be a scalar of dtype DT_VARIANT."); } tensor->scalar()() = DatasetVariantWrapper(dataset); - return OkStatus(); + return absl::OkStatus(); } namespace internal { @@ -792,12 +792,12 @@ Status DatasetBase::ComputeNumSources() { } if (num_sources_ >= 0) { // Already computed. - return OkStatus(); + return absl::OkStatus(); } num_sources_ = 0; if (inputs.empty()) { num_sources_ = 1; - return OkStatus(); + return absl::OkStatus(); } for (const auto& input : inputs) { if (input->num_sources() < 0) { @@ -808,7 +808,7 @@ Status DatasetBase::ComputeNumSources() { } num_sources_ += input->num_sources(); } - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { @@ -826,7 +826,7 @@ Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { return errors::OutOfRange("Index out of range [0, ", cardinality, "):", index); } - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::Get(OpKernelContext* ctx, int64 index, @@ -859,7 +859,7 @@ Status DatasetBase::MergeOptionsFromInputs() { return s; } if (inputs.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Merge options from inputs sequentially before merging options from dataset. // Since the last options merged takes precedence, the options that may be set @@ -871,7 +871,7 @@ Status DatasetBase::MergeOptionsFromInputs() { } internal::MergeOptions(options_, &merged_options); options_ = merged_options; - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::MakeIterator( @@ -883,12 +883,12 @@ Status DatasetBase::MakeIterator( Status s = InputDatasets(&inputs); return inputs[0]->MakeIterator(ctx, parent, output_prefix, iterator); } - profiler::TraceMe traceme( + tsl::profiler::TraceMe traceme( [&] { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( strings::StrCat("MakeIterator::", type_string()), {}); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); *iterator = MakeIteratorInternal(output_prefix); Status s = (*iterator)->InitializeBase(ctx, parent); if (s.ok()) { @@ -995,7 +995,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( << " will not be optimized because the dataset does not implement " "the " "AsGraphDefInternal() method needed to apply optimizations."; - return OkStatus(); + return absl::OkStatus(); } } return status; @@ -1033,7 +1033,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddIdentity( *output = ops::UnaryOp("Identity", *input, builder()->opts().WithName(UniqueNodeName(name_prefix))); - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( @@ -1055,7 +1055,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( opts.op_registry()); node_builder.Input(std::move(nodes)); *output = opts.FinalizeBuilder(&node_builder); - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper( @@ -1138,8 +1138,8 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx, "Iterator::GetNext", activity_watcher::ActivityCategory::kDatasetOp, std::move(attributes)); }); - profiler::TraceMe activity([&] { return BuildTraceMeName(); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMe activity([&] { return BuildTraceMeName(); }, + tsl::profiler::TraceMeLevel::kInfo); DVLOG(3) << prefix() << " GetNext enter"; auto model = ctx->model(); bool output_was_recording = @@ -1189,8 +1189,8 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx, Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence, int* num_skipped) { - profiler::TraceMe activity([&] { return BuildTraceMeName(); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMe activity([&] { return BuildTraceMeName(); }, + tsl::profiler::TraceMeLevel::kInfo); DVLOG(3) << prefix() << " Skip enter"; auto model = ctx->model(); bool output_was_recording = @@ -1232,7 +1232,7 @@ Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip, std::vector out_tensors; TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } // RecordElement is used to count the number of element computed and // help calculate the CPU time spent on a given iterator to do the @@ -1244,7 +1244,7 @@ Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip, RecordElement(ctx, &out_tensors); (*num_skipped)++; } - return OkStatus(); + return absl::OkStatus(); } void DatasetOpKernel::Compute(OpKernelContext* ctx) { @@ -1269,7 +1269,7 @@ void DatasetOpKernel::Compute(OpKernelContext* ctx) { string DatasetOpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const { - return profiler::TraceMeOp(name_view(), type_string_view()); + return tsl::profiler::TraceMeOp(name_view(), type_string_view()); } // static diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index ca9e5a639a8ba0..03470e6dd298f9 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -231,7 +231,7 @@ class GraphDefBuilderWrapper { if (*output == nullptr) { return errors::Internal("AddScalar: Failed to build Const op."); } - return OkStatus(); + return absl::OkStatus(); } // Adds a Const node with vector value to the Graph. @@ -250,7 +250,7 @@ class GraphDefBuilderWrapper { if (*output == nullptr) { return errors::Internal("AddVector: Failed to build Const op."); } - return OkStatus(); + return absl::OkStatus(); } Status AddVector(const std::vector& val, Node** output) { @@ -263,7 +263,7 @@ class GraphDefBuilderWrapper { if (*output == nullptr) { return errors::Internal("AddVector: Failed to build Const op."); } - return OkStatus(); + return absl::OkStatus(); } // Adds a `Const` node for the given tensor value to the graph. @@ -276,7 +276,7 @@ class GraphDefBuilderWrapper { if (*output == nullptr) { return errors::Internal("AddTensor: Failed to build Const op."); } - return OkStatus(); + return absl::OkStatus(); } // Adds a `Placeholder` node for the given tensor value to the graph. @@ -290,7 +290,7 @@ class GraphDefBuilderWrapper { return errors::Internal( "AddPlaceholder: Failed to build Placeholder op."); } - return OkStatus(); + return absl::OkStatus(); } // Adds a node for the given dataset to the `Graph`. The value of @@ -319,13 +319,15 @@ class GraphDefBuilderWrapper { Status AddDataset( const DatasetBase* dataset, const std::vector>& inputs, - const std::vector>>& list_inputs, + const std::vector>>& + list_inputs, const std::vector>& attrs, Node** output); Status AddDataset( const DatasetBase* dataset, const std::vector>& inputs, - const std::vector>>& list_inputs, + const std::vector>>& + list_inputs, const std::vector>& attrs, bool use_dataset_name, Node** output); @@ -378,7 +380,7 @@ class GraphDefBuilderWrapper { TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def)); } } - return OkStatus(); + return absl::OkStatus(); } GraphDefBuilder* b_; @@ -501,7 +503,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { Status WriteScalar(StringPiece name, StringPiece key, int64_t val) override { auto id = id_registry_->Add(string(name), string(key)); int_values_[id] = val; - return OkStatus(); + return absl::OkStatus(); } Status WriteScalar(StringPiece key, const tstring& val) override { string prefix; @@ -512,7 +514,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { const tstring& val) override { auto id = id_registry_->Add(string(name), string(key)); str_values_[id] = val; - return OkStatus(); + return absl::OkStatus(); } Status WriteTensor(StringPiece key, const Tensor& val) override { string prefix; @@ -523,7 +525,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { const Tensor& val) override { auto id = id_registry_->Add(string(name), string(key)); tensor_values_[id] = val; - return OkStatus(); + return absl::OkStatus(); } // END implementation of `IteratorStateWriter` interface @@ -554,7 +556,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { : is_root_(is_root), id_registry_(registry) {} void operator=(const MemoryCheckpoint&) = delete; - Status status_ = OkStatus(); + Status status_ = absl::OkStatus(); // Only set to true for the checkpoint in IteratorResource. // Root checkpoint does not track expired prefixes. const bool is_root_ = false; @@ -579,10 +581,10 @@ class SerializationContext { switch (params_.external_state_policy) { case ExternalStatePolicy::POLICY_WARN: LOG(WARNING) << s.ToString(); - return OkStatus(); + return absl::OkStatus(); case ExternalStatePolicy::POLICY_IGNORE: VLOG(2) << "Ignoring error status: " << s.ToString(); - return OkStatus(); + return absl::OkStatus(); case ExternalStatePolicy::POLICY_FAIL: return s; default: @@ -1117,7 +1119,7 @@ class IteratorBase : public Checkpointable { // Performs initialization that needs to happen outside of a constructor to // properly propagate errors. - virtual Status Initialize(IteratorContext* ctx) { return OkStatus(); } + virtual Status Initialize(IteratorContext* ctx) { return absl::OkStatus(); } // Performs initialization of the base iterator. Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent); @@ -1128,7 +1130,7 @@ class IteratorBase : public Checkpointable { TF_RETURN_IF_ERROR(SaveInternal(ctx, writer)); VLOG(1) << "Saved " << prefix() << " in " << (EnvTime::NowMicros() - start_us) << "us"; - return OkStatus(); + return absl::OkStatus(); } // Restores the state of this iterator. @@ -1138,7 +1140,7 @@ class IteratorBase : public Checkpointable { ctx->SaveCheckpoint(this); VLOG(1) << "Restored " << prefix() << " in " << (EnvTime::NowMicros() - start_us) << "us"; - return OkStatus(); + return absl::OkStatus(); } // Returns the total number of bytes buffered by the iterator across all nodes @@ -1158,7 +1160,7 @@ class IteratorBase : public Checkpointable { Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer, const std::unique_ptr& input) { if (ctx->symbolic_checkpoint()) { - return OkStatus(); + return absl::OkStatus(); } return input->Save(ctx, writer); } @@ -1326,7 +1328,7 @@ class DatasetBase : public core::RefCounted { TF_RETURN_IF_ERROR(it->Restore(&restore_ctx, reader)); ctx->MergeCheckpoint(restore_ctx.checkpoint()); *iterator = std::move(it); - return OkStatus(); + return absl::OkStatus(); } Status MakeIteratorFromCheckpoint( @@ -1699,7 +1701,7 @@ Status ParseScalarArgument(OpKernelContext* ctx, return errors::InvalidArgument(argument_name, " must be a scalar"); } *output = argument_t->scalar()(); - return OkStatus(); + return absl::OkStatus(); } template @@ -1716,7 +1718,7 @@ Status ParseVectorArgument(OpKernelContext* ctx, for (int i = 0; i < size; ++i) { output->push_back(argument_t->vec()(i)); } - return OkStatus(); + return absl::OkStatus(); } // Encapsulates the work required to plug a DatasetBase into the core TensorFlow diff --git a/tensorflow/core/framework/dataset_stateful_op_allowlist.h b/tensorflow/core/framework/dataset_stateful_op_allowlist.h index 5e8cdd4af32a19..b92acf5fb74972 100644 --- a/tensorflow/core/framework/dataset_stateful_op_allowlist.h +++ b/tensorflow/core/framework/dataset_stateful_op_allowlist.h @@ -27,12 +27,12 @@ class AllowlistedStatefulOpRegistry { public: Status Add(string op_name) { op_names_.insert(std::move(op_name)); - return OkStatus(); + return absl::OkStatus(); } Status Remove(string op_name) { op_names_.erase(op_name); - return OkStatus(); + return absl::OkStatus(); } bool Contains(const string& op_name) { return op_names_.count(op_name); } diff --git a/tensorflow/core/framework/device.h b/tensorflow/core/framework/device.h index 6cdcd2efd90ab9..08231d55d3a160 100644 --- a/tensorflow/core/framework/device.h +++ b/tensorflow/core/framework/device.h @@ -142,7 +142,7 @@ class Device : public DeviceBase { // 'graph' supplies the partition of the graph assigned to this // device. virtual Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { - return OkStatus(); + return absl::OkStatus(); } // Sets `out_context` a new DeviceContext* for executing a graph, or nullptr @@ -153,7 +153,7 @@ class Device : public DeviceBase { // and should call Unref(). virtual Status TryGetDeviceContext(DeviceContext** out_context) { *out_context = nullptr; - return OkStatus(); + return absl::OkStatus(); } // Returns the op segment of this device. The caller can reuse op diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index c8fbf9e1635296..065707fde4b8c2 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -229,7 +229,7 @@ class DeviceBase { PerOpGpuDevice* /*device*/, DeviceContext* /*dc*/, Allocator* /*allocator*/) { - return OkStatus(); + return absl::OkStatus(); } // Unimplemented by default diff --git a/tensorflow/core/framework/device_factory.cc b/tensorflow/core/framework/device_factory.cc index 43ad12393ac9a3..e39d768a56c785 100644 --- a/tensorflow/core/framework/device_factory.cc +++ b/tensorflow/core/framework/device_factory.cc @@ -151,7 +151,7 @@ Status DeviceFactory::ListAllPhysicalDevices(std::vector* devices) { } } - return OkStatus(); + return absl::OkStatus(); } Status DeviceFactory::ListPluggablePhysicalDevices( @@ -163,7 +163,7 @@ Status DeviceFactory::ListPluggablePhysicalDevices( TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices)); } } - return OkStatus(); + return absl::OkStatus(); } Status DeviceFactory::GetAnyDeviceDetails( @@ -223,7 +223,7 @@ Status DeviceFactory::AddCpuDevices( return errors::NotFound("No CPU devices are available in this process"); } - return OkStatus(); + return absl::OkStatus(); } Status DeviceFactory::AddDevices( @@ -259,7 +259,7 @@ Status DeviceFactory::AddDevices( } } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr DeviceFactory::NewDevice(const string& type, diff --git a/tensorflow/core/framework/device_factory.h b/tensorflow/core/framework/device_factory.h index c238aebf475bd2..7957af3cbad869 100644 --- a/tensorflow/core/framework/device_factory.h +++ b/tensorflow/core/framework/device_factory.h @@ -85,7 +85,7 @@ class DeviceFactory { // into devices from ListPhysicalDevices. virtual Status GetDeviceDetails(int device_index, std::unordered_map* details) { - return OkStatus(); + return absl::OkStatus(); } // Most clients should call AddDevices() instead. diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc index 415125c73b1bf3..bf7edef06ddae9 100644 --- a/tensorflow/core/framework/fake_input.cc +++ b/tensorflow/core/framework/fake_input.cc @@ -108,14 +108,14 @@ Status FakeInputImpl::AddInputToBuilder() { "': ", status.message()); } SourceList(dts); - return OkStatus(); + return absl::OkStatus(); } DataType dt; TF_RETURN_IF_ERROR(GetDataType(&dt)); builder_->Input(in_node_, 0, dt); } - return OkStatus(); + return absl::OkStatus(); } // static @@ -134,13 +134,13 @@ Status FakeInputImpl::GetN(int* n) const { arg_->name(), "': ", status.message()); } } - return OkStatus(); + return absl::OkStatus(); } Status FakeInputImpl::GetDataType(DataType* dt) const { if (dt_specified_) { *dt = dt_; - return OkStatus(); // Ignore is_ref field of arg_. + return absl::OkStatus(); // Ignore is_ref field of arg_. } else if (arg_->type() != DT_INVALID) { *dt = arg_->type(); } else if (!arg_->type_attr().empty()) { @@ -162,7 +162,7 @@ Status FakeInputImpl::GetDataType(DataType* dt) const { if (arg_->is_ref()) { *dt = MakeRefType(*dt); } - return OkStatus(); + return absl::OkStatus(); } void FakeInputImpl::NSources(int n, DataType dt) const { @@ -171,7 +171,7 @@ void FakeInputImpl::NSources(int n, DataType dt) const { for (int i = 0; i < n; ++i) { srcs.emplace_back(in_node_, i, dt); } - builder_->Input(gtl::ArraySlice(srcs)); + builder_->Input(absl::Span(srcs)); } void FakeInputImpl::SourceList(DataTypeSlice dts) const { @@ -180,7 +180,7 @@ void FakeInputImpl::SourceList(DataTypeSlice dts) const { for (size_t i = 0; i < dts.size(); ++i) { srcs.emplace_back(in_node_, i, dts[i]); } - builder_->Input(gtl::ArraySlice(srcs)); + builder_->Input(absl::Span(srcs)); } } // namespace diff --git a/tensorflow/core/framework/full_type_util.cc b/tensorflow/core/framework/full_type_util.cc index fcc6446b67ac4b..b76b1d52274095 100644 --- a/tensorflow/core/framework/full_type_util.cc +++ b/tensorflow/core/framework/full_type_util.cc @@ -41,7 +41,7 @@ OpTypeConstructor NoOp() { OpTypeConstructor NoOutputs() { return [](OpDef* op_def) { op_def->mutable_output_arg(); - return OkStatus(); + return absl::OkStatus(); }; } @@ -50,7 +50,7 @@ OpTypeConstructor Nullary(FullTypeId t) { FullTypeDef* tdef = op_def->mutable_output_arg(0)->mutable_experimental_full_type(); tdef->set_type_id(t); - return OkStatus(); + return absl::OkStatus(); }; } @@ -64,7 +64,7 @@ OpTypeConstructor Unary(FullTypeId t, const string& var_name) { arg->set_type_id(TFT_VAR); arg->set_s(var_name); - return OkStatus(); + return absl::OkStatus(); }; } @@ -77,7 +77,7 @@ OpTypeConstructor UnaryGeneric(FullTypeId t) { FullTypeDef* arg = tdef->add_args(); arg->set_type_id(TFT_ANY); - return OkStatus(); + return absl::OkStatus(); }; } @@ -92,7 +92,7 @@ OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype) { FullTypeDef* targ = arg->add_args(); targ->set_type_id(dtype); - return OkStatus(); + return absl::OkStatus(); }; } @@ -108,7 +108,7 @@ OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name) { varg->set_type_id(TFT_VAR); varg->set_s(var_name); - return OkStatus(); + return absl::OkStatus(); }; } @@ -133,7 +133,7 @@ OpTypeConstructor VariadicTensorContainer(FullTypeId t, tvar->set_type_id(TFT_VAR); tvar->set_s(var_name); - return OkStatus(); + return absl::OkStatus(); }; } @@ -176,7 +176,7 @@ Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) { attr->DebugString(), " for name ", var_name)); } t.clear_s(); - return OkStatus(); + return absl::OkStatus(); } Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { @@ -238,7 +238,7 @@ Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { attr->DebugString(), "\nfor name ", var_name)); } t = result; - return OkStatus(); + return absl::OkStatus(); } Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) { @@ -257,7 +257,7 @@ Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) { break; } } - return OkStatus(); + return absl::OkStatus(); } inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) { @@ -281,7 +281,7 @@ inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) { default: return SubstituteGeneric(attrs, t); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -312,7 +312,7 @@ Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def, t.DebugString(), "\nfrom\n", attrs.SummarizeNode()); } - return OkStatus(); + return absl::OkStatus(); } const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i) { diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 0b6bacd94af0d9..61cfee4198de94 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -90,7 +90,7 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, for (int i = 0; i < v->list().type_size(); ++i) { dtypes->push_back(v->list().type(i)); } - return OkStatus(); + return absl::OkStatus(); } *is_type_list = false; @@ -116,7 +116,7 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, dtype = v->type(); } dtypes->resize(num, dtype); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -166,7 +166,7 @@ Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) { } #endif - return OkStatus(); + return absl::OkStatus(); } // A helper class for instantiating functions. This contains shared information @@ -229,7 +229,7 @@ class FunctionInstantiationHelper { result_.arg_types.push_back(dtypes[i]); ++arg_index; } - return OkStatus(); + return absl::OkStatus(); } Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, @@ -259,7 +259,7 @@ class FunctionInstantiationHelper { } start += dtypes.size(); } - return OkStatus(); + return absl::OkStatus(); } Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { @@ -363,7 +363,7 @@ class FunctionInstantiationHelper { *gnode->mutable_experimental_type() = fnode.experimental_type(); } - return OkStatus(); + return absl::OkStatus(); } Status AddReturnNode( @@ -406,7 +406,7 @@ class FunctionInstantiationHelper { AddAttr("index", (*ret_index)++, gnode); result_.ret_types.push_back(dtypes[i]); } - return OkStatus(); + return absl::OkStatus(); } // Adds the actual node inputs to the result graph by converting indexes to @@ -452,7 +452,7 @@ class FunctionInstantiationHelper { " name: "), name); } - return OkStatus(); + return absl::OkStatus(); } const NameInfoItem* GetItemOrNull(const string& name) const { @@ -644,7 +644,7 @@ string Print(const FunctionDef& fdef) { return out; } -string Print(gtl::ArraySlice nodes) { +string Print(absl::Span nodes) { std::vector arg; std::vector ret; std::vector body; @@ -738,7 +738,7 @@ Status AddDefaultAttrs(const string& op, } } } - return OkStatus(); + return absl::OkStatus(); } } // end namespace @@ -857,7 +857,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, // Adds the actual node inputs using the input indexes. helper.AddNodeInputs(); - return OkStatus(); + return absl::OkStatus(); } string DebugString(const FunctionDef& func_def) { return Print(func_def); } @@ -870,7 +870,7 @@ string DebugString(const GraphDef& instantiated_func_def) { return Print(ptrs); } -string DebugString(gtl::ArraySlice instantiated_func_nodes) { +string DebugString(absl::Span instantiated_func_nodes) { std::vector ptrs; for (const NodeDef& n : instantiated_func_nodes) { ptrs.push_back(&n); @@ -1147,7 +1147,7 @@ FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, FunctionCallFrame::~FunctionCallFrame() {} -Status FunctionCallFrame::SetArgs(gtl::ArraySlice args) { +Status FunctionCallFrame::SetArgs(absl::Span args) { // Input type checks. if (args.size() != arg_types_.size()) { return errors::InvalidArgument("Expects ", arg_types_.size(), @@ -1162,7 +1162,7 @@ Status FunctionCallFrame::SetArgs(gtl::ArraySlice args) { } args_[i] = args[i]; } - return OkStatus(); + return absl::OkStatus(); } Status FunctionCallFrame::GetRetvals(std::vector* rets) const { @@ -1176,7 +1176,7 @@ Status FunctionCallFrame::GetRetvals(std::vector* rets) const { return errors::Internal("Retval[", i, "] does not have value"); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionCallFrame::ConsumeRetvals(std::vector* rets, @@ -1192,7 +1192,7 @@ Status FunctionCallFrame::ConsumeRetvals(std::vector* rets, return errors::Internal("Retval[", i, "] does not have value"); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionCallFrame::GetArg(int index, const Tensor** val) { @@ -1201,7 +1201,7 @@ Status FunctionCallFrame::GetArg(int index, const Tensor** val) { args_.size(), ")"); } *val = &args_[index]; - return OkStatus(); + return absl::OkStatus(); } Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { @@ -1221,7 +1221,7 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { } else { return errors::Internal("Retval[", index, "] has already been set."); } - return OkStatus(); + return absl::OkStatus(); } FunctionRecord::FunctionRecord(const FunctionDef& fdef, @@ -1446,7 +1446,7 @@ Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, "exists."); } // Ignore duplicate FunctionDefs. - return OkStatus(); + return absl::OkStatus(); } const OpDef* op_def; if (default_registry_ @@ -1460,7 +1460,7 @@ Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, registration->finalize(); records_.insert({registration->fdef().signature().name(), registration}); *added = true; - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::CopyFunctionDefFrom( @@ -1485,7 +1485,7 @@ Status FunctionLibraryDefinition::CopyFunctionDefFrom( "' because a different function with the same name already " "exists."); } else { - return OkStatus(); + return absl::OkStatus(); } } else if (other_record->finalized()) { bool added; @@ -1514,11 +1514,11 @@ Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, "'", *entry, "'"); } // Ignore duplicate GradientDefs - return OkStatus(); + return absl::OkStatus(); } *entry = grad.gradient_func(); *added = true; - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::AddLibrary( @@ -1567,7 +1567,7 @@ Status FunctionLibraryDefinition::AddLibrary( funcs_with_grads.push_back(grad.function_name()); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::AddLibrary( @@ -1625,7 +1625,7 @@ Status FunctionLibraryDefinition::AddLibrary( funcs_with_grads.push_back(grad.function_name()); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::ReplaceFunction( @@ -1636,7 +1636,7 @@ Status FunctionLibraryDefinition::ReplaceFunction( TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); TF_RETURN_IF_ERROR(AddFunctionDefHelper( FunctionDef(fdef), StackTracesMap(stack_traces), &added)); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { @@ -1644,13 +1644,13 @@ Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { bool added; TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name())); TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added)); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::RemoveFunction(const string& func) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { @@ -1661,7 +1661,7 @@ Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { } iter->second->Unref(); records_.erase(iter); - return OkStatus(); + return absl::OkStatus(); } void FunctionLibraryDefinition::Clear() { @@ -1681,7 +1681,7 @@ Status FunctionLibraryDefinition::RemoveGradient(const string& func) { func, "'."); } func_grad_.erase(i); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::Remove( @@ -1700,7 +1700,7 @@ Status FunctionLibraryDefinition::Remove( return s; } } - return OkStatus(); + return absl::OkStatus(); } string FunctionLibraryDefinition::FindGradient(const string& func) const { @@ -1718,7 +1718,7 @@ Status FunctionLibraryDefinition::LookUp( auto iter = records_.find(op); if (iter != records_.end()) { *op_reg_data = &iter->second->op_registration_data(); - return OkStatus(); + return absl::OkStatus(); } return default_registry_->LookUp(op, op_reg_data); } @@ -1796,7 +1796,7 @@ Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, const string& attr, T* value) const { const FunctionDef* fdef = GetAttrImpl(ndef); if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) { - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument("Attr ", attr, " is not defined."); } @@ -1837,7 +1837,7 @@ std::set ReachableFunctions(const FunctionLibraryDefinition& flib, // Functions might be reachable from the nested function calls, so we keep a // queue of functions that we have to check. - gtl::InlinedVector, 4> func_queue; + absl::InlinedVector, 4> func_queue; // Add reachable and not already processed functions to the functions queue. const auto add_to_func_queue = [&](const string& func_name) { @@ -2043,7 +2043,7 @@ void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( const string& name, - gtl::ArraySlice> attrs) { + absl::Span> attrs) { AttrValueWrapper ret; ret.proto.mutable_func()->set_name(name); for (const auto& a : attrs) { @@ -2081,11 +2081,11 @@ NodeDef FunctionDefHelper::Node::ToNodeDef() const { /* static */ FunctionDef FunctionDefHelper::Create( - const string& function_name, gtl::ArraySlice in_def, - gtl::ArraySlice out_def, gtl::ArraySlice attr_def, - gtl::ArraySlice node_def, - gtl::ArraySlice> ret_def, - gtl::ArraySlice> control_ret_def) { + const string& function_name, absl::Span in_def, + absl::Span out_def, absl::Span attr_def, + absl::Span node_def, + absl::Span> ret_def, + absl::Span> control_ret_def) { FunctionDef fdef; // Signature @@ -2131,20 +2131,20 @@ FunctionDef FunctionDefHelper::Create( /* static */ FunctionDef FunctionDefHelper::Create( - const string& function_name, gtl::ArraySlice in_def, - gtl::ArraySlice out_def, gtl::ArraySlice attr_def, - gtl::ArraySlice node_def, - gtl::ArraySlice> ret_def) { + const string& function_name, absl::Span in_def, + absl::Span out_def, absl::Span attr_def, + absl::Span node_def, + absl::Span> ret_def) { return Create(function_name, in_def, out_def, attr_def, node_def, ret_def, /*control_ret_def=*/{}); } /* static */ FunctionDef FunctionDefHelper::Define(const string& name, - gtl::ArraySlice arg_def, - gtl::ArraySlice ret_def, - gtl::ArraySlice attr_def, - gtl::ArraySlice node_def) { + absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, + absl::Span node_def) { FunctionDef fdef; OpDefBuilder b(name); for (const auto& a : arg_def) b.Input(a); @@ -2209,10 +2209,10 @@ FunctionDef FunctionDefHelper::Define(const string& name, return fdef; } -FunctionDef FunctionDefHelper::Define(gtl::ArraySlice arg_def, - gtl::ArraySlice ret_def, - gtl::ArraySlice attr_def, - gtl::ArraySlice node_def) { +FunctionDef FunctionDefHelper::Define(absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, + absl::Span node_def) { return Define("_", arg_def, ret_def, attr_def, node_def); } @@ -2238,7 +2238,7 @@ Status GetOpGradientCreator(const string& op, Creator* creator) { return errors::NotFound("No gradient defined for op: ", op); } *creator = iter->second; - return OkStatus(); + return absl::OkStatus(); } } // end namespace gradient diff --git a/tensorflow/core/framework/function_handle_cache.cc b/tensorflow/core/framework/function_handle_cache.cc index 446f8cefdc81ed..add92c44aff5bc 100644 --- a/tensorflow/core/framework/function_handle_cache.cc +++ b/tensorflow/core/framework/function_handle_cache.cc @@ -51,7 +51,7 @@ Status FunctionHandleCache::Instantiate( } else { *handle = h; } - return OkStatus(); + return absl::OkStatus(); } Status FunctionHandleCache::Clear() { @@ -60,7 +60,7 @@ Status FunctionHandleCache::Clear() { TF_RETURN_IF_ERROR(lib_->ReleaseHandle(entry.second)); } handles_.clear(); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index 418de4290a5c2d..8b9a8615bc6113 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -49,7 +49,7 @@ Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { for (const NodeDef& node : graph_def.node()) { TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); } - return OkStatus(); + return absl::OkStatus(); } Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, @@ -79,7 +79,7 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, } } - return OkStatus(); + return absl::OkStatus(); } static Status RemoveNewDefaultAttrsFromNodeDef( @@ -124,7 +124,7 @@ static Status RemoveNewDefaultAttrsFromNodeDef( } } - return OkStatus(); + return absl::OkStatus(); } static bool IsFunction(const GraphDef& graph_def, const string& op_name) { @@ -161,7 +161,7 @@ Status RemoveNewDefaultAttrsFromGraphDef( } } - return OkStatus(); + return absl::OkStatus(); } void StripDefaultAttributes(const OpRegistryInterface& op_registry, @@ -261,7 +261,7 @@ Status StrippedOpListForGraph(const GraphDef& graph_def, stripped_op->CopyFrom(*op_def); RemoveDescriptionsFromOpDef(stripped_op); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc index 384d9cba6865a2..fcd48e3fc5e047 100644 --- a/tensorflow/core/framework/graph_to_functiondef.cc +++ b/tensorflow/core/framework/graph_to_functiondef.cc @@ -145,7 +145,7 @@ Status NodeNameMapping::UseOutputName(const string& name) { "' appears more than once in 'output_names' array."); } used_names_.emplace(name, 0); - return OkStatus(); + return absl::OkStatus(); } string NodeNameMapping::Lookup(const string& name) const { @@ -318,7 +318,7 @@ Status FillFunctionBody( func_attr_names.insert(func_attr_name); } } - return OkStatus(); + return absl::OkStatus(); } Status GraphToFunctionDefHelper( @@ -536,7 +536,7 @@ Status GraphToFunctionDefHelper( fdef->mutable_signature()->add_control_output(control_output); } - return OkStatus(); + return absl::OkStatus(); } Status GraphToFunctionDefHelper( @@ -560,7 +560,7 @@ Status GraphToFunctionDefHelper( (*args_or_retvals)[index].node->DebugString(), "\nNow we have:\n", node->DebugString()); } - return OkStatus(); + return absl::OkStatus(); }; std::vector body_nodes; @@ -599,7 +599,7 @@ Status GraphToFunctionDefHelper( "' node at index ", i); } } - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR(validate_args_retvals(inputs, "_Arg")); @@ -631,7 +631,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, copy_placeholder_attrs_from_nodes, body_nodes, inputs, outputs, output_names, control_outputs, control_output_names, description, /*allow_destructive_reads=*/false, fdef); - return OkStatus(); + return absl::OkStatus(); } Status GraphToFunctionDef( diff --git a/tensorflow/core/framework/graph_to_functiondef_test.cc b/tensorflow/core/framework/graph_to_functiondef_test.cc index e6c30171910402..f29295274dfbe2 100644 --- a/tensorflow/core/framework/graph_to_functiondef_test.cc +++ b/tensorflow/core/framework/graph_to_functiondef_test.cc @@ -229,7 +229,7 @@ TEST(GraphToFunctionDefTest, ArgAttrConstInput) { args_or_retvals->resize(index + 1); } (*args_or_retvals)[index].node = node; - return OkStatus(); + return absl::OkStatus(); }; for (Node* node : root.graph()->op_nodes()) { // Set const as the input node. diff --git a/tensorflow/core/framework/kernel_def_util.cc b/tensorflow/core/framework/kernel_def_util.cc index 69738eea671f52..d1f556bdaa9288 100644 --- a/tensorflow/core/framework/kernel_def_util.cc +++ b/tensorflow/core/framework/kernel_def_util.cc @@ -117,7 +117,7 @@ Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, if (attr_value->type() != DT_INVALID) { if (!InTypeList(attr_value->type(), constraint.allowed_values())) { - return OkStatus(); + return absl::OkStatus(); } } else { if (!AttrValueHasType(*attr_value, "list(type)").ok()) { @@ -133,13 +133,13 @@ Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, for (int t : attr_value->list().type()) { if (!InTypeList(static_cast(t), constraint.allowed_values())) { - return OkStatus(); + return absl::OkStatus(); } } } } *match = true; - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/kernel_shape_util.cc b/tensorflow/core/framework/kernel_shape_util.cc index 071821ce4a56d6..f06a366f435e5f 100644 --- a/tensorflow/core/framework/kernel_shape_util.cc +++ b/tensorflow/core/framework/kernel_shape_util.cc @@ -63,7 +63,7 @@ Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, ", effective_filter_size: ", effective_filter_size, ", stride: ", stride, "]"); } - return OkStatus(); + return absl::OkStatus(); } Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, @@ -93,6 +93,6 @@ Status Get3dOutputSizeV2(const std::array& input, input[i], window[i], dilations[i], strides[i], padding_type, &(*output_ptr)[i], &(*padding_ptr)[i])); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index f1c3a4b3935605..d428f6d463ea51 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -66,7 +66,7 @@ Status LoadDynamicLibrary(const char* library_filename, void** result, if (seen_op_names.find(opdef.name()) == seen_op_names.end()) { // Over writing a registration of an op not in this custom op // library. Treat this as not an error. - return OkStatus(); + return absl::OkStatus(); } } if (s.ok()) { @@ -98,7 +98,7 @@ Status LoadDynamicLibrary(const char* library_filename, void** result, *len = str.length(); *result = library.handle; - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/local_rendezvous.cc b/tensorflow/core/framework/local_rendezvous.cc index 488e9251d8e913..910c8a92a744fb 100644 --- a/tensorflow/core/framework/local_rendezvous.cc +++ b/tensorflow/core/framework/local_rendezvous.cc @@ -191,7 +191,7 @@ Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, queue->push_back(new Item(std::move(rc_owner), send_args, val, is_dead, std::move(activity_scope))); bucket.mu.unlock(); - return OkStatus(); + return absl::OkStatus(); } DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). "; @@ -210,7 +210,8 @@ Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, bucket.mu.unlock(); DCHECK_EQ(item->type, Item::kRecv); - (*item->recv_state.waiter)(OkStatus(), send_args, item->args, val, is_dead); + (*item->recv_state.waiter)(absl::OkStatus(), send_args, item->args, val, + is_dead); { mutex_lock l(bucket.mu); bucket.pending_callback_counter--; @@ -220,7 +221,7 @@ Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, } // Delete the item at last since it may unref and destruct the rendezvous. delete item; - return OkStatus(); + return absl::OkStatus(); } void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, @@ -367,7 +368,7 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, bucket.mu.unlock(); DCHECK_EQ(item->type, Item::kSend); - done(OkStatus(), item->args, recv_args, *item->send_state.value, + done(absl::OkStatus(), item->args, recv_args, *item->send_state.value, item->send_state.is_dead); { mutex_lock l(bucket.mu); diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc index b868faf03ef426..2dc224c3f5b6ea 100644 --- a/tensorflow/core/framework/lookup_interface.cc +++ b/tensorflow/core/framework/lookup_interface.cc @@ -27,7 +27,7 @@ Status LookupInterface::CheckKeyShape(const TensorShape& shape) { " must end with the table's key shape ", key_shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, @@ -40,7 +40,7 @@ Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, return errors::InvalidArgument("Value must be type ", value_dtype(), " but got ", values.dtype()); } - return OkStatus(); + return absl::OkStatus(); } Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys, @@ -58,7 +58,7 @@ Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys, "Expected shape ", expected_value_shape.DebugString(), " for value, got ", values.shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status LookupInterface::CheckKeyAndValueTensorsForInsert(const Tensor& keys, @@ -95,7 +95,7 @@ Status LookupInterface::CheckFindArguments(const Tensor& key, fullsize_value_shape.DebugString(), " for default value, got ", default_value.shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace lookup diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 602b27db362010..d3d9bcbf759032 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -171,7 +171,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index a34e274c48228f..1fc6622bebe170 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -375,7 +375,7 @@ Status ModelToProtoHelper(std::shared_ptr output, ModelProto* model) { to_serialize.push_back(input); } } - return OkStatus(); + return absl::OkStatus(); } // Recursively produces node tree rooted in `output` from the given model proto. @@ -398,7 +398,7 @@ Status ModelFromProtoHelper(ModelProto model, std::shared_ptr* output) { to_restore_inputs.push_back(input); } } - return OkStatus(); + return absl::OkStatus(); } // The first input of InterleaveMany corresponds to the input dataset whose @@ -555,7 +555,7 @@ class InterleaveMany : public Node { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::INTERLEAVE_MANY); - return OkStatus(); + return absl::OkStatus(); } }; @@ -778,7 +778,7 @@ class AsyncInterleaveMany : public Node { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY); - return OkStatus(); + return absl::OkStatus(); } }; @@ -871,7 +871,7 @@ class KnownRatio : public Node { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::KNOWN_RATIO); node_proto->set_ratio(ratio_); - return OkStatus(); + return absl::OkStatus(); } private: @@ -1250,7 +1250,7 @@ class UnknownRatio : public Node { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::UNKNOWN_RATIO); - return OkStatus(); + return absl::OkStatus(); } }; @@ -1304,7 +1304,7 @@ class Unknown : public Node { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::UNKNOWN); - return OkStatus(); + return absl::OkStatus(); } }; @@ -1347,7 +1347,7 @@ class AsyncKnownRatio : public AsyncRatio { parameter->set_value(parameter->state_value()); parameter->set_tunable(true); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1390,7 +1390,7 @@ class AsyncUnknownRatio : public AsyncRatio { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_UNKNOWN_RATIO); - return OkStatus(); + return absl::OkStatus(); } }; @@ -2168,7 +2168,7 @@ Status Node::ToProto(ModelProto::Node* node_proto) const { for (auto const& input : inputs_) { node_proto->add_inputs(input->id()); } - return OkStatus(); + return absl::OkStatus(); } Status Node::FromProtoHelper(ModelProto::Node node_proto, @@ -2218,7 +2218,7 @@ Status Node::FromProtoHelper(ModelProto::Node node_proto, mutex_lock l(node->mu_); node->UpdateProcessingTimeEma(); } - return OkStatus(); + return absl::OkStatus(); } Status Node::FromProto(ModelProto::Node node_proto, @@ -2567,7 +2567,7 @@ Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; } if (cancellation_manager->IsCancelled()) { - return OkStatus(); + return absl::OkStatus(); } } @@ -3194,7 +3194,7 @@ Status Model::ToProto(ModelProto* model_proto) { tf_shared_lock gap_lock(gap_mu_); *model_proto->mutable_gap_times() = {gap_times_usec_.begin(), gap_times_usec_.end()}; - return OkStatus(); + return absl::OkStatus(); } Status Model::FromProto(ModelProto model_proto, std::unique_ptr* model) { @@ -3204,7 +3204,7 @@ Status Model::FromProto(ModelProto model_proto, std::unique_ptr* model) { ModelFromProtoHelper(model_proto, &restored_model->output_)); restored_model->id_counter_ = model_proto.id_counter(); *model = std::move(restored_model); - return OkStatus(); + return absl::OkStatus(); } Status Model::Save(const string& fname, std::shared_ptr snapshot, @@ -3232,7 +3232,7 @@ Status Model::Load(const string& fname, std::unique_ptr* model, const OptimizationParams restored_optimization_params = model_proto.optimization_params(); *optimization_params = restored_optimization_params; - return OkStatus(); + return absl::OkStatus(); } std::string Model::DebugString() { diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index fcf73e6970bb5c..86365b494217bd 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -106,7 +106,7 @@ NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) { } // For inputs that take a list of tensors. -NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice src_list) { +NodeDefBuilder& NodeDefBuilder::Input(absl::Span src_list) { const OpDef::ArgDef* arg = NextArgDef(); if (arg != nullptr) ListInput(arg, src_list); return *this; @@ -134,7 +134,7 @@ void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, } void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, - gtl::ArraySlice src_list) { + absl::Span src_list) { for (const auto& node_out : src_list) { AddInput(node_out.node, node_out.index); } @@ -262,7 +262,7 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { // Add default values for unspecified attrs. AddDefaultsToNodeDef(*op_def_, node_def); - return OkStatus(); + return absl::OkStatus(); } } @@ -311,21 +311,21 @@ ATTR(const PartialTensorShape&) ATTR(const Tensor&) ATTR(const TensorProto&) ATTR(const NameAttrList&) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) ATTR(const std::vector&) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) #undef ATTR } // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index d3af99893e7897..183a80ac18b1f5 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -185,7 +185,7 @@ const AttrValue* AttrSlice::FindByString(const string& attr_name) const { Status AttrSlice::CheckFind(StringPiece attr_name, const AttrValue* attr_value) const { if (attr_value != nullptr) { - return OkStatus(); + return absl::OkStatus(); } Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:"); // Skip AttachDef for internal attrs since it is a little bit @@ -402,7 +402,7 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, for (const auto& v : attr_value->list().type()) { value->push_back(static_cast(v)); } - return OkStatus(); + return absl::OkStatus(); } Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, @@ -411,7 +411,7 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor")); *value = &attr_value->tensor(); - return OkStatus(); + return absl::OkStatus(); } bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, @@ -434,7 +434,7 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func")); *value = &attr_value->func(); - return OkStatus(); + return absl::OkStatus(); } bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, @@ -523,7 +523,7 @@ Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs, (*sig)[i] = MakeRefType((*sig)[i]); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -537,7 +537,7 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, if (input_types_size > input_port) { const DataType dtype = input_types[input_port]; *input_type = dtype; - return OkStatus(); + return absl::OkStatus(); } } return errors::InvalidArgument("Input ", input_port, " not found for node ", @@ -549,7 +549,7 @@ Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, for (const auto& arg : op_def.input_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); } - return OkStatus(); + return absl::OkStatus(); } Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, @@ -561,7 +561,7 @@ Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, if (output_types_size > output_port) { const DataType dtype = output_types[output_port]; *output_type = dtype; - return OkStatus(); + return absl::OkStatus(); } } return errors::InvalidArgument("Output ", output_port, " not found for node ", @@ -573,7 +573,7 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, for (const auto& arg : op_def.output_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs)); } - return OkStatus(); + return absl::OkStatus(); } Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, @@ -581,7 +581,7 @@ Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, for (const auto& arg : op_def.output_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs)); } - return OkStatus(); + return absl::OkStatus(); } Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, @@ -595,7 +595,7 @@ Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, DataTypeVector outputs; TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs)); *num_outputs = outputs.size(); - return OkStatus(); + return absl::OkStatus(); } int OpPortIdToArgId(const NodeDef& node, @@ -718,7 +718,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def)); } - return OkStatus(); + return absl::OkStatus(); } namespace { // Helpers for NameRangesForNode() @@ -739,7 +739,7 @@ Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def, "Argument '", arg_def.name(), "' incorrectly specified in op definition: ", SummarizeOpDef(op_def)); } - return OkStatus(); + return absl::OkStatus(); } Status NameRangesHelper(const AttrSlice& attrs, @@ -752,7 +752,7 @@ Status NameRangesHelper(const AttrSlice& attrs, (*result)[arg.name()] = std::make_pair(start, start + num); start += num; } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -766,7 +766,7 @@ Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, if (outputs != nullptr) { return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs); } - return OkStatus(); + return absl::OkStatus(); } void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { @@ -866,10 +866,10 @@ const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); Status ValidateOpInput(const string& input_name, bool* is_control_input) { *is_control_input = false; if (IsValidDataInputName(input_name)) { - return OkStatus(); + return absl::OkStatus(); } else if (IsValidControlInputName(input_name)) { *is_control_input = true; - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Illegal op input name '", input_name, "'"); } @@ -877,7 +877,7 @@ Status ValidateOpInput(const string& input_name, bool* is_control_input) { Status ValidateNodeName(const string& node_name) { if (IsValidNodeName(node_name)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Illegal op name '", node_name, "'"); } @@ -903,7 +903,7 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { } in_control_inputs = is_control_input; } - return OkStatus(); + return absl::OkStatus(); } Status AttachDef(const Status& status, const NodeDef& node_def, @@ -947,20 +947,20 @@ ADD_NODE_ATTR(const PartialTensorShape&) ADD_NODE_ATTR(const Tensor&) ADD_NODE_ATTR(const TensorProto&) ADD_NODE_ATTR(const NameAttrList&) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(const std::vector&) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) #undef ADD_NODE_ATTR void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { @@ -990,7 +990,7 @@ Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, attr.set_s(frame_name); } - return OkStatus(); + return absl::OkStatus(); } Status MaybeAddPrefixToColocationConstraints( @@ -998,7 +998,7 @@ Status MaybeAddPrefixToColocationConstraints( NodeDef* node_def) { auto attr = node_def->mutable_attr()->find(kColocationAttrName); if (attr == node_def->mutable_attr()->end()) { - return OkStatus(); + return absl::OkStatus(); } auto constraints_list = attr->second.mutable_list(); auto constraints_size = constraints_list->s_size(); @@ -1011,7 +1011,7 @@ Status MaybeAddPrefixToColocationConstraints( } } } - return OkStatus(); + return absl::OkStatus(); } Status MaybeUpdateColocationConstraintsWithMap( @@ -1019,7 +1019,7 @@ Status MaybeUpdateColocationConstraintsWithMap( NodeDef* node_def) { auto attr = node_def->mutable_attr()->find(kColocationAttrName); if (attr == node_def->mutable_attr()->end()) { - return OkStatus(); + return absl::OkStatus(); } auto constraints_list = attr->second.mutable_list(); auto constraints_size = constraints_list->s_size(); @@ -1032,7 +1032,7 @@ Status MaybeUpdateColocationConstraintsWithMap( } } } - return OkStatus(); + return absl::OkStatus(); } void ChangeToNoOp(NodeDef* node_def) { diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index fbba2c86892112..67bde1fc71e228 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -50,7 +50,7 @@ NodeDef ToNodeDef(NodeDefBuilder&& builder) { } void ExpectSuccess(const NodeDef& good, const OpDef& op_def) { - EXPECT_EQ(OkStatus(), ValidateNodeDef(good, op_def)) + EXPECT_EQ(absl::OkStatus(), ValidateNodeDef(good, op_def)) << "NodeDef: " << SummarizeNodeDef(good) << "; OpDef: " << SummarizeOpDef(op_def); } @@ -318,7 +318,7 @@ TEST(NodeDefUtilTest, Device) { } void ExpectValidSyntax(const NodeDef& good) { - EXPECT_EQ(OkStatus(), ValidateExternalNodeDefSyntax(good)) + EXPECT_EQ(absl::OkStatus(), ValidateExternalNodeDefSyntax(good)) << "NodeDef: " << SummarizeNodeDef(good); } diff --git a/tensorflow/core/framework/node_properties.cc b/tensorflow/core/framework/node_properties.cc index 23eda55c6da49b..4af538b3b2c1c5 100644 --- a/tensorflow/core/framework/node_properties.cc +++ b/tensorflow/core/framework/node_properties.cc @@ -33,7 +33,7 @@ Status NodeProperties::CreateFromNodeDef( props->reset(new NodeProperties(op_def, std::move(node_def), std::move(input_types), std::move(output_types))); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/node_properties_test.cc b/tensorflow/core/framework/node_properties_test.cc index 258f413fba8c6e..5621137c7aba71 100644 --- a/tensorflow/core/framework/node_properties_test.cc +++ b/tensorflow/core/framework/node_properties_test.cc @@ -44,7 +44,7 @@ class MockOpRegistry : public OpRegistryInterface { const OpRegistrationData** op_reg_data) const override { if (op_type_name == "Foo") { *op_reg_data = &op_reg_; - return OkStatus(); + return absl::OkStatus(); } else { *op_reg_data = nullptr; return errors::InvalidArgument("Op type named ", op_type_name, diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index ccd5edcb3d37b5..3c3970506389f9 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -34,7 +34,7 @@ namespace tensorflow { Status DefaultValidator(const OpRegistryInterface& op_registry) { LOG(WARNING) << "No kernel validator registered with OpRegistry."; - return OkStatus(); + return absl::OkStatus(); } // OpRegistry ----------------------------------------------------------------- @@ -45,7 +45,7 @@ Status OpRegistryInterface::LookUpOpDef(const string& op_type_name, const OpRegistrationData* op_reg_data = nullptr; TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data)); *op_def = &op_reg_data->op_def; - return OkStatus(); + return absl::OkStatus(); } OpRegistry::OpRegistry() @@ -78,7 +78,7 @@ Status OpNotFound(const string& op_type_name) { Status OpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { - if ((*op_reg_data = LookUp(op_type_name))) return OkStatus(); + if ((*op_reg_data = LookUp(op_type_name))) return absl::OkStatus(); return OpNotFound(op_type_name); } @@ -155,7 +155,7 @@ Status OpRegistry::SetWatcher(const Watcher& watcher) { "Cannot over-write a valid watcher with another."); } watcher_ = watcher; - return OkStatus(); + return absl::OkStatus(); } void OpRegistry::Export(bool include_internal, OpList* ops) const { @@ -217,7 +217,7 @@ bool OpRegistry::MustCallDeferred() const { } Status OpRegistry::CallDeferred() const { - if (initialized_) return OkStatus(); + if (initialized_) return absl::OkStatus(); initialized_ = true; registry_.reserve(registry_.size() + deferred_.size()); for (const auto& op_data_factory : deferred_) { @@ -227,7 +227,7 @@ Status OpRegistry::CallDeferred() const { } } deferred_.clear(); - return OkStatus(); + return absl::OkStatus(); } Status OpRegistry::RegisterAlreadyLocked( @@ -278,7 +278,7 @@ const OpRegistrationData* OpListOpRegistry::LookUp( Status OpListOpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { - if ((*op_reg_data = LookUp(op_type_name))) return OkStatus(); + if ((*op_reg_data = LookUp(op_type_name))) return absl::OkStatus(); return OpNotFound(op_type_name); } diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 71bc11acb1f8ea..83aa4d8e1974dd 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -492,7 +492,7 @@ void FinalizeDoc(const string& text, OpDef* op_def, // Trim trailing blank lines from the description. while (start_l < end_l && lines[end_l - 1].empty()) --end_l; string desc = absl::StrJoin( - gtl::ArraySlice(lines.data() + start_l, end_l - start_l), "\n"); + absl::Span(lines.data() + start_l, end_l - start_l), "\n"); if (!desc.empty()) op_def->set_description(desc); // name: description @@ -687,7 +687,7 @@ Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { TF_RETURN_IF_ERROR(op_reg_data->type_ctor(op_def)); } - if (errors.empty()) return OkStatus(); + if (errors.empty()) return absl::OkStatus(); return errors::InvalidArgument(absl::StrJoin(errors, "\n")); } diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index fd6e284e5c1917..1da0aa726d64ca 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -45,7 +45,7 @@ Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { const AttrValue& allowed_values(attr.allowed_values()); for (auto allowed : allowed_values.list().type()) { if (dt == allowed) { - return OkStatus(); + return absl::OkStatus(); } } string allowed_str; @@ -65,7 +65,7 @@ Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { const AttrValue& allowed_values(attr.allowed_values()); for (const auto& allowed : allowed_values.list().s()) { if (str == allowed) { - return OkStatus(); + return absl::OkStatus(); } } string allowed_str; @@ -143,7 +143,7 @@ Status ValidateAttrValue(const AttrValue& attr_value, "Support for allowed_values not implemented for type ", attr.type()); } } - return OkStatus(); + return absl::OkStatus(); } const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { @@ -244,7 +244,7 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix); } - return OkStatus(); + return absl::OkStatus(); } bool IsValidOpName(StringPiece sp) { @@ -343,7 +343,7 @@ Status ValidateOpDef(const OpDef& op_def) { TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); } - return OkStatus(); + return absl::OkStatus(); } #undef VALIDATE @@ -372,7 +372,7 @@ Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { } } } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -684,7 +684,7 @@ Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { " changed from ref to non-ref"); } - return OkStatus(); + return absl::OkStatus(); } Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, @@ -723,7 +723,7 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, } } - return OkStatus(); + return absl::OkStatus(); } Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { @@ -752,7 +752,7 @@ Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { } } - return OkStatus(); + return absl::OkStatus(); } void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) { diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index 11a17486372f21..9151e1b0448fb2 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -55,7 +55,7 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) { StringPiece to_append = str.substr(0, space); str.remove_prefix(space + 1); // Remove spaces at break. - while (str_util::EndsWith(to_append, " ")) { + while (absl::EndsWith(to_append, " ")) { to_append.remove_suffix(1); } while (absl::ConsumePrefix(&str, " ")) { @@ -466,7 +466,7 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) { strings::StrCat(description, "\n", new_api_def.description_suffix()); } base_api_def->set_description(description); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -484,11 +484,11 @@ Status ApiDefMap::LoadFileList(Env* env, const std::vector& filenames) { for (const auto& filename : filenames) { TF_RETURN_IF_ERROR(LoadFile(env, filename)); } - return OkStatus(); + return absl::OkStatus(); } Status ApiDefMap::LoadFile(Env* env, const string& filename) { - if (filename.empty()) return OkStatus(); + if (filename.empty()) return absl::OkStatus(); string contents; TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); Status status = LoadApiDef(contents); @@ -498,7 +498,7 @@ Status ApiDefMap::LoadFile(Env* env, const string& filename) { status, strings::StrCat("Error parsing ApiDef file ", filename, ": ", status.message())); } - return OkStatus(); + return absl::OkStatus(); } Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { @@ -514,7 +514,7 @@ Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def)); } } - return OkStatus(); + return absl::OkStatus(); } void ApiDefMap::UpdateDocs() { diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index f8b8f81b15a67a..cd9c83bebc626f 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -93,7 +93,7 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, " expected: ", DataTypeSliceString(expected_inputs), "->", DataTypeSliceString(expected_outputs)); } - return OkStatus(); + return absl::OkStatus(); } const absl::flat_hash_set* GetOpNodeDefsToLogFromEnv() { @@ -196,7 +196,7 @@ Status OpKernel::InputRange(StringPiece input_name, int* start, } else { *start = result->second.first; *stop = result->second.second; - return OkStatus(); + return absl::OkStatus(); } } @@ -208,7 +208,7 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start, } else { *start = result->second.first; *stop = result->second.second; - return OkStatus(); + return absl::OkStatus(); } } @@ -235,12 +235,13 @@ string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const { } string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const { - string trace_string = profiler::TraceMeOp(name_view(), type_string_view()); + string trace_string = + tsl::profiler::TraceMeOp(name_view(), type_string_view()); if (verbose) { string shape = ShapeTraceString(ctx); if (!shape.empty()) { - trace_string = - profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}}); + trace_string = tsl::profiler::TraceMeEncode(std::move(trace_string), + {{"shape", shape}}); } } return trace_string; @@ -302,7 +303,7 @@ Status OpKernelConstruction::allocate_temp(DataType type, def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); } *out_temp = new_temp; - return OkStatus(); + return absl::OkStatus(); } Status OpKernelConstruction::allocate_temp(DataType type, @@ -327,7 +328,7 @@ Status OpKernelConstruction::allocate_temp(DataType type, def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); } *out_temp = new_temp; - return OkStatus(); + return absl::OkStatus(); } // OpKernelContext ----------------------------------------------------------- @@ -411,7 +412,7 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { "' when non-ref input was expected"); } *tensor = params_->inputs[index].tensor; - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { @@ -419,14 +420,14 @@ Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { TF_RETURN_IF_ERROR(get_input_index(name, &index)); const TensorValue& value(params_->inputs[index]); *dtype = value.dtype(); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { int index; TF_RETURN_IF_ERROR(get_input_index(name, &index)); *out_mutex = input_ref_mutex(index); - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr OpKernelContext::get_input(int index) const { @@ -516,7 +517,7 @@ Status OpKernelContext::forward_input_to_output_with_shape( return errors::FailedPrecondition("OpKernel could not forward input '", input_name, "' to output '", output_name); } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr OpKernelContext::forward_input( @@ -588,7 +589,7 @@ std::unique_ptr OpKernelContext::forward_input( } Status OpKernelContext::forward_input_or_allocate_temp( - gtl::ArraySlice candidate_input_indices, DataType type, + absl::Span candidate_input_indices, DataType type, const TensorShape& shape, const AllocatorAttributes& allocator_attr, Tensor* out_temp) { for (int input_index : candidate_input_indices) { @@ -597,14 +598,14 @@ Status OpKernelContext::forward_input_or_allocate_temp( type, shape, DEVICE_MEMORY, allocator_attr); if (new_tensor != nullptr) { *out_temp = std::move(*new_tensor); - return OkStatus(); + return absl::OkStatus(); } } return allocate_temp(type, shape, out_temp, allocator_attr); } Status OpKernelContext::forward_input_or_allocate_output( - gtl::ArraySlice candidate_input_indices, int output_index, + absl::Span candidate_input_indices, int output_index, const TensorShape& output_shape, Tensor** output, int* forwarded_input) { for (int input_index : candidate_input_indices) { if (forward_input_to_output_with_shape(input_index, output_index, @@ -612,7 +613,7 @@ Status OpKernelContext::forward_input_or_allocate_output( if (forwarded_input != nullptr) { *forwarded_input = input_index; } - return OkStatus(); + return absl::OkStatus(); } } if (forwarded_input != nullptr) { @@ -622,13 +623,13 @@ Status OpKernelContext::forward_input_or_allocate_output( } Status OpKernelContext::forward_input_or_allocate_output( - gtl::ArraySlice candidate_input_names, StringPiece output_name, - const TensorShape& output_shape, Tensor** output) { + absl::Span candidate_input_names, + StringPiece output_name, const TensorShape& output_shape, Tensor** output) { for (const StringPiece& input_name : candidate_input_names) { if (forward_input_to_output_with_shape(input_name, output_name, output_shape, output) .ok()) { - return OkStatus(); + return absl::OkStatus(); } } return allocate_output(output_name, output_shape, output); @@ -662,7 +663,7 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, tf_shared_lock l(*input_ref_mutex(index)); *tensor = *params_->inputs[index].tensor; } - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::replace_ref_input(StringPiece name, @@ -675,14 +676,14 @@ Status OpKernelContext::replace_ref_input(StringPiece name, "' when ref input was expected"); } replace_ref_input(index, tensor, lock_held); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::input_list(StringPiece name, OpInputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpInputList(this, start, stop); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::mutable_input_list(StringPiece name, @@ -690,14 +691,14 @@ Status OpKernelContext::mutable_input_list(StringPiece name, int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpMutableInputList(this, start, stop); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); *list = OpOutputList(this, start, stop); - return OkStatus(); + return absl::OkStatus(); } void OpKernelContext::maybe_initialize_scope_id_set() { @@ -779,7 +780,7 @@ Status OpKernelContext::allocate_tensor( params_->step_id, new_tensor); } *out_tensor = std::move(new_tensor); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::allocate_output(int index, const TensorShape& shape, @@ -889,7 +890,7 @@ Status OpKernelContext::get_input_index(StringPiece name, "expected"); } *out_index = start; - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::get_output_index(StringPiece name, @@ -903,21 +904,21 @@ Status OpKernelContext::get_output_index(StringPiece name, "expected"); } *out_index = start; - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output(index, tensor); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output(index, std::move(tensor)); - return OkStatus(); + return absl::OkStatus(); } bool OpKernelContext::maybe_set_output_by_allocate_and_copy( @@ -1025,14 +1026,14 @@ Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu, int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output_ref(index, mu, tensor_for_ref); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); *tensor = mutable_output(index); - return OkStatus(); + return absl::OkStatus(); } bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { @@ -1200,7 +1201,7 @@ static Status IsProbablySafeToLoad(const string& path) { errmsg.append(absl::StrJoin(missing_features, ", ")); return errors::FailedPrecondition(errmsg); } - return OkStatus(); + return absl::OkStatus(); } void LoadDynamicKernelsInternal() { @@ -1453,7 +1454,7 @@ Status FindKernelRegistration( } } - return OkStatus(); + return absl::OkStatus(); } Status FindKernelRegistration(const DeviceType& device_type, @@ -1517,7 +1518,7 @@ Status FindKernelDef( } if (def != nullptr) *def = ®->def; if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name; - return OkStatus(); + return absl::OkStatus(); } Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, @@ -1596,7 +1597,7 @@ Status SupportedDeviceTypesForNode( prioritized_device_types->push_back(std::make_pair(device_type, 0)); } } - return OkStatus(); + return absl::OkStatus(); } void LogAllRegisteredKernels() { @@ -1782,7 +1783,7 @@ Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) { } } } - return OkStatus(); + return absl::OkStatus(); } template <> diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index a4373446481d93..bea1208053c5e2 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -422,8 +422,8 @@ TEST_F(OpKernelTest, InputDtype) { Tensor a(DT_FLOAT, TensorShape({})); Tensor b(DT_INT32, TensorShape({})); Tensor c(DT_UINT8, TensorShape({})); - gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b), - TensorValue(&c)}; + absl::InlinedVector inputs{TensorValue(&a), TensorValue(&b), + TensorValue(&c)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); @@ -448,7 +448,7 @@ TEST_F(OpKernelTest, InputOnly) { EXPECT_TRUE(status.ok()); params.op_kernel = op.get(); Tensor a(DT_FLOAT, TensorShape({})); - gtl::InlinedVector inputs{TensorValue(&a)}; + absl::InlinedVector inputs{TensorValue(&a)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); @@ -475,8 +475,8 @@ TEST_F(OpKernelTest, RefInputs) { Tensor* a = new Tensor(DT_FLOAT, TensorShape({})); Tensor* b = new Tensor(DT_FLOAT, TensorShape({2})); mutex mu_a, mu_b; - gtl::InlinedVector inputs{TensorValue(&mu_a, a), - TensorValue(&mu_b, b)}; + absl::InlinedVector inputs{TensorValue(&mu_a, a), + TensorValue(&mu_b, b)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); @@ -502,7 +502,7 @@ TEST_F(OpKernelTest, AllocateOutput) { params.op_kernel = op.get(); Tensor a(DT_FLOAT, TensorShape({})); Tensor b(DT_INT32, TensorShape({})); - gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; + absl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); Tensor* output = nullptr; @@ -566,7 +566,7 @@ class ScopedAllocatorDevice : public DeviceBase { StatusCallback done) override { CHECK(input_tensor->NumElements() == output_tensor->NumElements()); tensor::DeepCopy(*input_tensor, output_tensor); - done(OkStatus()); + done(absl::OkStatus()); } // Return the count of calls to GetAllocator or GetScopedAllocator, depending @@ -641,7 +641,7 @@ TEST_F(OpKernelTest, TraceString) { params.op_kernel = op.get(); Tensor a(DT_FLOAT, TensorShape({4, 8})); - gtl::InlinedVector inputs{TensorValue(&a)}; + absl::InlinedVector inputs{TensorValue(&a)}; params.inputs = inputs; params.op_kernel = op.get(); @@ -1162,7 +1162,7 @@ void BM_TraceString(::testing::benchmark::State& state) { params.op_kernel = op.get(); Tensor a(DT_FLOAT, TensorShape({99000, 256})); Tensor b(DT_FLOAT, TensorShape({256, 256})); - gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; + absl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); diff --git a/tensorflow/core/framework/op_registration_test.cc b/tensorflow/core/framework/op_registration_test.cc index af80036272a367..286a0db358702c 100644 --- a/tensorflow/core/framework/op_registration_test.cc +++ b/tensorflow/core/framework/op_registration_test.cc @@ -27,7 +27,7 @@ namespace { void Register(const string& op_name, OpRegistry* registry) { registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status { op_reg_data->op_def.set_name(op_name); - return OkStatus(); + return absl::OkStatus(); }); } @@ -51,7 +51,7 @@ TEST(OpRegistrationTest, TestDuplicate) { TF_EXPECT_OK( registry->SetWatcher([](const Status& s, const OpDef& op_def) -> Status { EXPECT_TRUE(errors::IsAlreadyExists(s)); - return OkStatus(); + return absl::OkStatus(); })); Register("Foo", registry.get()); s = registry->ProcessRegistrations(); diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc index 42651c8c6dde6c..6af4d8973b3e1c 100644 --- a/tensorflow/core/framework/op_segment.cc +++ b/tensorflow/core/framework/op_segment.cc @@ -46,7 +46,7 @@ Status OpSegment::FindOrCreate(const string& session_handle, } *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name); if (*kernel != nullptr) { - return OkStatus(); + return absl::OkStatus(); } } Status s = create_fn(kernel); @@ -68,7 +68,7 @@ Status OpSegment::FindOrCreate(const string& session_handle, *kernel = *p_kernel; } } - return OkStatus(); + return absl::OkStatus(); } void OpSegment::AddHold(const string& session_handle) { diff --git a/tensorflow/core/framework/ops_util.cc b/tensorflow/core/framework/ops_util.cc index b53fb3e6c2b70c..abe57812774933 100644 --- a/tensorflow/core/framework/ops_util.cc +++ b/tensorflow/core/framework/ops_util.cc @@ -59,7 +59,7 @@ Status GetBroadcastSize(const int index, const int in_size, const int ksize, if (*bindex + ksize > in_size) { *bsize = std::min((in_size - *bindex), ksize); } - return OkStatus(); + return absl::OkStatus(); } string SanitizeThreadSuffix(string suffix) { diff --git a/tensorflow/core/framework/partial_tensor_shape_test.cc b/tensorflow/core/framework/partial_tensor_shape_test.cc index 581989c8cc3c31..77f81cc5a8a549 100644 --- a/tensorflow/core/framework/partial_tensor_shape_test.cc +++ b/tensorflow/core/framework/partial_tensor_shape_test.cc @@ -295,14 +295,14 @@ TEST(PartialTensorShapeTest, PartialShapeMergeWith) { const PartialTensorShape e; PartialTensorShape test; - EXPECT_EQ(OkStatus(), a.MergeWith(a, &test)); + EXPECT_EQ(absl::OkStatus(), a.MergeWith(a, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); EXPECT_EQ(test.dim_size(2), 1); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), a.MergeWith(b, &test)); + EXPECT_EQ(absl::OkStatus(), a.MergeWith(b, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), 1); EXPECT_EQ(test.dim_size(1), 0); @@ -312,28 +312,28 @@ TEST(PartialTensorShapeTest, PartialShapeMergeWith) { EXPECT_TRUE(errors::IsInvalidArgument(a.MergeWith(d, &test))); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), a.MergeWith(c, &test)); + EXPECT_EQ(absl::OkStatus(), a.MergeWith(c, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); EXPECT_EQ(test.dim_size(2), 1); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), c.MergeWith(a, &test)); + EXPECT_EQ(absl::OkStatus(), c.MergeWith(a, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); EXPECT_EQ(test.dim_size(2), 1); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), a.MergeWith(e, &test)); + EXPECT_EQ(absl::OkStatus(), a.MergeWith(e, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); EXPECT_EQ(test.dim_size(2), 1); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), e.MergeWith(a, &test)); + EXPECT_EQ(absl::OkStatus(), e.MergeWith(a, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc index 2bc23d0b8a6d30..2e433fb1359d5a 100644 --- a/tensorflow/core/framework/reader_base.cc +++ b/tensorflow/core/framework/reader_base.cc @@ -50,7 +50,7 @@ Status ReaderBase::ResetLocked() { work_finished_ = 0; num_records_produced_ = 0; work_.clear(); - return OkStatus(); + return absl::OkStatus(); } Status ReaderBase::SerializeState(tstring* state) { @@ -261,7 +261,7 @@ Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { "Inconsistent work started vs. finished when restoring in ", name(), ": ", debug_string); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/reader_base.h b/tensorflow/core/framework/reader_base.h index 8f4e347e09aa99..644a5618f7564e 100644 --- a/tensorflow/core/framework/reader_base.h +++ b/tensorflow/core/framework/reader_base.h @@ -64,8 +64,8 @@ class ReaderBase : public ReaderInterface { bool* at_end); // Called when work starts / finishes. - virtual Status OnWorkStartedLocked() { return OkStatus(); } - virtual Status OnWorkFinishedLocked() { return OkStatus(); } + virtual Status OnWorkStartedLocked() { return absl::OkStatus(); } + virtual Status OnWorkFinishedLocked() { return absl::OkStatus(); } // Called to reset the Reader to a newly constructed state. virtual Status ResetLocked(); diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h index 36f59717e0e9db..1433a54e5e7d12 100644 --- a/tensorflow/core/framework/reader_op_kernel.h +++ b/tensorflow/core/framework/reader_op_kernel.h @@ -76,7 +76,7 @@ class ReaderOpKernel : public ResourceOpKernel { } std::function temp = nullptr; factory_.swap(temp); - return OkStatus(); + return absl::OkStatus(); } std::function factory_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index efea3e2597c803..1792a1c1fed17d 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -109,7 +109,7 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { out->src_device = StringPiece(parts[0].data(), parts[0].size()); out->dst_device = StringPiece(parts[2].data(), parts[2].size()); out->edge_name = StringPiece(parts[3].data(), parts[3].size()); - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument("Invalid rendezvous key: ", key); } diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index 1212fadfc1bdc8..1c52e259ba55b1 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -403,7 +403,7 @@ class DummyDeviceContext : public DeviceContext { void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, Tensor* output_tensor, StatusCallback done) const override { - done(OkStatus()); + done(absl::OkStatus()); } private: diff --git a/tensorflow/core/framework/resource_handle.cc b/tensorflow/core/framework/resource_handle.cc index bc6e459a6566e9..0fe49206846a5f 100644 --- a/tensorflow/core/framework/resource_handle.cc +++ b/tensorflow/core/framework/resource_handle.cc @@ -96,7 +96,7 @@ Status ResourceHandle::FromProto(const ResourceHandleProto& proto) { dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{dtype, shape}); } dtypes_and_shapes_ = std::move(dtypes_and_shapes); - return OkStatus(); + return absl::OkStatus(); } string ResourceHandle::SerializeAsString() const { @@ -157,7 +157,7 @@ Status ResourceHandle::ValidateType(const TypeIndex& type_index) const { port::Demangle(type_index.name()), "' (hash code ", type_index.hash_code(), ")"); } - return OkStatus(); + return absl::OkStatus(); } std::atomic ResourceHandle::current_id_; diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 872665170ae08a..a738f8d735addd 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -61,7 +61,7 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, context->allocate_output(output_index, TensorShape({}), &handle)); handle->scalar()() = MakeResourceHandle(container, name, *context->device(), type_index); - return OkStatus(); + return absl::OkStatus(); } namespace internal { @@ -72,7 +72,7 @@ Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { "Trying to access resource ", p.name(), " located in device ", p.device(), " from device ", ctx->device()->attributes().name()); } - return OkStatus(); + return absl::OkStatus(); } } // end namespace internal @@ -84,7 +84,7 @@ Status ResourceMgr::InsertDebugTypeName(uint64 hash_code, return errors::AlreadyExists("Duplicate hash code found for type ", type_name); } - return OkStatus(); + return absl::OkStatus(); } const char* ResourceMgr::DebugTypeName(uint64 hash_code) const { @@ -219,7 +219,7 @@ Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type, auto st = container->insert(std::move(key_and_value)); if (st.second) { TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name())); - return OkStatus(); + return absl::OkStatus(); } return errors::AlreadyExists("Resource ", container_name, "/", name, "/", type.name()); @@ -259,7 +259,7 @@ Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code, type_name, " has been destroyed."); } *resource = ptr; - return OkStatus(); + return absl::OkStatus(); } Status ResourceMgr::PopResourceAndName(const string& container, @@ -279,7 +279,7 @@ Status ResourceMgr::PopResourceAndName(const string& container, } std::swap(resource_and_name, iter->second); b->erase(iter); - return OkStatus(); + return absl::OkStatus(); } Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code, @@ -297,7 +297,7 @@ Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code, "This indicates ref-counting ResourceHandle is exposed to weak " "ResourceHandle code paths."); } - return OkStatus(); + return absl::OkStatus(); } Status ResourceMgr::DoDelete(const string& container, TypeIndex type, @@ -315,7 +315,7 @@ Status ResourceMgr::Cleanup(const string& container) { tf_shared_lock l(mu_); if (!gtl::FindOrNull(containers_, container)) { // Nothing to cleanup. - return OkStatus(); + return absl::OkStatus(); } } Container* b = nullptr; @@ -324,14 +324,14 @@ Status ResourceMgr::Cleanup(const string& container) { auto iter = containers_.find(container); if (iter == containers_.end()) { // Nothing to cleanup, it's OK (concurrent cleanup). - return OkStatus(); + return absl::OkStatus(); } b = iter->second; containers_.erase(iter); } CHECK(b != nullptr); delete b; - return OkStatus(); + return absl::OkStatus(); } static bool IsValidContainerName(StringPiece s) { @@ -373,7 +373,7 @@ Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, static std::atomic counter(0); name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name()); } - return OkStatus(); + return absl::OkStatus(); } string ContainerInfo::DebugString() const { @@ -394,7 +394,7 @@ Status HandleFromInput(OpKernelContext* ctx, int input, return absl::InvalidArgumentError("Empty resource handle"); } *handle = tensor->flat()(0); - return OkStatus(); + return absl::OkStatus(); } Status HandleFromInput(OpKernelContext* ctx, StringPiece input, @@ -405,7 +405,7 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input, return absl::InvalidArgumentError("Empty resource handle"); } *handle = tensor->flat()(0); - return OkStatus(); + return absl::OkStatus(); } Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, @@ -414,7 +414,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, if (p.IsRefCounting()) { TF_ASSIGN_OR_RETURN(*value, p.GetResource()); (*value)->Ref(); - return OkStatus(); + return absl::OkStatus(); } return ctx->resource_manager()->Lookup(p, value); } @@ -422,7 +422,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); if (p.IsRefCounting()) { - return OkStatus(); + return absl::OkStatus(); } return ctx->resource_manager()->Delete(p); } diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index b13de22dd49e99..658ed31ebfea9f 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -680,7 +680,7 @@ Status ResourceMgr::LookupMany( (*resources)[i].reset(resource); } } - return OkStatus(); + return absl::OkStatus(); } // Simple wrapper to allow conditional dynamic / static casts. @@ -777,7 +777,7 @@ template Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); TF_RETURN_IF_ERROR(p.ValidateType()); - return OkStatus(); + return absl::OkStatus(); } } // namespace internal @@ -804,7 +804,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, TF_ASSIGN_OR_RETURN(*value, p.GetResource()); // Transfers out a new reference. (*value)->Ref(); - return OkStatus(); + return absl::OkStatus(); } return ctx->resource_manager()->Lookup(p.container(), @@ -825,7 +825,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, TF_RETURN_IF_ERROR(LookupResource(ctx, p, &raw_ptr)); value->reset(raw_ptr); - return OkStatus(); + return absl::OkStatus(); } // Similar to Lookup, but looks up multiple resources at once, with only a @@ -872,7 +872,7 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, TF_RETURN_IF_ERROR(LookupOrCreateResource(ctx, p, &raw_ptr, creator)); value->reset(raw_ptr); - return OkStatus(); + return absl::OkStatus(); } // Deletes the resource pointed by "p", using the resource manager in "ctx". @@ -883,7 +883,7 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { // NOTE(feyu): if we can convert all resources handle to ref-counting, then // DeleteResource can be removed. if (p.IsRefCounting()) { - return OkStatus(); + return absl::OkStatus(); } return ctx->resource_manager()->Delete(p.container(), p.name()); } diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index 5c079cb2ac7318..6b12270ab97528 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -73,7 +73,7 @@ string LookupOrCreate(ResourceMgr* rm, const string& container, T* r; TF_CHECK_OK(rm->LookupOrCreate(container, name, &r, [&label](T** ret) { *ret = new T(label); - return OkStatus(); + return absl::OkStatus(); })); const string ret = r->DebugString(); r->Unref(); @@ -240,7 +240,7 @@ TEST(ResourceMgrTest, CreateOrLookupRaceCondition) { Env::Default()->SleepForMicroseconds(1 * 1000 * 1000); atomic_int += 1; *ret = new Resource("label"); - return OkStatus(); + return absl::OkStatus(); })); r->Unref(); }); @@ -265,7 +265,7 @@ Status ComputePolicy(const string& attr_container, } TF_RETURN_IF_ERROR(cinfo.Init(&rmgr, ndef, use_node_name_as_default)); *result = cinfo.DebugString(); - return OkStatus(); + return absl::OkStatus(); } string Policy(const string& attr_container, const string& attr_shared_name, diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index d74366937210c9..71d856eaeebb6b 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -1288,6 +1288,10 @@ bool InferenceContext::RelaxHandleShapesAndMergeTypes( bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, output_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << output_handle_shapes_and_types_.size() << " inputs."; if (output_handle_shapes_and_types_[idx] == nullptr) { output_handle_shapes_and_types_[idx].reset( new std::vector(shapes_and_types)); @@ -1299,6 +1303,10 @@ bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( bool InferenceContext::RelaxInputHandleShapesAndMergeTypes( int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, input_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << input_handle_shapes_and_types_.size() << " inputs."; if (input_handle_shapes_and_types_[idx] == nullptr) { input_handle_shapes_and_types_[idx].reset( new std::vector(shapes_and_types)); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index f00dac88fd0388..6ed932e0c78189 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -683,6 +683,10 @@ class InferenceContext { void set_input_handle_shapes_and_types( int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, input_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << input_handle_shapes_and_types_.size() << " inputs."; input_handle_shapes_and_types_[idx] = absl::make_unique>(shapes_and_types); } @@ -690,17 +694,29 @@ class InferenceContext { // Returns the output handle shapes and types, for the resource tensor output // at index . Returns NULL if the shape and types were never set. const std::vector* output_handle_shapes_and_types(int idx) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, output_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << output_handle_shapes_and_types_.size() << " outputs."; return output_handle_shapes_and_types_[idx].get(); } // Returns the inputs handle shapes and types, for the resource tensor input // at index . Returns NULL if the shape and types were not available. const std::vector* input_handle_shapes_and_types(int idx) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, input_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << input_handle_shapes_and_types_.size() << " inputs."; return input_handle_shapes_and_types_[idx].get(); } void set_output_handle_shapes_and_types( int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, output_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << output_handle_shapes_and_types_.size() << " inputs."; output_handle_shapes_and_types_[idx] = absl::make_unique>(shapes_and_types); } diff --git a/tensorflow/core/framework/typed_allocator.h b/tensorflow/core/framework/typed_allocator.h index 20e16358f2c4c3..6d89983b2fb575 100644 --- a/tensorflow/core/framework/typed_allocator.h +++ b/tensorflow/core/framework/typed_allocator.h @@ -56,7 +56,8 @@ class TypedAllocator { size_t num_elements) { if (ptr) { RunDtor(raw_allocator, ptr, num_elements); - raw_allocator->DeallocateRaw(ptr); + raw_allocator->DeallocateRaw(ptr, Allocator::kAllocatorAlignment, + sizeof(T) * num_elements); } } diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index 638a6a33f9395f..cf159922c51daa 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -386,8 +386,8 @@ void MutableGraphView::AddAndDedupFanouts(NodeDef* node) { fanouts()[output].emplace(node, Graph::kControlSlot); } else { max_input_port = pos; - max_regular_output_port()[output.node] = - std::max(max_regular_output_port()[output.node], output.port_id); + int& max_port = max_regular_output_port()[output.node]; + max_port = std::max(max_port, output.port_id); fanouts()[output].emplace(node, pos); } ++pos; diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index e23e1c04fe1df6..6f867024bb9000 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -8,7 +8,6 @@ package( "//tensorflow/core/data:__pkg__", "//tensorflow/core/data/service:__pkg__", "//tensorflow/core/grappler/optimizers/data:__subpackages__", - "//tensorflow/core/kernels/data:__pkg__", "//tensorflow/core/kernels/data/experimental:__pkg__", ], licenses = ["notice"], diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 7a02b8283e752e..5ddff709e7435c 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -3085,7 +3085,7 @@ class XlaCpuJitDisableFusionTest : public RemapperTest { } Remapper optimizer(RewriterConfig::ON, RewriterConfig::NO_CONVERSION_ON_CPU, - /*xla_clustering_on=*/true); + /*xla_auto_clustering_on=*/true); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); diff --git a/tensorflow/core/ir/importexport/convert_attributes.cc b/tensorflow/core/ir/importexport/convert_attributes.cc index 24ab1d2c12cba2..dee8e7eb4c21d5 100644 --- a/tensorflow/core/ir/importexport/convert_attributes.cc +++ b/tensorflow/core/ir/importexport/convert_attributes.cc @@ -417,9 +417,10 @@ absl::StatusOr ConvertAttribute( default: return InvalidArgument("Unsupported attr kind in FullType"); } - - return FullTypeAttr::get(builder.getContext(), full_type.type_id(), args, - attr); + IntegerAttr type_id_attr = + mlir::IntegerAttr::get(mlir::IntegerType::get(builder.getContext(), 32), + static_cast(full_type.type_id())); + return FullTypeAttr::get(builder.getContext(), type_id_attr, args, attr); } absl::StatusOr ConvertAttribute( @@ -447,7 +448,8 @@ absl::StatusOr ConvertAttribute( mlir::debugString(full_type.getAttr())); } - ret.set_type_id(static_cast(full_type.getTypeId())); + ret.set_type_id( + static_cast(full_type.getTypeId().getInt())); return ret; } diff --git a/tensorflow/core/ir/types/attributes.td b/tensorflow/core/ir/types/attributes.td index c0af7de6f12b8e..3215c52212a90d 100644 --- a/tensorflow/core/ir/types/attributes.td +++ b/tensorflow/core/ir/types/attributes.td @@ -299,7 +299,7 @@ def TFType_FullTypeId : I32EnumAttr<"FullTypeId", "", [ I32EnumAttrCase<"TFT_LEGACY_VARIANT", 10203, "legacy_variant"> ]> { let cppNamespace = "::mlir::tf_type"; - string cppType = "int32_t"; + string cppType = "::mlir::IntegerAttr"; let genSpecializedAttr = 0; } @@ -320,7 +320,7 @@ def TFType_FullTypeAttr : AttrDef { let parameters = (ins TFType_FullTypeId:$type_id, TFType_FullTypeArgsAttr:$args, - TFType_FullTypeAttrAttr:$attr + "Attribute":$attr ); let mnemonic = "full_type"; let hasCustomAssemblyFormat = 1; diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc index 9805e17462876f..db175cfa089936 100644 --- a/tensorflow/core/ir/types/dialect.cc +++ b/tensorflow/core/ir/types/dialect.cc @@ -260,8 +260,11 @@ FailureOr RawFullTypeAttrParser(AsmParser& parser) { // Parse variable 'attr' Attribute attr; parser.parseOptionalAttribute(attr); - return FullTypeAttr::get(parser.getContext(), static_cast(*type_id), - args, attr); + return FullTypeAttr::get( + parser.getContext(), + mlir::IntegerAttr::get(mlir::IntegerType::get(parser.getContext(), 32), + static_cast(*type_id)), + args, attr); } Attribute FullTypeAttr::parse(AsmParser& parser, Type odsType) { @@ -272,7 +275,8 @@ Attribute FullTypeAttr::parse(AsmParser& parser, Type odsType) { } static void RawFullTypeAttrPrint(FullTypeAttr tfattr, AsmPrinter& printer) { - printer << stringifyFullTypeId(tf_type::FullTypeId(tfattr.getTypeId())); + printer << stringifyFullTypeId( + tf_type::FullTypeId(tfattr.getTypeId().getInt())); if (!tfattr.getArgs().empty()) { printer << "<"; llvm::interleaveComma(tfattr.getArgs(), printer, [&](Attribute arg) { diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc index 4a6bfe0bf046ff..31a40b1d5b9662 100644 --- a/tensorflow/core/kernels/aggregate_ops.cc +++ b/tensorflow/core/kernels/aggregate_ops.cc @@ -186,7 +186,7 @@ class AddNOp : public OpKernel { TF_RETURN_IF_ERROR( BinaryOpVariants(ctx, ADD_VARIANT_BINARY_OP, a, b, c)); temp_filled->at(lhs_ix) = true; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 6c70f5cf05eff4..bd93c1ec3a02a3 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -178,7 +178,8 @@ class BatchResource : public serving::BatchResourceBase { /*mixed_priority_batching_policy=*/ serving::MixedPriorityBatchingPolicy:: kLowPriorityPaddingWithMaxBatchSize, - enable_large_batch_splitting, resource); + enable_large_batch_splitting, + /*batch_padding_policy=*/"PAD_UP", resource); } static Status Create( @@ -191,7 +192,7 @@ class BatchResource : public serving::BatchResourceBase { int32_t low_priority_max_enqueued_batches, const std::vector& low_priority_allowed_batch_sizes, serving::MixedPriorityBatchingPolicy mixed_priority_batching_policy, - bool enable_large_batch_splitting, + bool enable_large_batch_splitting, absl::string_view batch_padding_policy, std::unique_ptr* resource) { BatcherT::Options batcher_options; batcher_options.num_batch_threads = num_batch_threads; @@ -204,8 +205,7 @@ class BatchResource : public serving::BatchResourceBase { num_batch_threads, max_execution_batch_size, batch_timeout_micros, max_enqueued_batches, allowed_batch_sizes, enable_large_batch_splitting, - /*disable_padding=*/false, - /*batch_padding_policy=*/serving::kPadUpPolicy, + /*disable_padding=*/false, batch_padding_policy, low_priority_max_batch_size, low_priority_batch_timeout_micros, low_priority_max_enqueued_batches, low_priority_allowed_batch_sizes, mixed_priority_batching_policy), @@ -441,7 +441,7 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { low_priority_batch_timeout_micros_, low_priority_max_enqueued_batches_, low_priority_allowed_batch_sizes_, mixed_priority_batching_policy, enable_large_batch_splitting_, - &new_resource)); + batch_padding_policy_, &new_resource)); if (session_metadata) { new_resource->set_session_metadata(*session_metadata); } diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc index 2d9b0d6068d2d2..9aaeb5ad5207c8 100644 --- a/tensorflow/core/kernels/batch_kernels_test.cc +++ b/tensorflow/core/kernels/batch_kernels_test.cc @@ -17,11 +17,13 @@ limitations under the License. #include #include +#include #include #include #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" @@ -46,6 +48,7 @@ limitations under the License. #include "tsl/platform/refcount.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" +#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { @@ -90,6 +93,7 @@ class SharedBatchFunctionTestState : public OpsTestBase { // Create common batch function op for testing. absl::StatusOr CreateBatchFunctionBuilder( const std::vector &allowed_batch_sizes, int max_batch_size, + absl::string_view padding_policy, const TensorShape &expected_output_shape) { NameAttrList f; f.set_name("ShapeEnforcingFunction"); @@ -114,13 +118,15 @@ class SharedBatchFunctionTestState : public OpsTestBase { std::vector inputs( {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})}); - return NodeDefBuilder("BatchTPUInput", "BatchFunction") + return NodeDefBuilder(absl::StrCat("BatchTPUInput", padding_policy), + "BatchFunction") .Attr("max_batch_size", max_batch_size) .Attr("num_batch_threads", 8) .Attr("allowed_batch_sizes", allowed_batch_sizes) .Attr("batch_timeout_micros", 1000000) .Attr("max_enqueued_batches", 10) .Attr("enable_large_batch_splitting", true) + .Attr("batch_padding_policy", padding_policy) .Attr("Tin", {DataType::DT_INT64}) .Input(inputs) .Attr("Tcaptured", std::vector{}) @@ -144,7 +150,7 @@ class BatchFunctionTestState : public SharedBatchFunctionTestState { const TensorShape expected_output_shape({expected_batch_size, 2}); TF_ASSIGN_OR_RETURN( NodeDefBuilder builder, - CreateBatchFunctionBuilder({4, 8}, 8, expected_output_shape)); + CreateBatchFunctionBuilder({4, 8}, 8, "PAD_UP", expected_output_shape)); TF_RETURN_IF_ERROR(builder .Attr("low_priority_max_batch_size", enable_low_priority_queue ? 8 : 0) @@ -592,7 +598,7 @@ class BatchFunctionKernelParallelWarmupTestState TF_ASSIGN_OR_RETURN( NodeDefBuilder builder, CreateBatchFunctionBuilder({2, 4, 8}, enable_splitting ? 16 : 8, - expected_output_shape)); + "PAD_UP", expected_output_shape)); TF_RETURN_IF_ERROR(builder.Finalize(node_def())); return OpsTestBase::InitOp(); @@ -665,5 +671,80 @@ INSTANTIATE_TEST_SUITE_P(BatchFunctionKernelParallelWarmupTestSuite, BatchFunctionKernelParallelWarmupTest, ::testing::Bool()); +class BatchFunctionKernelPaddingTestState + : public SharedBatchFunctionTestState { + public: + // Init test fixture with a batch kernel instance. + absl::Status Init(absl::string_view padding_policy, int expected_batch_size) { + static auto *const cpu_device = []() { + auto device = + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); + return device.release(); + }(); + + // Override the per-test/per-op device with a global device so that it can + // be shared between ops. + device_ = cpu_device; + + const TensorShape expected_output_shape({expected_batch_size, 2}); + TF_RETURN_IF_ERROR(CreateBatchFunctionBuilder({4, 8}, 8, padding_policy, + expected_output_shape) + ->Finalize(node_def())); + + return OpsTestBase::InitOp(); + } + + void TestBody() override {} +}; + +class BatchFunctionKernelPaddingTest + : public ::testing::TestWithParam {}; + +TEST_P(BatchFunctionKernelPaddingTest, PadUp) { + SessionMetadata session_metadata; + session_metadata.set_name("test_model"); + session_metadata.set_version(123); + + // Send 5 requests in parallel and check that the given batch padding + // policy behaves as expected. + int64_t num_requests = 5; + int64_t expected_batch_size = 0; + std::string padding_policy = GetParam(); + if (padding_policy == "PAD_UP") { + expected_batch_size = 8; + } else if (padding_policy == "BATCH_DOWN") { + expected_batch_size = 4; + } else if (padding_policy == "MINIMIZE_TPU_COST_PER_REQUEST") { + expected_batch_size = 8; + } else { + FAIL() << "Unsupported padding policy: " << padding_policy; + } + + { + tsl::BlockingCounter blocking_counter(num_requests); + for (int i = 0; i < num_requests; ++i) { + Env::Default()->SchedClosure([&]() { + BatchFunctionKernelPaddingTestState test_state; + test_state.set_session_metadata(session_metadata); + TF_CHECK_OK(test_state.Init(padding_policy, expected_batch_size)); + test_state.AddInputFromList(TensorShape({1, 2}), {123, 456}); + TF_EXPECT_OK(test_state.RunOpKernel()); + + test::ExpectTensorEqual( + *test_state.GetOutput(0), + test::AsTensor({123, 456}, TensorShape({1, 2}))); + blocking_counter.DecrementCount(); + }); + } + + blocking_counter.Wait(); + } +} + +INSTANTIATE_TEST_SUITE_P(BatchFunctionKernelPaddingTestSuite, + BatchFunctionKernelPaddingTest, + ::testing::Values("PAD_UP", "BATCH_DOWN", + "MINIMIZE_TPU_COST_PER_REQUEST")); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_stats.h b/tensorflow/core/kernels/batching_util/batch_stats.h index 4b23f9886e8d15..87c36fca0c02a1 100644 --- a/tensorflow/core/kernels/batching_util/batch_stats.h +++ b/tensorflow/core/kernels/batching_util/batch_stats.h @@ -206,14 +206,12 @@ class ModelBatchStats { // RegisterQuerySize for more details. std::atomic cumulative_processed_size_ = 0; - // The number of batch threads assigned to this model. Set to -1 if there is - // no batch thread count information for this model. - std::atomic num_batch_threads_ = -1; + // The number of batch threads assigned to this model. + std::atomic num_batch_threads_ = kNumBatchThreadsUnknown; // The timeout in microseconds for this model (after which the current batch - // is sent to be processed by the TPU). Set to -1 if there is no batch - // timeout information for this model. - std::atomic batch_timeout_micros_ = -1; + // is sent to be processed by the TPU). + std::atomic batch_timeout_micros_ = kBatchTimeoutMicrosUnknown; }; // Tracks batch statistics for all models. diff --git a/tensorflow/core/kernels/batchtospace_op.cc b/tensorflow/core/kernels/batchtospace_op.cc index 624b136d30a574..50ad9472a39198 100644 --- a/tensorflow/core/kernels/batchtospace_op.cc +++ b/tensorflow/core/kernels/batchtospace_op.cc @@ -64,8 +64,8 @@ static void BatchToSpaceOpCompute(OpKernelContext* context, orig_crops.shape().DebugString())); // To avoid out-of-bounds access in the case that the block_shape and/or // crops tensors are concurrently modified, we must copy the values. - gtl::InlinedVector block_shape; - gtl::InlinedVector crops; + absl::InlinedVector block_shape; + absl::InlinedVector crops; internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape); internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops); diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc index b60c5dd763923b..b4959d43d9c5e5 100644 --- a/tensorflow/core/kernels/bcast_ops.cc +++ b/tensorflow/core/kernels/bcast_ops.cc @@ -31,7 +31,7 @@ class BCastArgsOp : public OpKernel { OP_REQUIRES( ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const Tensor& in = ctx->input(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), @@ -81,7 +81,7 @@ class BCastGradArgsOp : public OpKernel { OP_REQUIRES( ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const Tensor& in = ctx->input(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc index 1a1e55ed067fd3..d6f8d3dbad9ed0 100644 --- a/tensorflow/core/kernels/bincount_op.cc +++ b/tensorflow/core/kernels/bincount_op.cc @@ -81,7 +81,7 @@ struct BincountFunctor { Eigen::array reduce_dim({0}); output.device(context->eigen_cpu_device()) = partial_bins.any(reduce_dim).cast(); - return OkStatus(); + return absl::OkStatus(); } }; @@ -164,7 +164,7 @@ struct BincountFunctor { Eigen::array reduce_dim({0}); output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dim); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -209,7 +209,7 @@ struct BincountReduceFunctor { static_cast(err_neg_val))); } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/bucketize_op.cc b/tensorflow/core/kernels/bucketize_op.cc index 03dc11ffe62ad4..179a930da5790c 100644 --- a/tensorflow/core/kernels/bucketize_op.cc +++ b/tensorflow/core/kernels/bucketize_op.cc @@ -44,7 +44,7 @@ struct BucketizeFunctor { output(i) = first_bigger_it - boundaries_vector.begin(); } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.cc b/tensorflow/core/kernels/conv_grad_shape_utils.cc index 00aceb02e31f5b..0be69d2689e7be 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.cc +++ b/tensorflow/core/kernels/conv_grad_shape_utils.cc @@ -95,7 +95,7 @@ Status ConvBackpropExtractAndVerifyDimension( Status ConvBackpropComputeDimensionsV2( StringPiece label, int num_spatial_dims, const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& out_backprop_shape, - const absl::Span& dilations, const std::vector& strides, + const absl::Span dilations, const std::vector& strides, Padding padding, absl::Span explicit_paddings, TensorFormat data_format, ConvBackpropDimensions* dims) { // The + 2 in the following line is for the batch and feature dimensions. diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.h b/tensorflow/core/kernels/conv_grad_shape_utils.h index 8d105a9df92e0a..f61f53ee13cc38 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.h +++ b/tensorflow/core/kernels/conv_grad_shape_utils.h @@ -44,7 +44,7 @@ struct ConvBackpropSpatialDimension { // Computed dimensions for a backwards convolution. struct ConvBackpropDimensions { // Information about each spatial dimension. - gtl::InlinedVector spatial_dims; + absl::InlinedVector spatial_dims; // Batch size. int64_t batch_size; @@ -80,7 +80,7 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims, Status ConvBackpropComputeDimensionsV2( StringPiece label, int num_spatial_dims, const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& out_backprop_shape, - const absl::Span& dilations, const std::vector& strides, + absl::Span dilations, const std::vector& strides, Padding padding, absl::Span explicit_paddings, TensorFormat data_format, ConvBackpropDimensions* dims); diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 72b230c0d71485..2774d25747340b 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -85,6 +85,13 @@ constexpr char kErrorMessage[] = "error_message"; // Period between reporting dataset statistics. constexpr int kStatsReportingPeriodMillis = 1000; +// Factor used to determine the autotune parallelism limit when using an +// unbounded threadpool. The limit is determined by multiplying this factor +// by the default threadpool size, which is typically based on the number of +// CPU cores. Without this limit, we see autotune sometimes choose unreasonably +// large values for the parallelism, e.g. creating 300k threads. +constexpr int kUnboundedThreadpoolAutotuningFactor = 10; + } // namespace class ParallelMapDatasetOp::Dataset : public DatasetBase { @@ -338,12 +345,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { std::shared_ptr CreateNode( IteratorContext* ctx, model::Node::Args args) const override { std::shared_ptr parameter; - // If unbounded threadpool is used, sets the max of `num_parallel_calls` - // to be infinite and lets Autotune find the right value that is under - // the ram budget. - double max_parallelism_value = use_unbounded_threadpool_ - ? std::numeric_limits::max() - : ctx->runner_threadpool_size(); + double max_parallelism_value = ctx->runner_threadpool_size(); + if (use_unbounded_threadpool_) { + max_parallelism_value *= kUnboundedThreadpoolAutotuningFactor; + } if (num_parallel_calls_ && dataset()->num_parallel_calls_ == model::kAutotune) { parameter = model::MakeParameter( diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 08718909ef0714..b2930d4b45a670 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -400,7 +400,7 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { << " with handle: " << handle; tsl::profiler::TraceMe trace_me( [&] { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( "RemoteCallOp", {{"func_name", func_name}, {"device", target_device}}); }, @@ -411,7 +411,7 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { target_device = std::move(function_target.first)](const Status& status) { tsl::profiler::TraceMe activity( [&] { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( "RemoteCallOpDone", {{"func_name", func_name}, {"device", target_device}}); }, @@ -431,13 +431,13 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { string RemoteCallOp::TraceString(const OpKernelContext& ctx, bool verbose) const { - string trace_string = profiler::TraceMeOp( + string trace_string = tsl::profiler::TraceMeOp( strings::StrCat(name_view(), "__", func_.name()), type_string_view()); if (verbose) { string shape = ShapeTraceString(ctx); if (!shape.empty()) { - trace_string = - profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}}); + trace_string = tsl::profiler::TraceMeEncode(std::move(trace_string), + {{"shape", shape}}); } } return trace_string; diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index 6dc4b07070e81b..79c7e3f1729f2f 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -612,11 +612,12 @@ class EinsumOp : public OpKernel { if (verbose) { string shape = ShapeTraceString(ctx); if (!shape.empty()) { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( std::move(op), {{"equation", equation}, {"shape", shape}}); } } - return profiler::TraceMeEncode(std::move(op), {{"equation", equation}}); + return tsl::profiler::TraceMeEncode(std::move(op), + {{"equation", equation}}); } private: diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index de5205c23201da..5256db35a1f228 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -127,6 +127,8 @@ class RangeOp : public OpKernel { #define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, CPUDevice, T) #define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, GPUDevice, T) +TF_CALL_half(REGISTER_CPU_KERNEL); +TF_CALL_bfloat16(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); TF_CALL_int32(REGISTER_CPU_KERNEL); @@ -134,6 +136,8 @@ TF_CALL_int64(REGISTER_CPU_KERNEL); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TF_CALL_half(REGISTER_GPU_KERNEL); +TF_CALL_bfloat16(REGISTER_GPU_KERNEL); TF_CALL_float(REGISTER_GPU_KERNEL); TF_CALL_double(REGISTER_GPU_KERNEL); TF_CALL_int64(REGISTER_GPU_KERNEL); diff --git a/tensorflow/core/kernels/sequence_ops_gpu.cu.cc b/tensorflow/core/kernels/sequence_ops_gpu.cu.cc index 205978fc1a4ecc..f33b8cc982d2d6 100644 --- a/tensorflow/core/kernels/sequence_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/sequence_ops_gpu.cu.cc @@ -58,6 +58,8 @@ struct RangeFunctor { } // namespace functor #define DEFINE_FUNCTOR(T) template struct functor::RangeFunctor; +TF_CALL_half(DEFINE_FUNCTOR); +TF_CALL_bfloat16(DEFINE_FUNCTOR); TF_CALL_float(DEFINE_FUNCTOR); TF_CALL_double(DEFINE_FUNCTOR); TF_CALL_int32(DEFINE_FUNCTOR); diff --git a/tensorflow/core/kernels/sequence_ops_test.cc b/tensorflow/core/kernels/sequence_ops_test.cc index 1985d631d23739..d0a079f1827428 100644 --- a/tensorflow/core/kernels/sequence_ops_test.cc +++ b/tensorflow/core/kernels/sequence_ops_test.cc @@ -68,6 +68,21 @@ TEST_F(RangeOpTest, Simple_D32) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +TEST_F(RangeOpTest, Simple_Half) { + MakeOp(DT_HALF); + + // Feed and run + AddInputFromList(TensorShape({}), {0.5}); + AddInputFromList(TensorShape({}), {2}); + AddInputFromList(TensorShape({}), {0.3}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output + Tensor expected(allocator(), DT_HALF, TensorShape({5})); + test::FillValues(&expected, {0.5, 0.8, 1.1, 1.4, 1.7}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + TEST_F(RangeOpTest, Simple_Float) { MakeOp(DT_FLOAT); diff --git a/tensorflow/core/kernels/special_math/special_math_op_bessel.cc b/tensorflow/core/kernels/special_math/special_math_op_bessel.cc index 8efa183655e3c3..e29042cea1cd04 100644 --- a/tensorflow/core/kernels/special_math/special_math_op_bessel.cc +++ b/tensorflow/core/kernels/special_math/special_math_op_bessel.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/core/kernels/cwise_ops_common.h" #include "tensorflow/core/kernels/special_math/special_math_op_misc_impl.h" diff --git a/tensorflow/core/lib/histogram/BUILD b/tensorflow/core/lib/histogram/BUILD index 04a698ff39dd2a..8701b2f5c49f6e 100644 --- a/tensorflow/core/lib/histogram/BUILD +++ b/tensorflow/core/lib/histogram/BUILD @@ -25,7 +25,7 @@ cc_library( "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/histogram", + "@local_xla//xla/tsl/lib/histogram", ], alwayslink = True, ) @@ -35,7 +35,7 @@ filegroup( name = "mobile_srcs_only_runtime", srcs = [ "histogram.h", - "@local_tsl//tsl/lib/histogram:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/histogram:mobile_srcs_only_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -44,7 +44,7 @@ filegroup( name = "legacy_lib_histogram_all_headers", srcs = [ "histogram.h", - "@local_tsl//tsl/lib/histogram:legacy_lib_histogram_all_headers", + "@local_xla//xla/tsl/lib/histogram:legacy_lib_histogram_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/histogram/histogram.h b/tensorflow/core/lib/histogram/histogram.h index 551477cf483961..281e190f0bb615 100644 --- a/tensorflow/core/lib/histogram/histogram.h +++ b/tensorflow/core/lib/histogram/histogram.h @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "xla/tsl/lib/histogram/histogram.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/histogram/histogram.h" namespace tensorflow { diff --git a/tensorflow/core/lib/strings/BUILD b/tensorflow/core/lib/strings/BUILD index 72eb0a6dac308c..d8f4e6df21d573 100644 --- a/tensorflow/core/lib/strings/BUILD +++ b/tensorflow/core/lib/strings/BUILD @@ -51,7 +51,7 @@ cc_library( name = "proto_serialization", hdrs = ["proto_serialization.h"], deps = [ - "@local_tsl//tsl/lib/strings:proto_serialization", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], ) @@ -116,7 +116,7 @@ filegroup( "ordered_code.cc", "ordered_code.h", "proto_serialization.h", - "@local_tsl//tsl/lib/strings:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/strings:mobile_srcs_only_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -133,7 +133,7 @@ filegroup( "str_util.h", "strcat.h", "stringprintf.h", - "@local_tsl//tsl/lib/strings:legacy_lib_strings_all_headers", + "@local_xla//xla/tsl/lib/strings:legacy_lib_strings_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -165,7 +165,7 @@ filegroup( "str_util.h", "strcat.h", "stringprintf.h", - "@local_tsl//tsl/lib/strings:legacy_lib_string_headers", + "@local_xla//xla/tsl/lib/strings:legacy_lib_string_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -178,7 +178,7 @@ filegroup( "proto_serialization.h", "proto_text_util.h", "scanner.h", - "@local_tsl//tsl/lib/strings:legacy_lib_internal_public_string_headers", + "@local_xla//xla/tsl/lib/strings:legacy_lib_internal_public_string_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/strings/proto_serialization.h b/tensorflow/core/lib/strings/proto_serialization.h index 0c01708dadf4b2..e0c253f52dbe45 100644 --- a/tensorflow/core/lib/strings/proto_serialization.h +++ b/tensorflow/core/lib/strings/proto_serialization.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_ #define TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_ -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" namespace tensorflow { // NOLINTBEGIN(misc-unused-using-decls) diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index ebaade2c926c8f..b05c4125eaa9bd 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -157,7 +157,6 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g, offset_i.push_back(strings::StrCat("offset:offset:", i)); dx_i.push_back(strings::StrCat("dx_", i, ":output:0")); } - DataTypeVector dtype_list(N, T); // ConcatGrad(dim, x, dy): // for i in range(N): diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc index 99d45512374584..6d21ee483a1948 100644 --- a/tensorflow/core/ops/batch_ops.cc +++ b/tensorflow/core/ops/batch_ops.cc @@ -76,9 +76,17 @@ REGISTER_OP("BatchFunction") // allowed. The following options are available. // // - PAD_UP: pad to size 32. + // - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the + // batch buffer. + // - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses + // to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per + // real request. In this case, it would compare (batch_16_cost / 16) and + // (batch_32_cost / 18). + // + // WARNING: Not all batch schedulers might support this attribute. .Attr( "batch_padding_policy: " - "{'PAD_UP'} = 'PAD_UP'") + "{'PAD_UP', 'BATCH_DOWN', 'MINIMIZE_TPU_COST_PER_REQUEST'} = 'PAD_UP'") .Attr("Tin: list(type)") .Attr("Tcaptured: list(type) >= 0") .Attr("Tout: list(type)") diff --git a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt index 8fecdf6b1490e7..d743b8e513a1b2 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt @@ -802,3 +802,152 @@ op { } is_distributed_communication: true } +op { + name: "BatchFunction" + input_arg { + name: "in_tensors" + type_list_attr: "Tin" + } + input_arg { + name: "captured_tensors" + type_list_attr: "Tcaptured" + } + output_arg { + name: "out_tensors" + type_list_attr: "Tout" + } + attr { + name: "f" + type: "func" + } + attr { + name: "num_batch_threads" + type: "int" + } + attr { + name: "max_batch_size" + type: "int" + } + attr { + name: "batch_timeout_micros" + type: "int" + } + attr { + name: "max_enqueued_batches" + type: "int" + default_value { + i: 10 + } + } + attr { + name: "allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "batching_queue" + type: "string" + default_value { + s: "" + } + } + attr { + name: "low_priority_max_batch_size" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "low_priority_batch_timeout_micros" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "low_priority_allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "low_priority_max_enqueued_batches" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "mixed_priority_policy" + type: "string" + default_value { + s: "low_priority_padding_with_max_batch_size" + } + allowed_values { + list { + s: "low_priority_padding_with_max_batch_size" + s: "low_priority_padding_with_next_allowed_batch_size" + s: "priority_isolation" + } + } + } + attr { + name: "batch_padding_policy" + type: "string" + default_value { + s: "PAD_UP" + } + allowed_values { + list { + s: "PAD_UP" + s: "BATCH_DOWN" + s: "MINIMIZE_TPU_COST_PER_REQUEST" + } + } + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tcaptured" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "enable_large_batch_splitting" + type: "bool" + default_value { + b: false + } + } + is_distributed_communication: true +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index ab35e0e5631851..dcf9e2f0e666e5 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4112,6 +4112,8 @@ op { allowed_values { list { s: "PAD_UP" + s: "BATCH_DOWN" + s: "MINIMIZE_TPU_COST_PER_REQUEST" } } } diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 2059e3b189bf31..3080e97f03feb2 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -53,7 +53,7 @@ tensorflow::GraphDef CreateTestProto() { return g; } -static void ExpectHasSubstr(StringPiece s, StringPiece expected) { +static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } @@ -82,7 +82,7 @@ TEST_F(DefaultEnvTest, IncompleteReadOutOfRange) { TF_EXPECT_OK(env_->NewRandomAccessFile(filename, &f)); // Reading past EOF should give an OUT_OF_RANGE error - StringPiece result; + absl::string_view result; char scratch[3]; EXPECT_EQ(error::OUT_OF_RANGE, f->Read(0, 3, &result, scratch).code()); EXPECT_EQ(input, result); @@ -300,7 +300,7 @@ class TmpDirFileSystem : public NullFileSystem { TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; absl::Status FileExists(const string& dir, TransactionToken* token) override { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(dir, &scheme, &host, &path); if (path.empty()) return errors::NotFound(dir, " not found"); // The special "flushed" file exists only if the filesystem's caches have @@ -316,7 +316,7 @@ class TmpDirFileSystem : public NullFileSystem { } absl::Status CreateDir(const string& dir, TransactionToken* token) override { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(dir, &scheme, &host, &path); if (scheme != "tmpdirfs") { return errors::FailedPrecondition("scheme must be tmpdirfs"); @@ -335,7 +335,7 @@ class TmpDirFileSystem : public NullFileSystem { absl::Status IsDirectory(const string& dir, TransactionToken* token) override { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(dir, &scheme, &host, &path); for (const auto& existing_dir : created_directories_) if (existing_dir == path) return absl::OkStatus(); @@ -405,7 +405,7 @@ TEST_F(DefaultEnvTest, LocalTempFilename) { // Read from the temporary file and check content. std::unique_ptr file_to_read; TF_CHECK_OK(env->NewRandomAccessFile(filename, &file_to_read)); - StringPiece content; + absl::string_view content; char scratch[1024]; CHECK_EQ( error::OUT_OF_RANGE, @@ -427,7 +427,7 @@ TEST_F(DefaultEnvTest, CreateUniqueFileName) { EXPECT_TRUE(env->CreateUniqueFileName(&filename, suffix)); EXPECT_TRUE(absl::StartsWith(filename, prefix)); - EXPECT_TRUE(str_util::EndsWith(filename, suffix)); + EXPECT_TRUE(absl::EndsWith(filename, suffix)); } TEST_F(DefaultEnvTest, GetProcessId) { diff --git a/tensorflow/core/platform/stringpiece.h b/tensorflow/core/platform/stringpiece.h index 17760cd7fee327..66040fc997173c 100644 --- a/tensorflow/core/platform/stringpiece.h +++ b/tensorflow/core/platform/stringpiece.h @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { -using StringPiece = tsl::StringPiece; +using StringPiece = absl::string_view; } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index 5609fd7658d86d..8b228479872bcc 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -227,8 +227,6 @@ OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); using OpMetricBySymbol = absl::flat_hash_map; - absl::flat_hash_map flat_op_metric; - XEventsOpMetricsDbBuilder builder; plane.ForEachLine([&](const XLineVisitor& line) { diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index 9aef8808ff49c6..383aad17ec1bdc 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -69,7 +69,6 @@ inline std::string HloOpEventPrefix(const GpuEventStats& stats) { std::vector GetOrCreateHloOpEventsMetadata( XPlaneBuilder& xplane, const GpuEventStats& stats, const Symbol symbol) { DCHECK(stats.IsXlaOp()); - DCHECK(!stats.hlo_module_name.empty()); std::vector hlo_op_events_metadata; hlo_op_events_metadata.reserve(stats.hlo_op_names.size()); // Prepend an HLO module identifier so HLO operators with the same name but in diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc index 15de9ff05e3e19..ae9decdc19d259 100644 --- a/tensorflow/core/profiler/utils/derived_timeline_test.cc +++ b/tensorflow/core/profiler/utils/derived_timeline_test.cc @@ -71,6 +71,30 @@ TEST(DerivedTimelineTest, HloModuleNameTest) { }); } +// Checks that HLO module events are expanded. +TEST(DerivedTimelineTest, NoHloModuleNameTest) { + const absl::string_view kKernelDetails = "kernel_details"; + XSpace space; + tsl::profiler::GroupMetadataMap group_metadata_map; + XPlane& plane = *GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); + XPlaneBuilder plane_builder(&plane); + auto line_builder = plane_builder.GetOrCreateLine(0); + CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, + {{StatType::kKernelDetails, kKernelDetails}}); + CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, + {{StatType::kKernelDetails, kKernelDetails}}); + GenerateDerivedTimeLines(group_metadata_map, &space); + XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&plane); + // Only the hlo module line is added and other empty lines are removed at the + // end. + EXPECT_EQ(plane_visitor.NumLines(), 1); + plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { + if (line_visitor.Id() == 0) return; + EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); + EXPECT_EQ(line_visitor.NumEvents(), 0); + }); +} + // Checks that the TF op events are expanded. TEST(DerivedTimelineTest, TfOpLineTest) { const absl::string_view kTfOpName = "mul:Mul"; diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.cc b/tensorflow/core/profiler/utils/gpu_event_stats.cc index be4a9246ba4d7f..80de74edec0968 100644 --- a/tensorflow/core/profiler/utils/gpu_event_stats.cc +++ b/tensorflow/core/profiler/utils/gpu_event_stats.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/profiler/utils/gpu_event_stats.h" +#include + #include "absl/strings/str_split.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" @@ -57,7 +59,7 @@ GpuEventStats::GpuEventStats(const XEventVisitor* event) { memcpy_details = stat.StrOrRefValue(); break; case StatType::kCorrelationId: - correlation_id = stat.IntValue(); + correlation_id = static_cast(stat.IntOrUintValue()); break; case StatType::kGroupId: group_id = stat.IntValue(); @@ -79,7 +81,7 @@ LaunchEventStats::LaunchEventStats(const XEventVisitor* event) { device_id = stat.IntOrUintValue(); break; case StatType::kCorrelationId: - correlation_id = stat.IntValue(); + correlation_id = static_cast(stat.IntOrUintValue()); break; case StatType::kGroupId: group_id = stat.IntValue(); diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index 70d5efc7c11a09..d6efbd1cd7a1b1 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -54,6 +54,7 @@ using tsl::profiler::kMetadataPlaneName; // NOLINT using tsl::profiler::kPythonTracerPlaneName; // NOLINT using tsl::profiler::kRoctracerApiPlaneName; // NOLINT using tsl::profiler::kSourceLineName; // NOLINT +using tsl::profiler::kSparseCorePlaneRegex; // NOLINT using tsl::profiler::kStepLineName; // NOLINT using tsl::profiler::kTensorFlowNameScopeLineName; // NOLINT using tsl::profiler::kTensorFlowOpLineName; // NOLINT diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD index 86bf0017f3cfda..c5bfac7a5bd974 100644 --- a/tensorflow/core/protobuf/BUILD +++ b/tensorflow/core/protobuf/BUILD @@ -210,7 +210,7 @@ tf_proto_library( protodeps = [ ":error_codes_proto_impl", "//tensorflow/core/framework:protos_all", - "@local_tsl//tsl/protobuf:bfc_memory_map_proto", + "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", "@local_tsl//tsl/protobuf:coordination_config_proto", "@local_tsl//tsl/protobuf:rpc_options_proto", "@local_tsl//tsl/protobuf:status_proto", @@ -218,9 +218,9 @@ tf_proto_library( tags = ["alt_dep=//third_party/tensorflow/core:protos_all"], visibility = ["//visibility:public"], exports = [ - "@local_tsl//tsl/protobuf:bfc_memory_map_proto", "@local_tsl//tsl/protobuf:rpc_options_proto", "@local_tsl//tsl/protobuf:status_proto", + "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", ], ) diff --git a/tensorflow/core/protobuf/bfc_memory_map.proto b/tensorflow/core/protobuf/bfc_memory_map.proto index 2dbcbf00bc6102..fcde598787250f 100644 --- a/tensorflow/core/protobuf/bfc_memory_map.proto +++ b/tensorflow/core/protobuf/bfc_memory_map.proto @@ -2,6 +2,6 @@ syntax = "proto3"; package tensorflow.dummy; -import public "tsl/protobuf/bfc_memory_map.proto"; +import public "xla/tsl/protobuf/bfc_memory_map.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index dcc61f23379764..a04655fb2ce770 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1948 // Updated: 2024/8/8 +#define TF_GRAPH_DEF_VERSION 1960 // Updated: 2024/8/20 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/runtime_fallback/runtime/BUILD b/tensorflow/core/runtime_fallback/runtime/BUILD index 9b4d014b0d6905..45f433d2d732a9 100644 --- a/tensorflow/core/runtime_fallback/runtime/BUILD +++ b/tensorflow/core/runtime_fallback/runtime/BUILD @@ -195,6 +195,7 @@ cc_library( "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler", "//tensorflow/core/kernels/batching_util:batch_resource_base", "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs", + "//tensorflow/core/kernels/batching_util:batch_stats", "//tensorflow/core/kernels/batching_util:bounded_executor", "//tensorflow/core/kernels/batching_util:warmup", "//tensorflow/core/lib/core:refcount", diff --git a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h index 367b8b89482b6a..86772a2a38d437 100644 --- a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h +++ b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/batch_resource_base.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" #include "tensorflow/core/kernels/batching_util/warmup.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" @@ -225,6 +226,13 @@ void BatchFunctionFallbackKernel::ComputeAsync( batch_resource_options.low_priority_allowed_batch_sizes = low_priority_allowed_batch_sizes_; + serving::ModelBatchStats& model_batch_stats = + serving::GlobalBatchStatsRegistry().model( + /* model_name= */ std::string(GetModelName(c)), + /* op_name= */ c->op_kernel().name()); + model_batch_stats.SetBatchTimeoutMicros(batch_timeout_micros_); + model_batch_stats.SetNumBatchThreads(num_batch_threads_); + std::unique_ptr new_resource; auto status = BatchResourceType::Create( c, batch_resource_options, batch_function_, diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc index 5af1ea812a2d46..38b8fc37f3a432 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc @@ -438,7 +438,7 @@ REGISTER_OP("_BatchFunctionFallback") // BatchFunction in core/ops/batch_ops.cc. .Attr( "batch_padding_policy: " - "{'PAD_UP'} = 'PAD_UP'") + "{'PAD_UP', 'BATCH_DOWN', 'MINIMIZE_TPU_COST_PER_REQUEST'} = 'PAD_UP'") .Attr("Tin: list(type)") .Attr("Tcaptured: list(type) >= 0") .Attr("Tout: list(type)") diff --git a/tensorflow/core/tfrt/gpu/kernel/BUILD b/tensorflow/core/tfrt/gpu/kernel/BUILD index bd4f86131e3117..fef0e58310a334 100644 --- a/tensorflow/core/tfrt/gpu/kernel/BUILD +++ b/tensorflow/core/tfrt/gpu/kernel/BUILD @@ -13,20 +13,16 @@ cc_library( deps = [ ":gpu_runner", "//tensorflow/core:framework", - "//tensorflow/core/common_runtime:copy_tensor", "//tensorflow/core/framework:tensor", "//tensorflow/core/platform:status", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", - "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_utils", "//tensorflow/core/runtime_fallback/kernel:tensor_util", "//tensorflow/core/tfrt/utils:fallback_tensor", "//tensorflow/core/tfrt/utils:gpu_variables_table", - "//tensorflow/core/tfrt/utils:tensor_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@tf_runtime//:core_runtime", "@tf_runtime//:hostcontext", "@tf_runtime//:support", "@tf_runtime//:tensor_alwayslink", @@ -47,9 +43,11 @@ cc_library( "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", + "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/framework:function_proto_cc", + "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/platform:notification", "//tensorflow/core/platform:status", - "//tensorflow/core/platform:statusor", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", "//tensorflow/core/tfrt/common:global_state", "//tensorflow/core/tfrt/utils:fallback_tensor", @@ -59,6 +57,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -122,6 +121,7 @@ cc_library( "//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector", "//tensorflow/core/platform:status", "//tensorflow/core/tfrt/runtime", + "@com_google_absl//absl/status", "@local_xla//xla/tsl/framework:serving_device_selector_policies", "@tf_runtime//:hostcontext", ], diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc index 53586841f4beb6..d4047d4d206043 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc +++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc @@ -41,15 +41,18 @@ limitations under the License. #include "xla/tsl/framework/device_id_manager.h" #include "xla/tsl/framework/serving_device_selector.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" #include "tensorflow/core/tfrt/common/global_state.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h index fc61eff2d28139..d292fedfbc4bfc 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h +++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h @@ -18,7 +18,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" #include "xla/tsl/framework/serving_device_selector.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" @@ -27,6 +30,7 @@ limitations under the License. #include "tensorflow/core/tfrt/utils/gpu_variables_table.h" #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime namespace tensorflow { namespace gpu { diff --git a/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc b/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc index 8cc6a6286abf75..43cb013da7b926 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc +++ b/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc @@ -21,19 +21,15 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "tensorflow/core/common_runtime/copy_tensor.h" -#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" -#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h" #include "tensorflow/core/runtime_fallback/kernel/tensor_util.h" #include "tensorflow/core/tfrt/gpu/kernel/gpu_runner.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tensorflow/core/tfrt/utils/gpu_variables_table.h" -#include "tensorflow/core/tfrt/utils/tensor_util.h" -#include "tfrt/host_context/async_dispatch.h" // from @tf_runtime #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime #include "tfrt/host_context/attribute_utils.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc index 94e52ad23ed51a..48f3160f8138da 100644 --- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc +++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "xla/tsl/framework/serving_device_selector_policies.h" #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h index bb990224ea0fc9..452ccdd9b1804d 100644 --- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h +++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ #define TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ #include "xla/tsl/framework/serving_device_selector_policies.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/runtime/runtime.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index dfbe4def44f256..919e0df3be45a0 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -30,6 +30,7 @@ cc_library( srcs = ["ifrt_serving_core_selector.cc"], hdrs = ["ifrt_serving_core_selector.h"], deps = [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -226,6 +227,15 @@ cc_library( ], ) +cc_library( + name = "ifrt_model_restore_context", + hdrs = ["ifrt_model_restore_context.h"], + deps = [ + ":checkpoint_loader", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_library( name = "ifrt_model_context", srcs = ["ifrt_model_context.cc"], @@ -316,7 +326,6 @@ cc_library( ":sharding_utils", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", "//tensorflow/core:framework", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -347,6 +356,7 @@ cc_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core_no_xla", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/framework:function_proto_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:fixed_array", @@ -491,6 +501,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core/framework:tensor_matcher", "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -599,3 +610,38 @@ tf_cc_test( "@tf_runtime//backends/cpu:tf_ops_alwayslink", ], ) + +cc_library( + name = "checkpoint_loader", + srcs = ["checkpoint_loader.cc"], + hdrs = ["checkpoint_loader.h"], + deps = [ + ":ifrt_loaded_variable_utils", + ":ifrt_restore_tensor_registry", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:function", + "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/framework:node_def_util", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "//tensorflow/core/tfrt/mlrt/bytecode", + "//tensorflow/core/tfrt/mlrt/kernel:context", + "//tensorflow/core/tfrt/mlrt/kernel:kernel_runner_utils", + "//tensorflow/core/tfrt/mlrt/kernel:shard_restore_util", + "//tensorflow/core/tfrt/utils:fallback_tensor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tstring", + "@local_xla//xla/python/ifrt", + "@tf_runtime//:hostcontext", + ], +) diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc new file mode 100644 index 00000000000000..a970b027e48b40 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc @@ -0,0 +1,359 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "xla/python/ifrt/future.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/kernel/context.h" +#include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h" +#include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/tstring.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { + +namespace { + +static constexpr int kNumRestoreClusters = 4; + +// A shard of variables to be restored. +struct RestoreVariableShard { + tensorflow::Tensor prefix; + tensorflow::Tensor tensor_names; + tensorflow::Tensor shape_and_slices; + std::vector var_handles; + tensorflow::AttrValue dtypes_attr_value; + std::vector restored_dtypes; + std::vector truncate_in_cast; +}; + +struct AsyncState { + explicit AsyncState( + const std::vector& input_tf_tensor_values, + const OpKernelContext::Params& params, int num_outputs, + const tensorflow::DeviceMgr& device_manager, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime) + : run_state(input_tf_tensor_values, params), + context(&run_state.params, num_outputs), + device_manager(device_manager), + process_function_library_runtime(process_function_library_runtime) {} + + tfrt_stub::OpKernelRunState run_state; + OpKernelContext context; + const tensorflow::DeviceMgr& device_manager; + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime; + + std::vector> results; +}; + +// Returns a casted tensor if successful. +absl::StatusOr Cast( + tensorflow::Tensor& in_tensor, tensorflow::DataType restored_dtype, + tensorflow::DataType cast_dtype, bool truncate_in_cast, + const tensorflow::DeviceMgr& device_manager, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime, + OpKernelContext::Params& params) { + auto runner = + tfrt_stub::OpKernelRunner::Create( + /*op_name=*/ + "Cast", /*node_name=*/"Cast", params.device->name(), + /*num_args=*/1, + [&](tensorflow::AttrValueMap* attr_value_map) { + tensorflow::AttrValue restored_dtype_attr_value; + restored_dtype_attr_value.set_type(restored_dtype); + attr_value_map->insert({"SrcT", restored_dtype_attr_value}); + + tensorflow::AttrValue cast_dtype_attr_value; + cast_dtype_attr_value.set_type(cast_dtype); + attr_value_map->insert({"DstT", cast_dtype_attr_value}); + + tensorflow::AttrValue truncate_attr_value; + truncate_attr_value.set_b(truncate_in_cast); + attr_value_map->insert({"Truncate", truncate_attr_value}); + return absl::OkStatus(); + }, + device_manager, process_function_library_runtime) + .value(); + + std::vector input_tf_tensor_values; + input_tf_tensor_values.push_back(tensorflow::TensorValue(&in_tensor)); + + tf_mlrt::SetUpParams(runner, input_tf_tensor_values, params); + // Use persistent device instead of the per request device. + + OpKernelContext op_kernel_context(¶ms, /*num_outputs=*/1); + + runner.Run(&op_kernel_context); + + if (!op_kernel_context.status().ok()) { + return op_kernel_context.status(); + } + DCHECK_EQ(op_kernel_context.num_outputs(), 1); + return *(op_kernel_context.mutable_output(0)); +} + +absl::Status RunShard(RestoreVariableShard shard, + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue, + tf_mlrt::Context& context) { + if (!ifrt_restore_tensor_registry) { + return absl::InternalError("ifrt_restore_tensor_registry must not be null"); + } + if (!checkpoint_loader_work_queue) { + return absl::InternalError("checkpoint_loader_work_queue must not be null"); + } + const int num_outputs = shard.var_handles.size(); + DCHECK_EQ(num_outputs, shard.tensor_names.NumElements()); + auto& fallback_request_state = context.fallback_request_state(); + + // Use `tf.RestoreV2` to restore tensor. This will also populate + // tensorflow::ResourceManager. + // TODO(b/319045348): avoid populating tensorflow::ResourceManager if the + // variable is only used by device/IFRT. + // TODO(b/319045348): consider directly calling restore function such as that + // in /tensorflow/core/kernels/save_restore_v2_ops.cc + auto runner = + tfrt_stub::OpKernelRunner::Create( + /*op_name=*/ + "RestoreV2", /*node_name=*/"RestoreV2", + context.params().device->name(), + /*num_args=*/3, + [&](tensorflow::AttrValueMap* attr_value_map) { + attr_value_map->insert({"dtypes", shard.dtypes_attr_value}); + return absl::OkStatus(); + }, + fallback_request_state.device_manager(), + fallback_request_state.process_function_library_runtime()) + .value(); + + // Prepare the input tensors. + std::vector input_tf_tensor_values; + static constexpr int kNumInputArgs = 3; + input_tf_tensor_values.resize(kNumInputArgs); + // We need to keep these tensor alive + input_tf_tensor_values[0].tensor = &shard.prefix; + input_tf_tensor_values[1].tensor = &shard.tensor_names; + input_tf_tensor_values[2].tensor = &shard.shape_and_slices; + + auto& params = context.params(); + tf_mlrt::SetUpParams(runner, input_tf_tensor_values, params); + // Use persistent device instead of the per request device. + params.device = context.fallback_request_state().device_manager().HostCPU(); + + auto async_state = std::make_unique( + input_tf_tensor_values, params, num_outputs, + fallback_request_state.device_manager(), + fallback_request_state.process_function_library_runtime()); + + for (int i = 0; i < num_outputs; ++i) { + auto promise = xla::ifrt::Future::CreatePromise(); + auto future = xla::ifrt::Future(promise); + const ResourceHandle& var_handle = + shard.var_handles[i].tensor().scalar()(); + + TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape, + ifrt_serving::GetDtypeAndShape(var_handle)); + + std::string runtime_name = + ifrt_serving::GetRuntimeNameFromVarHandle(var_handle); + + ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo + restored_tensor_info = {false, std::move(dtype_and_shape), + std::move(future)}; + if (auto status = ifrt_restore_tensor_registry->TryRegister( + runtime_name, restored_tensor_info); + !status.ok()) { + // Propagate errors so that if already-registered futures are being waited + // on, they can be unblocked. + for (auto& result : async_state->results) { + std::move(result).Set(status); + }; + return status; + } + async_state->results.push_back(std::move(promise)); + } + + // Use dedicated work queue for restore operation. + checkpoint_loader_work_queue->AddTask([runner = std::move(runner), + async_state = std::move(async_state), + shard = std::move(shard)]() { + // Keep input tensor alive in `shard`. + auto* op_kernel_context_ptr = &async_state->context; + runner.Run(op_kernel_context_ptr); + + auto& op_kernel_context = async_state->context; + if (!op_kernel_context.status().ok()) { + for (auto& result : async_state->results) { + std::move(result).Set(op_kernel_context.status()); + } + return; + } + DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs()); + DCHECK_EQ(shard.truncate_in_cast.size(), op_kernel_context.num_outputs()); + + // TODO(b/343964091): consider to run multiple casts in parallel. + for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { + DCHECK(op_kernel_context.mutable_output(i)); + + if (op_kernel_context.mutable_output(i)->dtype() != + shard.restored_dtypes[i]) { + std::move(async_state->results[i]) + .Set(absl::InvalidArgumentError(absl::StrCat( + "The restored tensor has a different dtype than the " + "variable handle: ", + op_kernel_context.mutable_output(i)->dtype(), " vs. ", + shard.restored_dtypes[i]))); + return; + } + const ResourceHandle& var_handle = + shard.var_handles[i].tensor().scalar()(); + + if (shard.restored_dtypes[i] == var_handle.dtypes_and_shapes()[0].dtype) { + std::move(async_state->results[i]) + .Set(*std::move(op_kernel_context.mutable_output(i))); + } else { + absl::StatusOr cast_output = + Cast(*op_kernel_context.mutable_output(i), shard.restored_dtypes[i], + var_handle.dtypes_and_shapes()[0].dtype, + shard.truncate_in_cast[i], async_state->device_manager, + async_state->process_function_library_runtime, + async_state->run_state.params); + if (!cast_output.ok()) { + std::move(async_state->results[i]).Set(cast_output.status()); + } else { + std::move(async_state->results[i]).Set(*std::move(cast_output)); + } + } + } + }); + return absl::OkStatus(); +} + +int64_t GetSizeFromVarHandle(const ResourceHandle& handle) { + int size = 0; + for (auto& dtype_and_shape : handle.dtypes_and_shapes()) { + size += DataTypeSize(dtype_and_shape.dtype) * + dtype_and_shape.shape.num_elements(); + } + return size; +} + +} // namespace + +absl::Status CheckpointLoader::PrepareRestore( + mlir::OwningOpRef module) { + VLOG(1) << "Skip CheckpointLoader::PrepareRestore"; + return absl::OkStatus(); +} + +absl::Status CheckpointLoader::Load( + const tensorflow::tfrt_stub::FallbackTensor& prefix, + const std::vector& var_handles, + const tensorflow::tfrt_stub::FallbackTensor& tensor_names, + const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices, + const mlrt::bc::Vector& restored_dtypes, + const mlrt::bc::Vector& truncate_in_cast, tf_mlrt::Context& context) { + std::vector variable_sizes; + variable_sizes.reserve(var_handles.size()); + for (auto& handle : var_handles) { + variable_sizes.push_back(GetSizeFromVarHandle( + handle.tensor().scalar()())); + } + + std::vector> sharded_indices = tf_mlrt::ShardVariables( + kNumRestoreClusters, absl::MakeSpan(variable_sizes)); + + // Converts the names and slices back to the tensor. + auto vector_to_tensor = [](const std::vector& vec) { + tensorflow::Tensor tensor(tensorflow::DT_STRING, + TensorShape({static_cast(vec.size())})); + for (int i = 0; i < vec.size(); ++i) { + tensor.flat()(i) = vec[i]; + } + return tensor; + }; + + const auto& tensor_names_flat = tensor_names.tensor().flat(); + const auto& shape_and_slices_flat = + shape_and_slices.tensor().flat(); + + std::vector shards; + shards.reserve(sharded_indices.size()); + for (auto& sharded_index : sharded_indices) { + RestoreVariableShard shard; + shard.var_handles.reserve(sharded_index.size()); + shard.truncate_in_cast.reserve(sharded_index.size()); + shard.restored_dtypes.reserve(sharded_index.size()); + std::vector tensor_names; + std::vector shape_and_slices; + shape_and_slices.reserve(sharded_index.size()); + tensor_names.reserve(sharded_index.size()); + for (int index : sharded_index) { + tensor_names.push_back(tensor_names_flat(index)); + shape_and_slices.push_back(shape_and_slices_flat(index)); + shard.dtypes_attr_value.mutable_list()->add_type(restored_dtypes[index]); + shard.var_handles.push_back(var_handles[index]); + shard.restored_dtypes.push_back(restored_dtypes[index]); + shard.truncate_in_cast.push_back(truncate_in_cast[index]); + } + shard.prefix = prefix.tensor(); + shard.tensor_names = vector_to_tensor(tensor_names); + shard.shape_and_slices = vector_to_tensor(shape_and_slices); + shards.push_back(std::move(shard)); + } + for (const auto& shard : shards) { + TF_RETURN_IF_ERROR(RunShard(shard, ifrt_restore_tensor_registry_, + checkpoint_loader_work_queue_, context)); + } + return absl::OkStatus(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.h b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h new file mode 100644 index 00000000000000..ab4a2ab48e12aa --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/kernel/context.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { + +// TODO(b/352551302) Move the unit test in ifrt_ops_kernel for restore to test +// this class's APIs. +// Implement the `CheckpointLoaderInterface` by using RestoreV2. +class CheckpointLoader { + public: + explicit CheckpointLoader( + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue) + : ifrt_restore_tensor_registry_(ifrt_restore_tensor_registry), + checkpoint_loader_work_queue_(checkpoint_loader_work_queue) {} + virtual ~CheckpointLoader() = default; + + // Called before `Load` to do some preparation work. + virtual absl::Status PrepareRestore(mlir::OwningOpRef module); + + // Load the checkpoint. This API is designed to be compatible with the + // `tf_mlrt.ifrt_restore_variable` kernel. + virtual absl::Status Load( + const tensorflow::tfrt_stub::FallbackTensor& prefix, + const std::vector& var_handles, + const tensorflow::tfrt_stub::FallbackTensor& tensor_names, + const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices, + const mlrt::bc::Vector& restored_dtypes, + const mlrt::bc::Vector& truncate_in_cast, + tf_mlrt::Context& context); + + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry_; + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue_; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h index 2c6a566c5a1dd2..d488d936776954 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h @@ -25,6 +25,8 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/array.h" diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc index 0af5363da9a109..ff71481a490d60 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -28,7 +27,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" -#include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/future.h" diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h new file mode 100644 index 00000000000000..da9528eab6b023 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" + +namespace tensorflow { +namespace ifrt_serving { + +inline constexpr absl::string_view kIfrtModelRestoreContextName = + "IfrtModelRestoreContext"; + +// A resource context that holds the `CheckpointLoader` for a model. We need a +// different context than `IfrtModelContext` because `IfrtModelContext` is too +// large to be a dependency of other libraries. +class IfrtModelRestoreContext { + public: + explicit IfrtModelRestoreContext( + std::unique_ptr checkpoint_loader) + : checkpoint_loader_(std::move(checkpoint_loader)) {} + + CheckpointLoader* checkpoint_loader() const { + return checkpoint_loader_.get(); + } + + private: + std::unique_ptr checkpoint_loader_; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h index 0ab53974be06f3..a4505cbab06f38 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "xla/tsl/framework/serving_device_selector.h" diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.h b/tensorflow/core/tfrt/ifrt/sharding_utils.h index aef068534b741a..43dbe9e8bca8dd 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils.h +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.h @@ -26,9 +26,11 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/future.h" #include "xla/tsl/concurrency/ref_count.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tsl/platform/threadpool.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc index 77043fce3a9faa..f85b3243c36191 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc +++ b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_matcher.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/ml_dtypes.h" diff --git a/tensorflow/core/tfrt/ifrt/tf_host_callback.cc b/tensorflow/core/tfrt/ifrt/tf_host_callback.cc index 5c5a48f4fc52b4..8beeddf82a92e3 100644 --- a/tensorflow/core/tfrt/ifrt/tf_host_callback.cc +++ b/tensorflow/core/tfrt/ifrt/tf_host_callback.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/protobuf/config.pb.h" diff --git a/tensorflow/core/tfrt/ifrt/tf_host_callback.h b/tensorflow/core/tfrt/ifrt/tf_host_callback.h index a78b0e5d0aecea..5b73221e6d3afa 100644 --- a/tensorflow/core/tfrt/ifrt/tf_host_callback.h +++ b/tensorflow/core/tfrt/ifrt/tf_host_callback.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/ifrt/tf_host_callback_test.cc b/tensorflow/core/tfrt/ifrt/tf_host_callback_test.cc index bc67bbae34d94a..17240e361881c8 100644 --- a/tensorflow/core/tfrt/ifrt/tf_host_callback_test.cc +++ b/tensorflow/core/tfrt/ifrt/tf_host_callback_test.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/core/tfrt/mlrt/interpreter/BUILD b/tensorflow/core/tfrt/mlrt/interpreter/BUILD index 0b1eee7667f43b..10b5346a49553e 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/BUILD +++ b/tensorflow/core/tfrt/mlrt/interpreter/BUILD @@ -127,6 +127,8 @@ cc_library( ":future", ":value", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_xla//xla/tsl/concurrency:async_value", "@tf_runtime//:async_value", ], diff --git a/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h b/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h index ceef6679b6fa8a..43d43422e60093 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h +++ b/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h @@ -19,6 +19,8 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/mlrt/interpreter/future.h" #include "tensorflow/core/tfrt/mlrt/interpreter/value.h" @@ -141,6 +143,12 @@ class AsyncHandle { } auto& execution_context = *arg->Get(); + execution_context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: unwind AsyncHandle of context ", + absl::Hex(reinterpret_cast(execution_context_.get())), + " from context ", + absl::Hex(reinterpret_cast(&execution_context)), + " of state ", execution_context.state_))); execution_context.Await(std::move(*this)); } diff --git a/tensorflow/core/tfrt/mlrt/interpreter/execute.cc b/tensorflow/core/tfrt/mlrt/interpreter/execute.cc index f3ef9bc2822085..635935911aa221 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/execute.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/execute.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/mlrt/interpreter/execute.h" +#include #include #include @@ -178,7 +179,10 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { function_name = context.function_stack_.back().function_object().name(); } context.LogError(absl::InternalError(absl::StrCat( - "Start UnwindOnError from function ", function_name, " at pc: ", pc))); + "UnwindOnError: start from function ", function_name, + " with stack size: ", context.function_stack_.size(), " at pc: ", pc, + " for context ", absl::Hex(reinterpret_cast(&context)), + " at state ", context.state_))); while (!context.function_stack_.empty()) { DCHECK(context.state_ == ExecutionContext::State::kError); @@ -199,6 +203,11 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { reg.HandleError(context_value); if (context.state_ != ExecutionContext::State::kError) { DCHECK(context.state_ == ExecutionContext::State::kSuspended); + + context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: entering state", context.state_, " for context ", + absl::Hex(reinterpret_cast(&context))))); + // Rewind current pc so that the execution context come back to where // is is suspended. --pc; @@ -207,6 +216,12 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { } } + context.LogError(absl::InternalError( + absl::StrCat("UnwindOnError: unwinding function from ", pc, " to ", + current_function->pc_, " for context ", + absl::Hex(reinterpret_cast(&context)), + " at state ", context.state_))); + for (; context.state_ == ExecutionContext::State::kError && pc <= current_function->pc_; ++pc) { @@ -218,6 +233,10 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { reg.HandleError(context_value); if (context.state_ != ExecutionContext::State::kError) { DCHECK(context.state_ == ExecutionContext::State::kSuspended); + context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: entering state", context.state_, " for context ", + absl::Hex(reinterpret_cast(&context))))); + // Rewind current pc so that the execution context come back to where // is is suspended. --pc; @@ -230,6 +249,9 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { DCHECK(context.suspend_handler_) << "suspend_handler_ must be populated when the state is set to " "kSuspended."; + context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: suspended state ", context.state_, " for context ", + absl::Hex(reinterpret_cast(&context))))); std::move(context.suspend_handler_)([&context, pc]() { auto* work_queue = context.work_queue(); DCHECK(work_queue); @@ -247,8 +269,10 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { context.function_stack_.pop_back(); } - context.LogError(absl::InternalError( - absl::StrCat("Finish UnwindOnError for function ", function_name))); + context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: done for function ", function_name, + " for context: ", absl::Hex(reinterpret_cast(&context)), + " at state ", context.state_))); // Context may no longer be valid after exit_handler_ is called. if (context.exit_handler_) { diff --git a/tensorflow/core/tfrt/mlrt/kernel/BUILD b/tensorflow/core/tfrt/mlrt/kernel/BUILD index fb749ea227b7df..5e377f8f809153 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/BUILD +++ b/tensorflow/core/tfrt/mlrt/kernel/BUILD @@ -10,6 +10,7 @@ package( # copybara:uncomment "//learning/brain/tfrt:__subpackages__", # copybara:uncomment "//learning/serving/servables/tfrt:__subpackages__", "//tensorflow/core/tfrt/graph_executor:__subpackages__", + "//tensorflow/core/tfrt/ifrt:__subpackages__", "//tensorflow/core/tfrt/saved_model:__subpackages__", "//tensorflow/core/tfrt/tfrt_session:__subpackages__", ], @@ -67,19 +68,16 @@ cc_library( deps = [ ":context", ":kernel", - ":kernel_runner_utils", - ":shard_restore_util", - "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", - "//tensorflow/core/common_runtime:function", "//tensorflow/core/framework:attr_value_proto_cc", "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/platform:protobuf", - "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "//tensorflow/core/tfrt/ifrt:checkpoint_loader", "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc", "//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_utils", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context", "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/interpreter:context", @@ -89,13 +87,10 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tstring", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/python/ifrt", - "@tf_runtime//:hostcontext", ], alwayslink = 1, ) @@ -210,9 +205,11 @@ tf_cc_shared_test( "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "//tensorflow/core/tfrt/ifrt:checkpoint_loader", "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc", "//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_registry", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context", "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry", "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector", "//tensorflow/core/tfrt/mlrt/bytecode", @@ -230,6 +227,7 @@ tf_cc_shared_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@eigen_archive//:eigen3", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:refcount", "@local_tsl//tsl/platform:status", diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc index ca9dd2271335fb..e5c7dbd1dc0c72 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc @@ -25,39 +25,31 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "xla/python/ifrt/future.h" #include "xla/xla_data.pb.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep -#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/mlrt/interpreter/future.h" #include "tensorflow/core/tfrt/mlrt/kernel/context.h" #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h" -#include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h" -#include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" #include "tsl/platform/tstring.h" -#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime using tensorflow::ifrt_serving::IfrtModelContext; @@ -65,14 +57,6 @@ namespace tensorflow { namespace tf_mlrt { namespace { -int64_t GetSizeFromVarHandle(const ResourceHandle& handle) { - int size = 0; - for (auto& dtype_and_shape : handle.dtypes_and_shapes()) { - size += DataTypeSize(dtype_and_shape.dtype) * - dtype_and_shape.shape.num_elements(); - } - return size; -} struct MlrtIfrtRestoreVariableKernel : mlrt::KernelFrame { using KernelFrame::KernelFrame; @@ -119,20 +103,8 @@ struct MlrtIfrtRestoreVariableKernel : mlrt::KernelFrame { // dynamically decide it based on the size of the variables. static constexpr int kNumRestoreClusters = 4; - // A shard of variables to be restored. - struct RestoreVariableShard { - tensorflow::Tensor prefix; - tensorflow::Tensor tensor_names; - tensorflow::Tensor shape_and_slices; - std::vector var_handles; - tensorflow::AttrValue dtypes_attr_value; - std::vector restored_dtypes; - std::vector truncate_in_cast; - }; - absl::Status InvokeHelper(); - absl::Status RunShard(RestoreVariableShard shard); absl::Status ValidateInput(); }; @@ -144,218 +116,6 @@ void MlrtIfrtRestoreVariableKernel::Invoke() { } } -// Returns a casted tensor if successful. -absl::StatusOr Cast( - tensorflow::Tensor& in_tensor, tensorflow::DataType restored_dtype, - tensorflow::DataType cast_dtype, bool truncate_in_cast, - const tensorflow::DeviceMgr& device_manager, - const tensorflow::ProcessFunctionLibraryRuntime& - process_function_library_runtime, - OpKernelContext::Params& params) { - auto runner = - tfrt_stub::OpKernelRunner::Create( - /*op_name=*/ - "Cast", /*node_name=*/"Cast", params.device->name(), - /*num_args=*/1, - [&](tensorflow::AttrValueMap* attr_value_map) { - tensorflow::AttrValue restored_dtype_attr_value; - restored_dtype_attr_value.set_type(restored_dtype); - attr_value_map->insert({"SrcT", restored_dtype_attr_value}); - - tensorflow::AttrValue cast_dtype_attr_value; - cast_dtype_attr_value.set_type(cast_dtype); - attr_value_map->insert({"DstT", cast_dtype_attr_value}); - - tensorflow::AttrValue truncate_attr_value; - truncate_attr_value.set_b(truncate_in_cast); - attr_value_map->insert({"Truncate", truncate_attr_value}); - return absl::OkStatus(); - }, - device_manager, process_function_library_runtime) - .value(); - - std::vector input_tf_tensor_values; - input_tf_tensor_values.push_back(tensorflow::TensorValue(&in_tensor)); - - SetUpParams(runner, input_tf_tensor_values, params); - // Use persistent device instead of the per request device. - - OpKernelContext op_kernel_context(¶ms, /*num_outputs=*/1); - - runner.Run(&op_kernel_context); - - if (!op_kernel_context.status().ok()) { - return op_kernel_context.status(); - } - DCHECK_EQ(op_kernel_context.num_outputs(), 1); - return *(op_kernel_context.mutable_output(0)); -} - -absl::Status MlrtIfrtRestoreVariableKernel::RunShard( - RestoreVariableShard shard) { - std::optional ifrt_model_context = - context().resource_context().GetResource( - "IfrtModelContext"); - if (!ifrt_model_context.has_value()) { - return absl::FailedPreconditionError( - "RestoreVariableOp: failed to fetch IfrtModelContext"); - } - const int num_outputs = shard.var_handles.size(); - DCHECK_EQ(num_outputs, shard.tensor_names.NumElements()); - auto& fallback_request_state = context().fallback_request_state(); - - // Use `tf.RestoreV2` to restore tensor. This will also populate - // tensorflow::ResourceManager. - // TODO(b/319045348): avoid populating tensorflow::ResourceManager if the - // variable is only used by device/IFRT. - // TODO(b/319045348): consider directly calling restore function such as that - // in /tensorflow/core/kernels/save_restore_v2_ops.cc - auto runner = - tfrt_stub::OpKernelRunner::Create( - /*op_name=*/ - "RestoreV2", /*node_name=*/"RestoreV2", - context().params().device->name(), - /*num_args=*/3, - [&](tensorflow::AttrValueMap* attr_value_map) { - attr_value_map->insert({"dtypes", shard.dtypes_attr_value}); - return absl::OkStatus(); - }, - fallback_request_state.device_manager(), - fallback_request_state.process_function_library_runtime()) - .value(); - - // Prepare the input tensors. - std::vector input_tf_tensor_values; - static constexpr int kNumInputArgs = 3; - input_tf_tensor_values.resize(kNumInputArgs); - // We need to keep these tensor alive - input_tf_tensor_values[0].tensor = &shard.prefix; - input_tf_tensor_values[1].tensor = &shard.tensor_names; - input_tf_tensor_values[2].tensor = &shard.shape_and_slices; - - auto& params = context().params(); - SetUpParams(runner, input_tf_tensor_values, params); - // Use persistent device instead of the per request device. - params.device = context().fallback_request_state().device_manager().HostCPU(); - - struct AsyncState { - explicit AsyncState( - const std::vector& input_tf_tensor_values, - const OpKernelContext::Params& params, int num_outputs, - const tensorflow::DeviceMgr& device_manager, - const tensorflow::ProcessFunctionLibraryRuntime& - process_function_library_runtime) - : run_state(input_tf_tensor_values, params), - context(&run_state.params, num_outputs), - device_manager(device_manager), - process_function_library_runtime(process_function_library_runtime) {} - - tfrt_stub::OpKernelRunState run_state; - OpKernelContext context; - const tensorflow::DeviceMgr& device_manager; - const tensorflow::ProcessFunctionLibraryRuntime& - process_function_library_runtime; - - std::vector> results; - }; - auto async_state = std::make_unique( - input_tf_tensor_values, params, num_outputs, - fallback_request_state.device_manager(), - fallback_request_state.process_function_library_runtime()); - - ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry = - (*ifrt_model_context)->GetRestoreTensorRegistry(); - for (int i = 0; i < num_outputs; ++i) { - auto promise = xla::ifrt::Future::CreatePromise(); - auto future = xla::ifrt::Future(promise); - const ResourceHandle& var_handle = - shard.var_handles[i].tensor().scalar()(); - - TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape, - ifrt_serving::GetDtypeAndShape(var_handle)); - - std::string runtime_name = - ifrt_serving::GetRuntimeNameFromVarHandle(var_handle); - - ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo - restored_tensor_info = {false, std::move(dtype_and_shape), - std::move(future)}; - if (auto status = ifrt_restore_tensor_registry.TryRegister( - runtime_name, restored_tensor_info); - !status.ok()) { - // Propagate errors so that if already-registered futures are being waited - // on, they can be unblocked. - for (auto& result : async_state->results) { - std::move(result).Set(status); - }; - return status; - } - async_state->results.push_back(std::move(promise)); - } - - // Use dedicated work queue for restore operation. - DCHECK((*ifrt_model_context)->checkpoint_loader_queue() != nullptr); - (*ifrt_model_context) - ->checkpoint_loader_queue() - ->AddTask([runner = std::move(runner), - async_state = std::move(async_state), - shard = std::move(shard)]() { - // Keep input tensor alive in `shard`. - auto* op_kernel_context_ptr = &async_state->context; - runner.Run(op_kernel_context_ptr); - - auto& op_kernel_context = async_state->context; - if (!op_kernel_context.status().ok()) { - for (auto& result : async_state->results) { - std::move(result).Set(op_kernel_context.status()); - } - return; - } - DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs()); - DCHECK_EQ(shard.truncate_in_cast.size(), - op_kernel_context.num_outputs()); - - // TODO(b/343964091): consider to run multiple casts in parallel. - for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { - DCHECK(op_kernel_context.mutable_output(i)); - - if (op_kernel_context.mutable_output(i)->dtype() != - shard.restored_dtypes[i]) { - std::move(async_state->results[i]) - .Set(absl::InvalidArgumentError(absl::StrCat( - "The restored tensor has a different dtype than the " - "variable handle: ", - op_kernel_context.mutable_output(i)->dtype(), " vs. ", - shard.restored_dtypes[i]))); - return; - } - const ResourceHandle& var_handle = - shard.var_handles[i] - .tensor() - .scalar()(); - - if (shard.restored_dtypes[i] == - var_handle.dtypes_and_shapes()[0].dtype) { - std::move(async_state->results[i]) - .Set(*std::move(op_kernel_context.mutable_output(i))); - } else { - absl::StatusOr cast_output = Cast( - *op_kernel_context.mutable_output(i), shard.restored_dtypes[i], - var_handle.dtypes_and_shapes()[0].dtype, - shard.truncate_in_cast[i], async_state->device_manager, - async_state->process_function_library_runtime, - async_state->run_state.params); - if (!cast_output.ok()) { - std::move(async_state->results[i]).Set(cast_output.status()); - } else { - std::move(async_state->results[i]).Set(*std::move(cast_output)); - } - } - } - }); - return absl::OkStatus(); -} - absl::Status MlrtIfrtRestoreVariableKernel::ValidateInput() { if (prefix().tensor().NumElements() != 1) { return absl::InvalidArgumentError( @@ -398,65 +158,26 @@ absl::Status MlrtIfrtRestoreVariableKernel::ValidateInput() { } absl::Status MlrtIfrtRestoreVariableKernel::InvokeHelper() { - TF_RETURN_IF_ERROR(ValidateInput()); - - std::vector variable_sizes; - variable_sizes.reserve(var_handles().size()); - for (auto& handle : var_handles()) { - variable_sizes.push_back(GetSizeFromVarHandle( - handle.tensor().scalar()())); + std::optional model_restore_context = + context() + .resource_context() + .GetResource( + ifrt_serving::kIfrtModelRestoreContextName); + if (!model_restore_context.has_value()) { + return absl::InternalError( + "Did not find IfrtModelRestoreContext resource."); } - - std::vector> sharded_indices = - ShardVariables(kNumRestoreClusters, absl::MakeSpan(variable_sizes)); - - // Converts the names and slices back to the tensor. - auto vector_to_tensor = [](const std::vector& vec) { - tensorflow::Tensor tensor(tensorflow::DT_STRING, - TensorShape({static_cast(vec.size())})); - for (int i = 0; i < vec.size(); ++i) { - tensor.flat()(i) = vec[i]; - } - return tensor; - }; - - const auto& tensor_names_flat = tensor_names().tensor().flat(); - const auto& shape_and_slices_flat = - shape_and_slices().tensor().flat(); - - std::vector shards; - shards.reserve(sharded_indices.size()); - for (auto& sharded_index : sharded_indices) { - RestoreVariableShard shard; - shard.var_handles.reserve(sharded_index.size()); - shard.truncate_in_cast.reserve(sharded_index.size()); - shard.restored_dtypes.reserve(sharded_index.size()); - - std::vector tensor_names; - std::vector shape_and_slices; - shape_and_slices.reserve(sharded_index.size()); - tensor_names.reserve(sharded_index.size()); - for (int index : sharded_index) { - tensor_names.push_back(tensor_names_flat(index)); - shape_and_slices.push_back(shape_and_slices_flat(index)); - shard.dtypes_attr_value.mutable_list()->add_type( - restored_dtypes()[index]); - - shard.var_handles.push_back(var_handles()[index]); - shard.restored_dtypes.push_back(restored_dtypes()[index]); - shard.truncate_in_cast.push_back(truncate_in_cast()[index]); - } - - shard.prefix = prefix().tensor(); - shard.tensor_names = vector_to_tensor(tensor_names); - shard.shape_and_slices = vector_to_tensor(shape_and_slices); - shards.push_back(std::move(shard)); + if (*model_restore_context == nullptr) { + return absl::InternalError("IfrtModelRestoreContext must not be null."); } - - for (const auto& shard : shards) { - TF_RETURN_IF_ERROR(RunShard(shard)); + ifrt_serving::CheckpointLoader* checkpoint_loader = + (*model_restore_context)->checkpoint_loader(); + if (!checkpoint_loader) { + return absl::InternalError("CheckpointLoader must not be null."); } - return absl::OkStatus(); + return checkpoint_loader->Load(prefix(), var_handles(), tensor_names(), + shape_and_slices(), restored_dtypes(), + truncate_in_cast(), context()); } class MlrtIfrtLoadVariableKernel : public mlrt::KernelFrame { diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc index 8fea9e5a9deb29..07fb83b1e6eb32 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc @@ -44,8 +44,10 @@ limitations under the License. #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" #include "tensorflow/core/tfrt/fallback/fallback_state.h" #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" @@ -403,6 +405,13 @@ class KernelTest : public ::testing::Test { .value(); ifrt_model_context_->set_checkpoint_loader_queue(restore_work_queue_.get()); + resource_context_ + .CreateResource( + ifrt_serving::kIfrtModelRestoreContextName, + std::make_unique( + &ifrt_model_context_->GetRestoreTensorRegistry(), + ifrt_model_context_->checkpoint_loader_queue())); + serving_device_selector_ = std::make_unique(); ifrt_core_selector_ = diff --git a/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc b/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc index cd3f49f3d6b37c..16293c2ed5d4bb 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc @@ -66,7 +66,7 @@ std::vector> ShardVariables( }; std::priority_queue, decltype(cmp)> - min_heap; + min_heap(cmp); for (int i = 0; i < num_shards; ++i) { min_heap.push(RestoreVariableCluster()); } diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index 9e85c14baef362..5261546e6c6a0f 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -135,6 +135,8 @@ cc_library( "//tensorflow/core/tfrt/graph_executor", "//tensorflow/core/tfrt/graph_executor:export_mlir", "//tensorflow/core/tfrt/graph_executor:graph_execution_options", + "//tensorflow/core/tfrt/ifrt:checkpoint_loader", + "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", "//tensorflow/core/tfrt/mlrt/interpreter:context", diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index 62ad6550cd6c6d..84fbbff7401340 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -16,13 +16,11 @@ limitations under the License. #include #include -#include #include #include #include #include #include -#include #include #include @@ -70,6 +68,8 @@ limitations under the License. #include "tensorflow/core/tfrt/graph_executor/export_mlir.h" #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" #include "tensorflow/core/tfrt/graph_executor/graph_executor.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" @@ -134,6 +134,34 @@ auto* saved_model_input_spec_validation_failure = "/tensorflow/tfrt/saved_model/input_spec_validation_failure", "Record the models that failed input spec validation.", "model_name"); +absl::Status PrepareRestore(mlir::MLIRContext* context, + ModelRuntimeContext* model_runtime_context, + const tensorflow::MetaGraphDef& meta_graph_def, + FallbackState& fallback_state, + const std::string& saved_model_dir, + const SavedModel::Options& options, + ifrt_serving::CheckpointLoader* checkpoint_loader) { + // Import the global MLIR with `import_user_signatures` as true so that we can + // analysis the global MLIR to retrieve data needed for restore. + mlir::OwningOpRef mlir_module_restore_analysis; + ASSIGN_OR_RETURN_IN_IMPORT( + mlir_module_restore_analysis, + ImportSavedModel( + context, meta_graph_def, fallback_state, saved_model_dir, + /*import_user_signatures=*/true, + options.graph_execution_options.run_placer_grappler_on_functions)); + + if (!checkpoint_loader) { + return absl::InternalError("Missing checkpoint loader."); + } + + TF_RETURN_IF_ERROR(checkpoint_loader->PrepareRestore( + std::move(mlir_module_restore_analysis))); + + LOG(INFO) << "Complete set restore metadata."; + return absl::OkStatus(); +} + tensorflow::Status RunBytecodeInitializers( const GraphExecutionOptions& options, const InitializersAndSignatures& initializers_and_signatures, @@ -596,6 +624,25 @@ absl::StatusOr> SavedModelImpl::LoadSavedModel( model_context.set_callable_options(nullptr); } + if (options.graph_execution_options.use_ifrt) { + std::optional + model_restore_context = + model_context.resource_context() + .GetResource( + ifrt_serving::kIfrtModelRestoreContextName); + if (!model_restore_context.has_value()) { + return absl::InternalError( + "Did not find IfrtModelRestoreContext resource."); + } + if (*model_restore_context == nullptr) { + return absl::InternalError("IfrtModelRestoreContexts must not be null."); + } + TF_RETURN_IF_ERROR( + PrepareRestore(&context, &model_context, meta_graph_def, + *fallback_state, std::string(saved_model_dir), options, + (*model_restore_context)->checkpoint_loader())); + } + GetDefaultInputValue(meta_graph_def.signature_def(), model_context, initializers_and_signatures.signature_map); diff --git a/tensorflow/core/tfrt/saved_model/tests/BUILD b/tensorflow/core/tfrt/saved_model/tests/BUILD index 3dfc07d245eaf1..c026800861ff2a 100644 --- a/tensorflow/core/tfrt/saved_model/tests/BUILD +++ b/tensorflow/core/tfrt/saved_model/tests/BUILD @@ -649,7 +649,9 @@ cc_library( "//tensorflow/core/platform:resource_loader", "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", "//tensorflow/core/tfrt:ifrt_program_ops_op_lib", + "//tensorflow/core/tfrt/ifrt:checkpoint_loader", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context", "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector", "//tensorflow/core/tfrt/mlrt/kernel:ifrt_ops_kernel", "//tensorflow/core/tfrt/runtime", diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc index c403424ae4882b..4f4caf0e028b52 100644 --- a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc +++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc @@ -29,7 +29,9 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/runtime/runtime.h" #include "tensorflow/core/tfrt/saved_model/saved_model.h" @@ -77,10 +79,17 @@ TEST(SavedModelIfrt, Basic) { "IfrtModelContext", client, &core_selector, &GetThreadPool(), /*compilation_environment_proto=*/nullptr); - (*model_context.resource_context() - .GetResource( - "IfrtModelContext")) - ->set_checkpoint_loader_queue(work_queue.get()); + tensorflow::ifrt_serving::IfrtModelContext* ifrt_model_context = + (*model_context.resource_context() + .GetResource( + "IfrtModelContext")); + ifrt_model_context->set_checkpoint_loader_queue(work_queue.get()); + model_context.resource_context() + .CreateResource( + ifrt_serving::kIfrtModelRestoreContextName, + std::make_unique( + &ifrt_model_context->GetRestoreTensorRegistry(), + ifrt_model_context->checkpoint_loader_queue())); return absl::OkStatus(); }); diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc index 498d07a0e41e23..c18230e0b431dc 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc @@ -250,9 +250,9 @@ TfrtGraphExecutionState::CreateOptimizedGraph( DumpGraphDefToFile("before_pruning", graph_def); } - TF_ASSIGN_OR_RETURN( - result.graph, - CreatePrunedGraph(graph_def, build_graph_options.callable_options)); + TF_ASSIGN_OR_RETURN(result.graph, + CreatePrunedGraph(std::move(graph_def), + build_graph_options.callable_options)); DCHECK(result.graph); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD index 273c822fd74df4..73fbacd589160b 100644 --- a/tensorflow/core/tpu/graph_rewrite/BUILD +++ b/tensorflow/core/tpu/graph_rewrite/BUILD @@ -1,5 +1,9 @@ # Contains graph rewrites for TPU runtimes and optimizations. +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) load( "//tensorflow/core/platform:build_config_root.bzl", "if_static", @@ -119,6 +123,7 @@ cc_library( "//tensorflow/core:session_options", "//tensorflow/core/common_runtime:function_body", "//tensorflow/core/common_runtime:function_utils", + "//tensorflow/core/config:flag_defs", "//tensorflow/core/tpu:tpu_compile_interface", "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/container:flat_hash_map", @@ -131,6 +136,7 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", "@local_xla//xla:status_macros", + "@local_xla//xla/tsl/util:env_var", ] + if_static( [ "//tensorflow/core/common_runtime:function", @@ -140,6 +146,26 @@ cc_library( ), ) +tf_cc_test( + name = "encapsulate_tpu_computations_pass_test", + srcs = ["encapsulate_tpu_computations_pass_test.cc"], + deps = [ + ":encapsulate_tpu_computations_pass", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:optimization_registry", + "//tensorflow/core/config:flag_defs", + ], +) + cc_library( name = "distributed_tpu_rewrite_pass_internal", srcs = ["distributed_tpu_rewrite_pass_internal.cc"], diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc index 62a4c45d696017..9370cd6b01ab1c 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/config/flag_defs.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -2481,10 +2482,32 @@ Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr, return absl::OkStatus(); } +// TODO(b/355263902): Encapsulation fails for some non-TPU graphs that are +// missing full variable shape information. Remove this path once the +// underlying issue is fixed. +bool ShouldSkipEncapsulationForNonTPUGraph() { + return flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.value(); +} + } // namespace /*static*/ Status EncapsulateTPUComputationsPass::Encapsulate( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // If the graph does not contain any TPU computations, there is nothing to do. + if (ShouldSkipEncapsulationForNonTPUGraph()) { + bool found_tpu_replicate = false; + for (const Node* n : (*graph)->nodes()) { + if (n->attrs().Find(kTPUReplicateAttr) != nullptr) { + found_tpu_replicate = true; + break; + } + } + if (!found_tpu_replicate) { + VLOG(1) << "No TPU replicate found, skipping encapsulation"; + return absl::OkStatus(); + } + } + // Check for undeclared outputs before Encapsulation, so we can give a better // error message. // TODO(phawkins): merge this with the encapsulation code to avoid the extra diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc new file mode 100644 index 00000000000000..a21cdaec4dbc72 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h" + +#include + +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/config/flag_defs.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +std::unique_ptr CreateGraph() { + // c = a + b + auto g = std::make_unique(OpRegistry::Global()); + auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto in1 = test::graph::Arg(g.get(), 1, DT_FLOAT); + auto tmp = test::graph::Add(g.get(), in0, in1); + auto ret = test::graph::Retval(g.get(), 0, tmp); + g->AddControlEdge(in1, ret); + FixupSourceAndSinkEdges(g.get()); + return g; +} + +TEST(EncapsulateTPUComputationsPassTest, NonTPUGraph) { + auto g = CreateGraph(); + GraphOptimizationPassOptions options; + options.graph = &g; + options.flib_def = g->mutable_flib_def(); + + EncapsulateTPUComputationsPass pass; + TF_ASSERT_OK(pass.Run(options)); + + int nodes_meeting_expectations = 0; + + for (const auto* node : g->nodes()) { + if (!IsSource(node) && !IsSink(node)) { + ASSERT_TRUE(node->attrs().Find("_xla_inferred_shapes")); + ++nodes_meeting_expectations; + } + } + EXPECT_EQ(nodes_meeting_expectations, 4); +} + +TEST(EncapsulateTPUComputationsPassTest, SkipEncapsulationForNonTPUGraph) { + flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.reset(true); + auto g = CreateGraph(); + GraphOptimizationPassOptions options; + options.graph = &g; + options.flib_def = g->mutable_flib_def(); + + EncapsulateTPUComputationsPass pass; + TF_ASSERT_OK(pass.Run(options)); + + int nodes_meeting_expectations = 0; + + for (const auto* node : g->nodes()) { + if (!IsSource(node) && !IsSink(node)) { + ASSERT_FALSE(node->attrs().Find("_xla_inferred_shapes")); + ++nodes_meeting_expectations; + } + } + EXPECT_EQ(nodes_meeting_expectations, 4); + + flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.reset(false); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops.cc b/tensorflow/core/tpu/kernels/sharding_util_ops.cc index 4ca817b23ebc13..643f5bc588f334 100644 --- a/tensorflow/core/tpu/kernels/sharding_util_ops.cc +++ b/tensorflow/core/tpu/kernels/sharding_util_ops.cc @@ -178,7 +178,6 @@ class XlaSplitNDBaseOp : public XlaSplitNDShared { bool resource, OpKernelContext* ctx, const std::function& assign_or_copy_value_fn, const Tensor* input) { - const auto& input_shape = input->shape().dim_sizes(); absl::string_view input_name = resource ? kResourceName : kTensorName; auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index 8b89487f0b0d9b..990edbe549f3eb 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -52,8 +52,8 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/protobuf:dnn_proto_cc", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], ) @@ -118,7 +118,7 @@ tf_cuda_library( "conv_parameters.h", ], cuda_deps = [ - "@local_tsl//tsl/lib/strings:proto_serialization", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], deps = [ ":conv_parameters_proto_cc", @@ -182,12 +182,12 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core/platform:status", "//tensorflow/core/platform:str_util", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/protobuf:dnn_proto_cc", "@local_xla//xla:status_macros", "@local_xla//xla/stream_executor:dnn", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/gpu:gpu_init", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], ) diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.cc b/tensorflow/core/util/autotune_maps/autotune_serialize.cc index 63470c09df5f87..c601502a0d0512 100644 --- a/tensorflow/core/util/autotune_maps/autotune_serialize.cc +++ b/tensorflow/core/util/autotune_maps/autotune_serialize.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/util/activation_mode.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/protobuf/dnn.pb.h" namespace tensorflow { diff --git a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc index baa68aae1131c1..0bd1122c132238 100644 --- a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc +++ b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/protobuf/dnn.pb.h" namespace tensorflow { diff --git a/tensorflow/core/util/autotune_maps/conv_parameters.cc b/tensorflow/core/util/autotune_maps/conv_parameters.cc index 63436938980b68..a620e39c2b2afe 100644 --- a/tensorflow/core/util/autotune_maps/conv_parameters.cc +++ b/tensorflow/core/util/autotune_maps/conv_parameters.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/hash.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/lib/strings/proto_serialization.h" namespace tensorflow { diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h index 5c8a5dbfda4fcf..61d1fb5a19d538 100644 --- a/tensorflow/core/util/bcast.h +++ b/tensorflow/core/util/bcast.h @@ -199,7 +199,6 @@ BCastList::BCastList(const BCastList::Vec (&x)[N], prev_is_one[i] = false; current_is_one[i] = false; } - Vec output; bool output_dim_set = false; int64_t output_dim = -1; bool none_is_one = true; diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index 76abf34544f88f..6600a6b23ebd9d 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -1716,7 +1716,7 @@ void DTensorDevice::ModuleToExecutionFunctions( absl::flat_hash_set control_ret_nodes; GraphExportConfig export_config; RETURN_C_STATUS_IF_NOT_OK( - tensorflow::tf2xla::v2::ConvertMlirToGraph( + tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( *lowering_context.module, export_config, &(lowering_context.graph), flib_def, &control_ret_nodes), status); diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index 97f7d3d2a7a93d..f304d843096efb 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -436,6 +436,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@llvm-project//llvm:Support", "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", ], alwayslink = True, diff --git a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc index cff404f2095fec..a89a07521eb939 100644 --- a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc @@ -235,7 +235,6 @@ GetSpecsFromLabelsAndMap( std::vector sharding_specs(layout_rank); absl::flat_hash_map dimension_use_count; - absl::flat_hash_set dimension_use_set; for (const auto& label_and_indices : label_to_index) { const auto& loc = label_to_sharding_spec.find(label_and_indices.first); if (loc != label_to_sharding_spec.end()) { diff --git a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc index cac70c2b9848a6..61d51226141168 100644 --- a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc @@ -61,7 +61,6 @@ void GetTransposeSettings(mlir::Operation* op, bool* left_transposed, } // namespace StatusOr MatMulSPMDExpander::ExpandOp(mlir::Operation* op) { - absl::flat_hash_set reduced_dims; bool left_transposed; bool right_transposed; TF_ASSIGN_OR_RETURN(const Layout left_layout, diff --git a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc index 07e121b88424fb..737d1f562bb8ff 100644 --- a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" @@ -388,9 +389,6 @@ StatusOr ExpandSaveV2Op(mlir::Operation* op) { auto save_v2 = mlir::cast(op); mlir::OpBuilder builder(save_v2); - - absl::flat_hash_map, Layout>> - tensor_shape_layout_map; std::vector metadata; for (const auto& it : llvm::enumerate(save_v2.getTensors())) { mlir::Value tensor = it.value(); diff --git a/tensorflow/dtensor/mlir/sparse_expander_common.h b/tensorflow/dtensor/mlir/sparse_expander_common.h index 4496043bed384f..9d6115067ae2a8 100644 --- a/tensorflow/dtensor/mlir/sparse_expander_common.h +++ b/tensorflow/dtensor/mlir/sparse_expander_common.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/dtensor/mlir/spmd_expander_common.h b/tensorflow/dtensor/mlir/spmd_expander_common.h index 90b5ba5346bc34..0a35ce8032b07b 100644 --- a/tensorflow/dtensor/mlir/spmd_expander_common.h +++ b/tensorflow/dtensor/mlir/spmd_expander_common.h @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project diff --git a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc index 802d46fd27ecde..2e24e5d1f9db4c 100644 --- a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc +++ b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc @@ -165,7 +165,7 @@ Status UpdateMetadataProtoXlaSpmd(const Mesh& mesh_config, mesh_name = ""; } const std::vector& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name]; - VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", "); + VLOG(1) << "tpu_core_ids: " << absl::StrJoin(tpu_core_ids, ", "); xla::DeviceAssignmentProto device_assignment; device_assignment.set_replica_count(1); @@ -223,7 +223,7 @@ Status UpdateMetadataProtoDtensorSpmd(const Mesh& mesh_config, mesh_name = ""; } const std::vector& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name]; - VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", "); + VLOG(1) << "tpu_core_ids: " << absl::StrJoin(tpu_core_ids, ", "); xla::DeviceAssignmentProto device_assignment; device_assignment.set_replica_count(num_replicas); diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index eb41fb0e6a223c..d49976b8cf3886 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -126,6 +126,8 @@ filegroup( name = "tflite_internal_cc_3p_api_deps_src_all", srcs = [ ":tflite_internal_cc_3p_api_deps_src", + "//tensorflow/compiler/mlir/lite:tflite_internal_cc_3p_api_deps_src", + "//tensorflow/compiler/mlir/lite/core/api:tflite_internal_cc_3p_api_deps_src", "//tensorflow/compiler/mlir/lite/schema:tflite_internal_cc_3p_api_deps_src", "//tensorflow/lite/core:macros.h", "//tensorflow/lite/core/acceleration/configuration/c:tflite_internal_cc_3p_api_deps_src", @@ -141,7 +143,6 @@ filegroup( filegroup( name = "tflite_internal_cc_3p_api_deps_src", srcs = [ - ":allocation.cc", ":allocation.h", ":array.cc", ":array.h", @@ -150,7 +151,6 @@ filegroup( ":minimal_logging.cc", ":minimal_logging.h", ":minimal_logging_android.cc", - ":mmap_allocation.cc", ":mutable_op_resolver.cc", ":mutable_op_resolver.h", ":op_resolver.h", @@ -482,22 +482,16 @@ cc_library( cc_library( name = "allocation", - srcs = [ - "allocation.cc", - ] + select({ - ":tflite_mmap_disabled": [ - "mmap_allocation_disabled.cc", - ], - "//conditions:default": [ - "mmap_allocation.cc", - ], - }), hdrs = [ "allocation.h", + "//tensorflow/compiler/mlir/lite:allocation.h", ], compatible_with = get_compatible_with_portable(), copts = tflite_copts_warnings(), - deps = ["//tensorflow/lite/core/api:error_reporter"], + deps = [ + "//tensorflow/compiler/mlir/lite:allocation", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + ], ) cc_library( diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 09e9ed33c61626..1aa7dec994944a 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -275,12 +275,6 @@ list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*with_selected_ops\\.cc$") # Exclude tensorflow_profiler_logger files. list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*tensorflow_profiler_logger\\.cc$") -if(_TFLITE_ENABLE_MMAP) - list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation_disabled\\.cc$") -else() - list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation\\.cc$") -endif() - # Handle TFLite logging source. list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*minimal_logging_.*\\.cc$") if("${CMAKE_SYSTEM_NAME}" STREQUAL "Android") @@ -373,7 +367,9 @@ if(TFLITE_ENABLE_GPU) list(APPEND TFLITE_DELEGATES_GPU_SRCS ${TFLITE_SOURCE_DIR}/delegates/gpu/api.cc ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc + ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.h ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.cc + ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.h ${TFLITE_SOURCE_DIR}/delegates/gpu/tflite_profile.cc ${TFLITE_SOURCE_DIR}/experimental/acceleration/compatibility/android_info.cc ${TFLITE_DELEGATES_GPU_CL_SRCS} @@ -681,10 +677,25 @@ set(_ALL_TFLITE_SRCS ${TF_SOURCE_DIR}/compiler/mlir/lite/utils/string_utils.h ${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.h ${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/macros.h ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.h ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.h + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/verifier.h + ${TF_SOURCE_DIR}/compiler/mlir/lite/allocation.h + ${TF_SOURCE_DIR}/compiler/mlir/lite/allocation.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/mmap_allocation.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/mmap_allocation_disabled.cc ${TFLITE_SOURCE_DIR}/schema/schema_generated.h ) + +if(_TFLITE_ENABLE_MMAP) + list(FILTER _ALL_TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation_disabled\\.cc$") +else() + list(FILTER _ALL_TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation\\.cc$") +endif() + add_library(tensorflow-lite ${_ALL_TFLITE_SRCS} ) @@ -774,6 +785,9 @@ set(TFLITE_GENERATED_HEADERS_DIR ${CMAKE_BINARY_DIR}/tensorflow/lite) # Add the profiling proto directory. add_subdirectory(${TFLITE_SOURCE_DIR}/profiling/proto) +# Add the tf example directory. +add_subdirectory(${TF_SOURCE_DIR}/core/example ${CMAKE_BINARY_DIR}/example_proto_generated) + # The benchmark tool. add_subdirectory(${TFLITE_SOURCE_DIR}/tools/benchmark) diff --git a/tensorflow/lite/allocation.h b/tensorflow/lite/allocation.h index f007b3cf0e540e..b2a03a66ae36bf 100644 --- a/tensorflow/lite/allocation.h +++ b/tensorflow/lite/allocation.h @@ -18,139 +18,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ALLOCATION_H_ #define TENSORFLOW_LITE_ALLOCATION_H_ -#include - -#include -#include -#include - -#include "tensorflow/lite/core/api/error_reporter.h" - -namespace tflite { - -/// A memory allocation handle. This could be a mmap or shared memory. -class Allocation { - public: - virtual ~Allocation() {} - - enum class Type { - kMMap, - kFileCopy, - kMemory, - }; - - /// Base pointer of this allocation - virtual const void* base() const = 0; - /// Size in bytes of the allocation - virtual size_t bytes() const = 0; - /// Whether the allocation is valid - virtual bool valid() const = 0; - /// Return the type of the Allocation. - Type type() const { return type_; } - - protected: - Allocation(ErrorReporter* error_reporter, Type type) - : error_reporter_(error_reporter), type_(type) {} - ErrorReporter* error_reporter_; - - private: - const Type type_; -}; - -/// Note that not all platforms support MMAP-based allocation. -/// Use `IsSupported()` to check. -class MMAPAllocation : public Allocation { - public: - /// Loads and maps the provided file to a memory region. - MMAPAllocation(const char* filename, ErrorReporter* error_reporter); - - /// Maps the provided file descriptor to a memory region. - /// Note: The provided file descriptor will be dup'ed for usage; the caller - /// retains ownership of the provided descriptor and should close accordingly. - MMAPAllocation(int fd, ErrorReporter* error_reporter); - - /// Maps the provided file descriptor, with the given offset and length (both - /// in bytes), to a memory region. - /// Note: The provided file descriptor will be dup'ed for usage; the caller - /// retains ownership of the provided descriptor and should close accordingly. - MMAPAllocation(int fd, size_t offset, size_t length, - ErrorReporter* error_reporter); - - ~MMAPAllocation() override; - const void* base() const override; - size_t bytes() const override; - bool valid() const override; - - int fd() const { return mmap_fd_; } - - // The start address of the mmapped buffer. - // This will be base() rounded down to the nearest page boundary. - const void* mmapped_buffer() const { return mmapped_buffer_; } - - // The size of the mmapped buffer. - size_t mmapped_buffer_size() const { return bytes() + offset_in_buffer_; } - - // Offset of mmapped_buffer() in the file referenced by the file descriptor. - size_t mmapped_buffer_offset_in_file() const { - return offset_of_buffer_in_file_; - } - - static bool IsSupported(); - - protected: - // Data required for mmap. - int mmap_fd_ = -1; // mmap file descriptor - const void* mmapped_buffer_; - size_t buffer_size_bytes_ = 0; - // Used when the address to mmap is not page-aligned. - size_t offset_in_buffer_ = 0; - size_t offset_of_buffer_in_file_ = 0; - - private: - // Assumes ownership of the provided `owned_fd` instance. - MMAPAllocation(ErrorReporter* error_reporter, int owned_fd); - - // Assumes ownership of the provided `owned_fd` instance, and uses the given - // offset and length (both in bytes) for memory mapping. - MMAPAllocation(ErrorReporter* error_reporter, int owned_fd, size_t offset, - size_t length); -}; - -class FileCopyAllocation : public Allocation { - public: - /// Loads the provided file into a heap memory region. - FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); - ~FileCopyAllocation() override; - const void* base() const override; - size_t bytes() const override; - bool valid() const override; - - private: - std::unique_ptr copied_buffer_; - size_t buffer_size_bytes_ = 0; -}; - -class MemoryAllocation : public Allocation { - public: - /// Provides a (read-only) view of the provided buffer region as an - /// allocation. - /// Note: The caller retains ownership of `ptr`, and must ensure it remains - /// valid for the lifetime of the class instance. - MemoryAllocation(const void* ptr, size_t num_bytes, - ErrorReporter* error_reporter); - ~MemoryAllocation() override; - const void* base() const override; - size_t bytes() const override; - bool valid() const override; - - private: - const void* buffer_; -#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) - void* aligned_ptr_ = nullptr; -#endif - size_t buffer_size_bytes_ = 0; -}; - -} // namespace tflite +#include "tensorflow/compiler/mlir/lite/allocation.h" #endif // TENSORFLOW_LITE_ALLOCATION_H_ diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index 19cdd37ed4f549..f1664849f36e50 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -292,7 +292,10 @@ cc_test( size = "small", srcs = ["c_api_signature_runner_test.cc"], copts = tflite_copts(), - data = ["//tensorflow/lite:testdata/multi_signatures.bin"], + data = [ + "//tensorflow/lite:testdata/multi_signatures.bin", + "//tensorflow/lite:testdata/no_signatures.bin", + ], deps = [ ":c_api", "//tensorflow/lite/core/c:c_api", diff --git a/tensorflow/lite/c/c_api_signature_runner_test.cc b/tensorflow/lite/c/c_api_signature_runner_test.cc index 30614e5d7e59f5..61af71ffd863a6 100644 --- a/tensorflow/lite/c/c_api_signature_runner_test.cc +++ b/tensorflow/lite/c/c_api_signature_runner_test.cc @@ -24,6 +24,94 @@ limitations under the License. namespace tflite { namespace { +TEST(SignatureRunnerTest, TestNoSignatures) { + TfLiteModel* model = TfLiteModelCreateFromFile( + "tensorflow/lite/testdata/no_signatures.bin"); + ASSERT_NE(model, nullptr); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreate(model, /*optional_options=*/nullptr); + ASSERT_NE(interpreter, nullptr); + + int nun_signatures = TfLiteInterpreterGetSignatureCount(interpreter); + ASSERT_EQ(nun_signatures, 0); + + ASSERT_EQ(TfLiteInterpreterGetSignatureRunner(interpreter, "foo"), nullptr); + + TfLiteSignatureRunner* runner = + TfLiteInterpreterGetSignatureRunner(interpreter, nullptr); + ASSERT_NE(runner, nullptr); + + int num_interpreter_inputs = + TfLiteInterpreterGetInputTensorCount(interpreter); + int num_runner_inputs = TfLiteSignatureRunnerGetInputCount(runner); + ASSERT_EQ(num_runner_inputs, num_interpreter_inputs); + + for (int i = 0; i < num_interpreter_inputs; ++i) { + auto* interpreter_input_tensor = + TfLiteInterpreterGetInputTensor(interpreter, i); + ASSERT_NE(interpreter_input_tensor, nullptr); + auto* interpreter_input_name = TfLiteTensorName(interpreter_input_tensor); + ASSERT_NE(interpreter_input_name, nullptr); + auto* runner_input_name = TfLiteSignatureRunnerGetInputName(runner, i); + ASSERT_NE(runner_input_name, nullptr); + EXPECT_STREQ(runner_input_name, interpreter_input_name); + auto* runner_input_tensor = + TfLiteSignatureRunnerGetInputTensor(runner, interpreter_input_name); + ASSERT_NE(runner_input_tensor, nullptr); + ASSERT_EQ(runner_input_tensor, interpreter_input_tensor); + } + + int num_interpreter_outputs = + TfLiteInterpreterGetOutputTensorCount(interpreter); + int num_runner_outputs = TfLiteSignatureRunnerGetOutputCount(runner); + ASSERT_EQ(num_runner_outputs, num_interpreter_outputs); + + for (int i = 0; i < num_interpreter_outputs; ++i) { + auto* interpreter_output_tensor = + TfLiteInterpreterGetOutputTensor(interpreter, i); + ASSERT_NE(interpreter_output_tensor, nullptr); + auto* interpreter_output_name = TfLiteTensorName(interpreter_output_tensor); + ASSERT_NE(interpreter_output_name, nullptr); + auto* runner_output_name = TfLiteSignatureRunnerGetOutputName(runner, i); + ASSERT_NE(runner_output_name, nullptr); + EXPECT_STREQ(runner_output_name, interpreter_output_name); + auto* runner_output_tensor = + TfLiteSignatureRunnerGetOutputTensor(runner, interpreter_output_name); + ASSERT_NE(runner_output_tensor, nullptr); + ASSERT_EQ(runner_output_tensor, interpreter_output_tensor); + } + + std::array input_dims{2}; + ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor( + runner, "x1", input_dims.data(), input_dims.size()), + kTfLiteOk); + ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor( + runner, "x2", input_dims.data(), input_dims.size()), + kTfLiteOk); + ASSERT_EQ(TfLiteSignatureRunnerAllocateTensors(runner), kTfLiteOk); + TfLiteTensor* input1 = TfLiteSignatureRunnerGetInputTensor(runner, "x1"); + ASSERT_NE(input1, nullptr); + TfLiteTensor* input2 = TfLiteSignatureRunnerGetInputTensor(runner, "x2"); + ASSERT_NE(input2, nullptr); + ASSERT_EQ(TfLiteSignatureRunnerGetInputTensor(runner, "foo"), nullptr); + const TfLiteTensor* output = + TfLiteSignatureRunnerGetOutputTensor(runner, "Identity"); + ASSERT_NE(output, nullptr); + ASSERT_EQ(TfLiteSignatureRunnerGetOutputTensor(runner, "foo"), nullptr); + input1->data.f[0] = -8; + input1->data.f[1] = 0.5; + input2->data.f[0] = -1; + input2->data.f[1] = 1.5; + ASSERT_EQ(TfLiteSignatureRunnerInvoke(runner), kTfLiteOk); + ASSERT_EQ(output->data.f[0], 0); + ASSERT_EQ(output->data.f[1], 2); + + TfLiteSignatureRunnerDelete(runner); + TfLiteInterpreterDelete(interpreter); + TfLiteModelDelete(model); +} + TEST(SignatureRunnerTest, TestMultiSignatures) { TfLiteModel* model = TfLiteModelCreateFromFile( "tensorflow/lite/testdata/multi_signatures.bin"); diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index d3939b91f911ea..4309e28baf8e38 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -43,9 +43,7 @@ cc_library( ], compatible_with = get_compatible_with_portable(), copts = tflite_copts() + tflite_copts_warnings(), - visibility = [ - "//tensorflow/lite:__subpackages__", - ], + visibility = ["//tensorflow/lite:__subpackages__"], deps = [ ":cc_api_stable", ":signature_runner", diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index 08ac033fcb0f77..6613d1c3e14c96 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -11,7 +11,6 @@ package( filegroup( name = "tflite_internal_cc_3p_api_deps_src", srcs = [ - ":error_reporter.cc", ":error_reporter.h", ":op_resolver.cc", ":op_resolver.h", @@ -68,32 +67,41 @@ cc_library( ], deps = [ ":error_reporter", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/lite/core/c:common", "//tensorflow/lite/schema:schema_fbs", - "//tensorflow/lite/schema:schema_utils", - "@flatbuffers//:runtime_cc", ], ) cc_library( name = "error_reporter", - srcs = ["error_reporter.cc"], - hdrs = ["error_reporter.h"], + hdrs = [ + "error_reporter.h", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter.h", + ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = [ "//visibility:public", ], - deps = [], + deps = [ + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + ], ) cc_library( name = "verifier", - hdrs = ["verifier.h"], + hdrs = [ + "verifier.h", + "//tensorflow/compiler/mlir/lite/core/api:verifier.h", + ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = ["//visibility:public"], - deps = [":error_reporter"], + deps = [ + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + "//tensorflow/compiler/mlir/lite/core/api:verifier", + ], ) cc_library( @@ -108,24 +116,19 @@ cc_library( deps = [":op_resolver"], ) -cc_test( - name = "error_reporter_test", - size = "small", - srcs = ["error_reporter_test.cc"], - deps = [ - ":api", - "@com_google_googletest//:gtest_main", - ], -) - cc_test( name = "op_resolver_test", size = "small", srcs = ["op_resolver_test.cc"], deps = [ ":api", - "//tensorflow/lite/schema:schema_conversion_utils", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", + "@flatbuffers//:runtime_cc", ], ) @@ -136,7 +139,6 @@ cc_test( deps = [ ":op_resolver", ":op_resolver_internal", - "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite:mutable_op_resolver", "//tensorflow/lite/core/kernels:builtin_ops", @@ -151,6 +153,7 @@ cc_test( srcs = ["flatbuffer_conversions_test.cc"], deps = [ ":api", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", "//tensorflow/lite:string", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/c:c_api_types", diff --git a/tensorflow/lite/core/api/error_reporter.h b/tensorflow/lite/core/api/error_reporter.h index 1e0ef7dc913a44..f9106046b2f231 100644 --- a/tensorflow/lite/core/api/error_reporter.h +++ b/tensorflow/lite/core/api/error_reporter.h @@ -15,58 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ #define TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ -#include - -namespace tflite { - -/// A functor that reports error to supporting system. Invoked similar to -/// printf. -/// -/// Usage: -/// ErrorReporter foo; -/// foo.Report("test %d", 5); -/// or -/// va_list args; -/// foo.Report("test %d", args); // where args is va_list -/// -/// Subclass ErrorReporter to provide another reporting destination. -/// For example, if you have a GUI program, you might redirect to a buffer -/// that drives a GUI error log box. -class ErrorReporter { - public: - virtual ~ErrorReporter() = default; - /// Converts `args` to character equivalents according to `format` string, - /// constructs the error string and report it. - /// Returns number of characters written or zero on success, and negative - /// number on error. - virtual int Report(const char* format, va_list args) = 0; - - /// Converts arguments to character equivalents according to `format` string, - /// constructs the error string and report it. - /// Returns number of characters written or zero on success, and negative - /// number on error. - int Report(const char* format, ...); - - /// Equivalent to `Report` above. The additional `void*` parameter is unused. - /// This method is for compatibility with macros that takes `TfLiteContext`, - /// like TF_LITE_ENSURE and related macros. - int ReportError(void*, const char* format, ...); -}; - -} // namespace tflite - -// You should not make bare calls to the error reporter, instead use the -// TF_LITE_REPORT_ERROR macro, since this allows message strings to be -// stripped when the binary size has to be optimized. If you are looking to -// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and -// every call will be stubbed out, taking no memory. -#ifndef TF_LITE_STRIP_ERROR_STRINGS -#define TF_LITE_REPORT_ERROR(reporter, ...) \ - do { \ - static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \ - } while (false) -#else // TF_LITE_STRIP_ERROR_STRINGS -#define TF_LITE_REPORT_ERROR(reporter, ...) -#endif // TF_LITE_STRIP_ERROR_STRINGS +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 35268103be8792..c27e4e6f8b82a9 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -20,9 +20,8 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index c01e8875813f93..de287af21c8a5b 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc index 87c897dfc0928e..98c8c910ac1d84 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/core/api/op_resolver.cc b/tensorflow/lite/core/api/op_resolver.cc index ce5ae4f406eb6a..214490c874d7ad 100644 --- a/tensorflow/lite/core/api/op_resolver.cc +++ b/tensorflow/lite/core/api/op_resolver.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/lite/core/api/op_resolver.h" -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/schema/schema_utils.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/lite/core/api/op_resolver.h b/tensorflow/lite/core/api/op_resolver.h index 7aff7cafea1783..f6f5fd214d187a 100644 --- a/tensorflow/lite/core/api/op_resolver.h +++ b/tensorflow/lite/core/api/op_resolver.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/core/api/op_resolver_internal_test.cc b/tensorflow/lite/core/api/op_resolver_internal_test.cc index d052e9c7bab8ee..b62df374c483ef 100644 --- a/tensorflow/lite/core/api/op_resolver_internal_test.cc +++ b/tensorflow/lite/core/api/op_resolver_internal_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/lite/core/kernels/builtin_op_kernels.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/lite/core/api/op_resolver_test.cc b/tensorflow/lite/core/api/op_resolver_test.cc index 45fcdcf81dac18..59b08ad21864dc 100644 --- a/tensorflow/lite/core/api/op_resolver_test.cc +++ b/tensorflow/lite/core/api/op_resolver_test.cc @@ -18,7 +18,13 @@ limitations under the License. #include #include -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace { diff --git a/tensorflow/lite/core/api/verifier.h b/tensorflow/lite/core/api/verifier.h index 8128ff31e1ea85..dcb1d029b5678a 100644 --- a/tensorflow/lite/core/api/verifier.h +++ b/tensorflow/lite/core/api/verifier.h @@ -18,22 +18,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_API_VERIFIER_H_ #define TENSORFLOW_LITE_CORE_API_VERIFIER_H_ -#include "tensorflow/lite/core/api/error_reporter.h" - -namespace tflite { - -/// Abstract interface that verifies whether a given model is legit. -/// It facilitates the use-case to verify and build a model without loading it -/// twice. -/// (See also "tensorflow/lite/tools/verifier.h".) -class TfLiteVerifier { - public: - /// Returns true if the model is legit. - virtual bool Verify(const char* data, int length, - ErrorReporter* reporter) = 0; - virtual ~TfLiteVerifier() {} -}; - -} // namespace tflite +#include "tensorflow/compiler/mlir/lite/core/api/verifier.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_CORE_API_VERIFIER_H_ diff --git a/tensorflow/lite/core/async/BUILD b/tensorflow/lite/core/async/BUILD index 625104252899a1..ca2f3caac2906a 100644 --- a/tensorflow/lite/core/async/BUILD +++ b/tensorflow/lite/core/async/BUILD @@ -38,8 +38,9 @@ cc_test( name = "task_internal_test", srcs = ["task_internal_test.cc"], deps = [ - ":async_kernel_internal", ":task_internal", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", "//tensorflow/lite/core/async/c:types", "//tensorflow/lite/core/async/interop/c:types", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/core/async/async_signature_runner_test.cc b/tensorflow/lite/core/async/async_signature_runner_test.cc index bb5e23b31111d1..3eb075ac143b35 100644 --- a/tensorflow/lite/core/async/async_signature_runner_test.cc +++ b/tensorflow/lite/core/async/async_signature_runner_test.cc @@ -183,7 +183,7 @@ TEST_F(AsyncSignatureRunnerNoSignatureDefTest, GetAsyncSignatureRunner) { TEST_F(AsyncSignatureRunnerNoSignatureDefTest, InputsTest) { signature_runner_ = interpreter_->GetAsyncSignatureRunner(nullptr); EXPECT_EQ(1, signature_runner_->input_size()); - EXPECT_EQ(0, signature_runner_->input_names().size()); + EXPECT_EQ(1, signature_runner_->input_names().size()); EXPECT_EQ(1, signature_runner_->inputs().size()); EXPECT_NE(nullptr, signature_runner_->tensor(signature_runner_->inputs()[0])); @@ -192,7 +192,7 @@ TEST_F(AsyncSignatureRunnerNoSignatureDefTest, InputsTest) { TEST_F(AsyncSignatureRunnerNoSignatureDefTest, OutputsTest) { signature_runner_ = interpreter_->GetAsyncSignatureRunner(nullptr); EXPECT_EQ(1, signature_runner_->output_size()); - EXPECT_EQ(0, signature_runner_->output_names().size()); + EXPECT_EQ(1, signature_runner_->output_names().size()); EXPECT_EQ(1, signature_runner_->outputs().size()); EXPECT_NE(nullptr, diff --git a/tensorflow/lite/core/async/c/BUILD b/tensorflow/lite/core/async/c/BUILD index e9a8bf9ae6c7cc..0f6bb9c62bc2d8 100644 --- a/tensorflow/lite/core/async/c/BUILD +++ b/tensorflow/lite/core/async/c/BUILD @@ -118,6 +118,9 @@ cc_test( name = "async_signature_runner_test", srcs = ["async_signature_runner_test.cc"], copts = tflite_copts() + tflite_copts_warnings(), + data = [ + "//tensorflow/lite:testdata/no_signatures.bin", + ], deps = [ ":async_signature_runner", ":internal", diff --git a/tensorflow/lite/core/async/c/async_signature_runner_test.cc b/tensorflow/lite/core/async/c/async_signature_runner_test.cc index 2648e5028ed84b..1e2b54dacd55f3 100644 --- a/tensorflow/lite/core/async/c/async_signature_runner_test.cc +++ b/tensorflow/lite/core/async/c/async_signature_runner_test.cc @@ -182,9 +182,10 @@ TEST_P(AsyncSignatureRunnerTest, InputsTest) { "x", TfLiteOpaqueTensorName( TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "input"))); } else { - EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetInputName(runner_, 0)); - EXPECT_EQ(nullptr, - TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "input")); + EXPECT_STREQ("x", TfLiteAsyncSignatureRunnerGetInputName(runner_, 0)); + EXPECT_STREQ("x", + TfLiteOpaqueTensorName( + TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "x"))); } } @@ -198,9 +199,10 @@ TEST_P(AsyncSignatureRunnerTest, OutputsTest) { "a", TfLiteOpaqueTensorName( TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "output"))); } else { - EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetOutputName(runner_, 0)); - EXPECT_EQ(nullptr, - TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "output")); + EXPECT_STREQ("a", TfLiteAsyncSignatureRunnerGetOutputName(runner_, 0)); + EXPECT_STREQ("a", + TfLiteOpaqueTensorName( + TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "a"))); } } @@ -229,5 +231,93 @@ TEST_P(AsyncSignatureRunnerTest, IndexOutOfBound) { EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetTensor(runner_, 42)); } +TEST(AsyncSignatureRunnerTest, TestNoSignatures) { + TfLiteModel* model = TfLiteModelCreateFromFile( + "third_party/tensorflow/lite/testdata/no_signatures.bin"); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + ASSERT_NE(options, nullptr); + auto kernel = + std::make_unique<::testing::StrictMock>(); + auto backend = std::make_unique(kernel->kernel()); + TfLiteInterpreterOptionsAddDelegate(options, backend->get_delegate()); + + TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); + ASSERT_NE(interpreter, nullptr); + + TfLiteInterpreterOptionsDelete(options); + + int nun_signatures = TfLiteInterpreterGetSignatureCount(interpreter); + ASSERT_EQ(nun_signatures, 0); + + ASSERT_EQ(TfLiteInterpreterGetAsyncSignatureRunner(interpreter, "foo"), + nullptr); + + TfLiteAsyncSignatureRunner* runner = + TfLiteInterpreterGetAsyncSignatureRunner(interpreter, nullptr); + ASSERT_NE(runner, nullptr); + + int num_interpreter_inputs = + TfLiteInterpreterGetInputTensorCount(interpreter); + int num_runner_inputs = TfLiteAsyncSignatureRunnerGetInputCount(runner); + ASSERT_EQ(num_runner_inputs, num_interpreter_inputs); + + for (int i = 0; i < num_interpreter_inputs; ++i) { + auto* interpreter_input_tensor = + TfLiteInterpreterGetInputTensor(interpreter, i); + ASSERT_NE(interpreter_input_tensor, nullptr); + auto* interpreter_input_name = TfLiteTensorName(interpreter_input_tensor); + ASSERT_NE(interpreter_input_name, nullptr); + auto* runner_input_name = TfLiteAsyncSignatureRunnerGetInputName(runner, i); + ASSERT_NE(runner_input_name, nullptr); + EXPECT_STREQ(runner_input_name, interpreter_input_name); + auto* runner_input_tensor = TfLiteAsyncSignatureRunnerGetInputTensor( + runner, interpreter_input_name); + ASSERT_NE(runner_input_tensor, nullptr); + ASSERT_EQ(runner_input_tensor, reinterpret_cast( + interpreter_input_tensor)); + } + + int num_interpreter_outputs = + TfLiteInterpreterGetOutputTensorCount(interpreter); + int num_runner_outputs = TfLiteAsyncSignatureRunnerGetOutputCount(runner); + ASSERT_EQ(num_runner_outputs, num_interpreter_outputs); + + for (int i = 0; i < num_interpreter_outputs; ++i) { + auto* interpreter_output_tensor = + TfLiteInterpreterGetOutputTensor(interpreter, i); + ASSERT_NE(interpreter_output_tensor, nullptr); + auto* interpreter_output_name = TfLiteTensorName(interpreter_output_tensor); + ASSERT_NE(interpreter_output_name, nullptr); + auto* runner_output_name = + TfLiteAsyncSignatureRunnerGetOutputName(runner, i); + ASSERT_NE(runner_output_name, nullptr); + EXPECT_STREQ(runner_output_name, interpreter_output_name); + auto* runner_output_tensor = TfLiteAsyncSignatureRunnerGetOutputTensor( + runner, interpreter_output_name); + ASSERT_NE(runner_output_tensor, nullptr); + ASSERT_EQ(runner_output_tensor, reinterpret_cast( + interpreter_output_tensor)); + } + + EXPECT_CALL(*kernel, Prepare(_, _)).WillOnce(Return(kTfLiteOk)); + EXPECT_CALL(*kernel, Eval(_, _, _)).WillOnce(Return(kTfLiteOk)); + EXPECT_CALL(*kernel, Wait(_, _)).WillOnce(Return(kTfLiteOk)); + EXPECT_CALL(*kernel, Finish(_, _)).WillOnce(Return(kTfLiteOk)); + + EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerPrepareBackends(runner)); + + auto* task = TfLiteAsyncSignatureRunnerCreateTask(runner); + + EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerInvokeAsync(runner, task)); + EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerWait(runner, task)); + EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerFinish(runner, task)); + + TfLiteAsyncSignatureRunnerDelete(runner); + TfLiteInterpreterDelete(interpreter); + TfLiteModelDelete(model); +} + } // namespace async } // namespace tflite diff --git a/tensorflow/lite/core/async/task_internal_test.cc b/tensorflow/lite/core/async/task_internal_test.cc index d63eb03e89767f..b0dc1ae385917f 100644 --- a/tensorflow/lite/core/async/task_internal_test.cc +++ b/tensorflow/lite/core/async/task_internal_test.cc @@ -17,7 +17,8 @@ limitations under the License. #include #include -#include "tensorflow/lite/core/async/async_kernel_internal.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/async/c/types.h" #include "tensorflow/lite/core/async/interop/c/types.h" diff --git a/tensorflow/lite/core/interpreter.cc b/tensorflow/lite/core/interpreter.cc index 9d1623c5b4821f..dd9feb9b90c3be 100644 --- a/tensorflow/lite/core/interpreter.cc +++ b/tensorflow/lite/core/interpreter.cc @@ -520,7 +520,13 @@ void Interpreter::AddProfiler(std::unique_ptr profiler) { } impl::SignatureRunner* Interpreter::GetSignatureRunner( - const char* signature_key) { + const char* signature_key_) { + auto [signature_key, empty_signature_fallback] = + ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_); + if (!signature_key) { + return nullptr; + } + auto iter = signature_runner_map_.find(signature_key); if (iter != signature_runner_map_.end()) { return &(iter->second); @@ -533,6 +539,14 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner( return nullptr; } + if (empty_signature_fallback) { + placeholder_signature_def_ = CreatePlaceholderSignatureDef(); + auto status = signature_runner_map_.insert( + {signature_key, SignatureRunner(placeholder_signature_def_.get(), + &primary_subgraph())}); + return &(status.first->second); + } + for (const auto& signature : signature_defs_) { if (signature.signature_key == signature_key) { auto status = signature_runner_map_.insert( @@ -541,7 +555,56 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner( return &(status.first->second); } } + return nullptr; } +std::unique_ptr +Interpreter::CreatePlaceholderSignatureDef() { + auto placeholder_signature_def = std::make_unique(); + for (auto i = 0; i < inputs().size(); ++i) { + auto* name = GetInputName(i); + placeholder_signature_def->inputs[name] = inputs()[i]; + } + for (auto i = 0; i < outputs().size(); ++i) { + auto* name = GetOutputName(i); + placeholder_signature_def->outputs[name] = outputs()[i]; + } + placeholder_signature_def->signature_key = kPlaceholderSignatureDefKey; + placeholder_signature_def->subgraph_index = 0; + return placeholder_signature_def; +} + +std::pair +Interpreter::ReplaceWithPlaceholderSignatureKeyIfNeeded( + const char* signature_key) { + // Handles nullptr signature key. + // If the model does not have signature def, use default name as placeholder. + // Otherwise use the first signature key that points to primary subgraph. + bool empty_signature_fallback = false; + if (signature_key == nullptr) { + if (signature_defs_.empty()) { + signature_key = kPlaceholderSignatureDefKey; + empty_signature_fallback = true; + } else { + for (const auto& signature : signature_defs_) { + if (signature.subgraph_index == 0) { + signature_key = signature.signature_key.c_str(); + break; + } + } + } + } + + if (signature_key == nullptr) { + // The model has signature def but none of those points to primary subgraph. + TF_LITE_REPORT_ERROR(error_reporter_, + "The model has signature def but none of those points " + "to primary subgraph."); + return {nullptr, empty_signature_fallback}; + } else { + return {signature_key, empty_signature_fallback}; + } +} + } // namespace tflite diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h index 4a3fb131da3c14..f26a15dcd0b9b8 100644 --- a/tensorflow/lite/core/interpreter.h +++ b/tensorflow/lite/core/interpreter.h @@ -335,21 +335,25 @@ class Interpreter { } /// \brief Returns a pointer to the SignatureRunner instance to run the part - /// of the graph identified by a SignatureDef. The nullptr is returned if the - /// given signature key is not valid. + /// of the graph identified by a SignatureDef. If the model does not have any + /// signature defs, pass nullptr as signature_key and a SignatureRunner will + /// be created using the primary subgraph (0). A nullptr is returned if the + /// given signature_key is not valid. Note, the returned SignatureRunner + /// instance is owned by and has the same lifetime as the Interpreter object; + /// additionally, class SignatureRunner is *not* thread-safe. /// If you need to specify delegates, you have to do that before calling this /// function. This function will additionally apply default delegates. Thus, /// applying delegates after that might lead to undesirable behaviors. - /// Note, the pointed instance has lifetime same as the Interpreter object - /// and the SignatureRunner class is *not* thread-safe. SignatureRunner* GetSignatureRunner(const char* signature_key); - /// \warning Experimental interface, subject to change. \n - /// \brief Returns a pointer to the AsyncSignatureRunner instance to run the - /// part of the graph identified by a SignatureDef. The nullptr is returned if - /// the given signature key is not valid. - /// if the model does not have signature def, pass nullptr to signature_key - /// and AsyncSignatureRunner will be created using primary subgraph (0). + /// \warning Experimental interface, subject to change. \n \brief Returns a + /// pointer to the AsyncSignatureRunner instance to run the part of the graph + /// identified by a SignatureDef. If the model does not have any signature + /// defs, pass nullptr as signature_key and an AsyncSignatureRunner will be + /// created using the primary subgraph (0). A nullptr is returned if the + /// given signature_key is not valid. Note, the returned AsyncSignatureRunner + /// instance is owned by and has the same lifetime as the Interpreter object; + /// additionally, class AsyncSignatureRunner is *not* thread-safe. /// The async delegate should be applied before calling this function. async::AsyncSignatureRunner* GetAsyncSignatureRunner( const char* signature_key); @@ -905,6 +909,10 @@ class Interpreter { TfLiteStatus ApplyOptionsImpl(InterpreterOptions* options); + std::unique_ptr CreatePlaceholderSignatureDef(); + std::pair ReplaceWithPlaceholderSignatureKeyIfNeeded( + const char* signature_key); + // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. @@ -964,6 +972,13 @@ class Interpreter { // List of SignatureDefs obtained from the model. std::vector signature_defs_; + // Default signature key to use when the model has no signatures. + static constexpr char kPlaceholderSignatureDefKey[] = + ""; + + // Placeholder SignatureDef for legacy models with no signatures. + std::unique_ptr placeholder_signature_def_; + // Map of signature key to its corresponding SignatureRunner object. // A SignatureRunner is basically a wrapper of the Subgraph corresponding to // its SignatureDef. diff --git a/tensorflow/lite/core/interpreter_experimental.cc b/tensorflow/lite/core/interpreter_experimental.cc index 7eef090791df8f..4a7bca720d8239 100644 --- a/tensorflow/lite/core/interpreter_experimental.cc +++ b/tensorflow/lite/core/interpreter_experimental.cc @@ -34,10 +34,6 @@ limitations under the License. namespace tflite { -namespace { -static constexpr char kDefaultServingSignatureDefKey[] = "serving_default"; -} // namespace - TfLiteStatus Interpreter::SetCustomAllocationForTensor( int tensor_index, const TfLiteCustomAllocation& allocation, int64_t flags) { return primary_subgraph().SetCustomAllocationForTensor(tensor_index, @@ -145,27 +141,10 @@ TfLiteStatus Interpreter::ApplyOptions(InterpreterOptions* options) { } async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner( - const char* signature_key) { - // Handles nullptr signature key. - // If the model does not have signature def, use default name as placeholder. - // Otherwise use the first signature key that points to primary subgraph. - bool empty_signature_fallback = false; - if (signature_key == nullptr) { - if (signature_defs_.empty()) { - signature_key = kDefaultServingSignatureDefKey; - empty_signature_fallback = true; - } else { - for (const auto& signature : signature_defs_) { - if (signature.subgraph_index == 0) { - signature_key = signature.signature_key.c_str(); - break; - } - } - } - } - - if (signature_key == nullptr) { - // The model has signature def but none of those points to primary subgraph. + const char* signature_key_) { + auto [signature_key, empty_signature_fallback] = + ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_); + if (!signature_key) { return nullptr; } @@ -175,11 +154,14 @@ async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner( } if (empty_signature_fallback) { + placeholder_signature_def_ = CreatePlaceholderSignatureDef(); auto status = async_signature_runner_map_.insert( {signature_key, - async::AsyncSignatureRunner(nullptr, &primary_subgraph())}); + async::AsyncSignatureRunner(placeholder_signature_def_.get(), + &primary_subgraph())}); return &(status.first->second); } + for (const auto& signature : signature_defs_) { if (signature.signature_key == signature_key) { auto status = async_signature_runner_map_.insert( diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 87f701f5adc427..d5ae8d056bde0d 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -364,6 +364,24 @@ tf_cc_test( ], ) +tf_cc_test( + name = "allowlisted_flex_ops_test", + size = "small", + srcs = [ + "allowlisted_flex_ops_test.cc", + ], + features = tf_features_nolayering_check_if_ios(), + deps = [ + ":delegate", + "//tensorflow/compiler/mlir/lite/delegates/flex:allowlisted_flex_ops_lib", + "@com_google_googletest//:gtest_main", + ] + if_mobile([ + "//tensorflow/core:portable_tensorflow_lib_lite", + ]) + if_not_mobile([ + "//tensorflow/core:framework", + ]), +) + # Alias to support selective build of image ops. # TODO(b/163285312): Remove after tensorflow/core refactoring completed. cc_library( diff --git a/tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops_test.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops_test.cc similarity index 100% rename from tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops_test.cc rename to tensorflow/lite/delegates/flex/allowlisted_flex_ops_test.cc diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 2be6504b9a5878..d66d66b544a608 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -85,6 +85,7 @@ cc_library( "//tensorflow/lite:minimal_logging", "//tensorflow/lite/core/c:common", "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_builder", "//tensorflow/lite/delegates/gpu/common:model_transformer", @@ -95,9 +96,14 @@ cc_library( "//tensorflow/lite/delegates/gpu/gl:api", "//tensorflow/lite/delegates/gpu/gl:command_queue", "//tensorflow/lite/delegates/gpu/gl:compiler", + "//tensorflow/lite/delegates/gpu/gl:compiler_options", "//tensorflow/lite/delegates/gpu/gl:egl_environment", + "//tensorflow/lite/delegates/gpu/gl:gl_buffer", "//tensorflow/lite/delegates/gpu/gl:gl_call", + "//tensorflow/lite/delegates/gpu/gl:object", + "//tensorflow/lite/delegates/gpu/gl:object_manager", "//tensorflow/lite/delegates/gpu/gl:request_gpu_info", + "//tensorflow/lite/delegates/gpu/gl:runtime_options", "//tensorflow/lite/delegates/gpu/gl/converters:bhwc_to_phwc4", "//tensorflow/lite/delegates/gpu/gl/converters:phwc4_to_bhwc", "//tensorflow/lite/delegates/gpu/gl/kernels:registry", diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 73f192d17ebf0c..b84cb9a71a46f0 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -299,6 +299,7 @@ cc_library( ":cl_kernel", ":program_cache", ":tensor", + "//tensorflow/lite/delegates/gpu/common/task:compiler_options", "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc index 1cc1738d071d44..8fd94938b57258 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/lite/delegates/gpu/common/task/compiler_options.h" + namespace tflite { namespace gpu { namespace cl { @@ -165,6 +167,10 @@ absl::Status ClOperation::Compile(const CreationContext& creation_context) { creation_context.context, &operation_->args_, &operation_->code_)); operation_->args_.ReleaseCPURepresentation(); + if (creation_context.device->info_.opencl_info.IsCLVK()) { + operation_->compiler_options_.push_back( + CompilerOptions::kClFastRelaxedMath); + } RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( operation_->code_, "main_function", operation_->compiler_options_, *creation_context.context, *creation_context.device, &kernel_, diff --git a/tensorflow/lite/delegates/gpu/common/model.cc b/tensorflow/lite/delegates/gpu/common/model.cc index dc68e702ac2328..a7a174f60f54d2 100644 --- a/tensorflow/lite/delegates/gpu/common/model.cc +++ b/tensorflow/lite/delegates/gpu/common/model.cc @@ -333,10 +333,16 @@ absl::Status GraphFloat32::MakeExactCopy(GraphFloat32* model) const { model->nodes_.clear(); model->execution_plan_.clear(); model->values_.clear(); + model->known_graph_outputs_.clear(); for (auto& value_def : values_) { model->values_.push_back({}); if (value_def.value) { model->values_.back().value = std::make_unique(*value_def.value); + if (std::find(known_graph_outputs_.begin(), known_graph_outputs_.end(), + value_def.value.get()) != known_graph_outputs_.end()) { + model->known_graph_outputs_.push_back( + model->values_.back().value.get()); + } } } // Add all nodes first. diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc index ae3e4e5438a5d4..804eac531e26f9 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h" +#include #include #include #include @@ -59,12 +60,12 @@ class AddBias : public NodeTransformation { "runtime input."}; } auto& attr = - absl::any_cast(node->operation.attributes); + std::any_cast(node->operation.attributes); return FillBias(attr.weights.shape.o, &attr.bias); } if (node->operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { - auto& attr = absl::any_cast( + auto& attr = std::any_cast( node->operation.attributes); return FillBias(attr.weights.shape.o, &attr.bias); } @@ -76,17 +77,17 @@ class AddBias : public NodeTransformation { "with one " "runtime input."}; } - auto& attr = absl::any_cast( + auto& attr = std::any_cast( node->operation.attributes); return FillBias(attr.weights.shape.o * attr.weights.shape.i, &attr.bias); } if (node->operation.type == ToString(OperationType::FULLY_CONNECTED)) { auto& attr = - absl::any_cast(node->operation.attributes); + std::any_cast(node->operation.attributes); return FillBias(attr.weights.shape.o, &attr.bias); } if (node->operation.type == ToString(OperationType::FULLY_CONNECTED_INT8)) { - auto& attr = absl::any_cast( + auto& attr = std::any_cast( node->operation.attributes); return FillBias(attr.weights.shape.o, &attr.bias); } @@ -97,7 +98,7 @@ class AddBias : public NodeTransformation { } // namespace std::unique_ptr NewAddBias() { - return absl::make_unique(); + return std::make_unique(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc index 361b6d0ebf1322..66040d03aa8cde 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h" +#include #include +#include #include #include @@ -34,7 +36,7 @@ namespace tflite { namespace gpu { namespace { -void AddQuantParams(absl::optional* params, float min, +void AddQuantParams(std::optional* params, float min, float max, float scale) { params->emplace(); params->value().min = min; @@ -154,7 +156,7 @@ TEST(AddQuantAdjustments, GeneralCase) { graph.nodes()[2]->operation.type); EXPECT_EQ(quant_node->id, graph.nodes()[2]->id); EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[3]->operation.type); - auto new_quant_attr = absl::any_cast( + auto new_quant_attr = std::any_cast( graph.nodes()[1]->operation.attributes); EXPECT_EQ(0.0, new_quant_attr.min); EXPECT_EQ(2.0, new_quant_attr.max); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc index 0236bfa4326ce0..4500b0ed50655a 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -39,8 +39,8 @@ namespace { void FuseBiasWithAddAttributes(const ElementwiseAttributes& add_attr, const int channels, Tensor* bias) { - auto add = absl::get_if>(&add_attr.param); - auto add_scalar = absl::get_if(&add_attr.param); + auto add = std::get_if>(&add_attr.param); + auto add_scalar = std::get_if(&add_attr.param); if (bias->data.empty()) { *bias = MakeZeroTensor(Linear(channels)); } @@ -65,35 +65,35 @@ class MergeConvolutionWithAdd : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } ElementwiseAttributes add_attr = - absl::any_cast(add_node.operation.attributes); - if (!absl::holds_alternative>( + std::any_cast(add_node.operation.attributes); + if (!std::holds_alternative>( add_attr.param) && - !absl::holds_alternative(add_attr.param)) { + !std::holds_alternative(add_attr.param)) { return {TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar addition."}; } if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { Convolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseConvolution2DWithAdd(add_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { ConvolutionTransposedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseConvolutionTransposedWithAdd(add_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { DepthwiseConvolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseDepthwiseConvolution2DWithAdd(add_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::FULLY_CONNECTED)) { FullyConnectedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseFullyConnectedWithAdd(add_attr, conv_attr); } else { @@ -112,8 +112,8 @@ class MergeConvolutionWithAdd : public SequenceTransformation { void FuseAddWithConvolution2D(const ElementwiseAttributes& add_attr, Convolution2DAttributes* attr) { - auto add = absl::get_if>(&add_attr.param); - auto add_scalar = absl::get_if(&add_attr.param); + auto add = std::get_if>(&add_attr.param); + auto add_scalar = std::get_if(&add_attr.param); if (attr->bias.data.empty()) { attr->bias = MakeZeroTensor( Linear(attr->weights.shape.o)); @@ -149,17 +149,17 @@ class MergeAddWithConvolution : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } ElementwiseAttributes add_attr = - absl::any_cast(add_node.operation.attributes); - if (!absl::holds_alternative>( + std::any_cast(add_node.operation.attributes); + if (!std::holds_alternative>( add_attr.param) && - !absl::holds_alternative(add_attr.param)) { + !std::holds_alternative(add_attr.param)) { return {TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar addition."}; } if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { Convolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); if (conv_attr->groups != 1) { return {TransformStatus::DECLINED, @@ -191,11 +191,11 @@ class MergeAddWithConvolution : public SequenceTransformation { } // namespace std::unique_ptr NewMergeConvolutionWithAdd() { - return absl::make_unique(); + return std::make_unique(); } std::unique_ptr NewMergeAddWithConvolution() { - return absl::make_unique(); + return std::make_unique(); } void FuseConvolution2DWithAdd(const ElementwiseAttributes& add_attr, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc index ca2ec7caba7805..fc6c3e2975c98d 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc @@ -224,7 +224,7 @@ TEST(MergeAddWithConvolutionTest, Smoke) { graph.nodes()[0]->operation.type); Convolution2DAttributes* conv_attr_new = - absl::any_cast( + std::any_cast( &graph.nodes()[0]->operation.attributes); EXPECT_THAT(conv_attr_new->bias.data, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc index 507456a8fefe15..6496c77ac07163 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc @@ -55,10 +55,10 @@ class MergeConvolutionWithMul : public SequenceTransformation { } ElementwiseAttributes mul_attr = - absl::any_cast(mul_node.operation.attributes); - if (!absl::holds_alternative>( + std::any_cast(mul_node.operation.attributes); + if (!std::holds_alternative>( mul_attr.param) && - !absl::holds_alternative(mul_attr.param)) { + !std::holds_alternative(mul_attr.param)) { return { TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar multiplication."}; @@ -66,25 +66,25 @@ class MergeConvolutionWithMul : public SequenceTransformation { if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { Convolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseConvolution2DWithMultiply(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { ConvolutionTransposedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseConvolutionTransposedWithMultiply(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { DepthwiseConvolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseDepthwiseConvolution2DWithMultiply(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::FULLY_CONNECTED)) { FullyConnectedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseFullyConnectedWithMultiply(mul_attr, conv_attr); } else { @@ -119,10 +119,10 @@ class MergeMulWithConvolution : public SequenceTransformation { } ElementwiseAttributes mul_attr = - absl::any_cast(mul_node.operation.attributes); - if (!absl::holds_alternative>( + std::any_cast(mul_node.operation.attributes); + if (!std::holds_alternative>( mul_attr.param) && - !absl::holds_alternative(mul_attr.param)) { + !std::holds_alternative(mul_attr.param)) { return { TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar multiplication."}; @@ -130,25 +130,25 @@ class MergeMulWithConvolution : public SequenceTransformation { if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { Convolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseMultiplyWithConvolution2D(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { ConvolutionTransposedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseMultiplyWithConvolutionTransposed(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { DepthwiseConvolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseMultiplyWithDepthwiseConvolution2D(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::FULLY_CONNECTED)) { FullyConnectedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseMultiplyWithFullyConnected(mul_attr, conv_attr); } else { @@ -168,17 +168,17 @@ class MergeMulWithConvolution : public SequenceTransformation { } // namespace std::unique_ptr NewMergeConvolutionWithMul() { - return absl::make_unique(); + return std::make_unique(); } std::unique_ptr NewMergeMulWithConvolution() { - return absl::make_unique(); + return std::make_unique(); } void FuseConvolution2DWithMultiply(const ElementwiseAttributes& mul_attr, Convolution2DAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int d = 0; d < attr->weights.shape.o; ++d) { const float multiplier = mul ? mul->data[d] : *mul_scalar; for (int s = 0; s < attr->weights.shape.i; ++s) { @@ -198,8 +198,8 @@ void FuseConvolution2DWithMultiply(const ElementwiseAttributes& mul_attr, void FuseDepthwiseConvolution2DWithMultiply( const ElementwiseAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int g = 0; g < attr->weights.shape.o; ++g) { for (int s = 0; s < attr->weights.shape.i; ++s) { const int d = s * attr->weights.shape.o + g; @@ -220,8 +220,8 @@ void FuseDepthwiseConvolution2DWithMultiply( void FuseConvolutionTransposedWithMultiply( const ElementwiseAttributes& mul_attr, ConvolutionTransposedAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int d = 0; d < attr->weights.shape.o; ++d) { const float multiplier = mul ? mul->data[d] : *mul_scalar; for (int s = 0; s < attr->weights.shape.i; ++s) { @@ -240,8 +240,8 @@ void FuseConvolutionTransposedWithMultiply( void FuseFullyConnectedWithMultiply(const ElementwiseAttributes& mul_attr, FullyConnectedAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int d = 0; d < attr->weights.shape.o; ++d) { const float multiplier = mul ? mul->data[d] : *mul_scalar; for (int s = 0; s < attr->weights.shape.i; ++s) { @@ -256,8 +256,8 @@ void FuseFullyConnectedWithMultiply(const ElementwiseAttributes& mul_attr, void FuseMultiplyWithConvolution2D(const ElementwiseAttributes& mul_attr, Convolution2DAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int s = 0; s < attr->weights.shape.i; ++s) { const float multiplier = mul ? mul->data[s] : *mul_scalar; for (int d = 0; d < attr->weights.shape.o; ++d) { @@ -274,8 +274,8 @@ void FuseMultiplyWithConvolution2D(const ElementwiseAttributes& mul_attr, void FuseMultiplyWithDepthwiseConvolution2D( const ElementwiseAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int s = 0; s < attr->weights.shape.i; ++s) { const float multiplier = mul ? mul->data[s] : *mul_scalar; for (int g = 0; g < attr->weights.shape.o; ++g) { @@ -292,8 +292,8 @@ void FuseMultiplyWithDepthwiseConvolution2D( void FuseMultiplyWithConvolutionTransposed( const ElementwiseAttributes& mul_attr, ConvolutionTransposedAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int s = 0; s < attr->weights.shape.i; ++s) { const float multiplier = mul ? mul->data[s] : *mul_scalar; for (int d = 0; d < attr->weights.shape.o; ++d) { @@ -309,8 +309,8 @@ void FuseMultiplyWithConvolutionTransposed( void FuseMultiplyWithFullyConnected(const ElementwiseAttributes& mul_attr, FullyConnectedAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int s = 0; s < attr->weights.shape.i; ++s) { const float multiplier = mul ? mul->data[s] : *mul_scalar; for (int d = 0; d < attr->weights.shape.o; ++d) { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc index 3034c91c0929d3..fc3dec545a21b3 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h" +#include #include #include #include @@ -56,7 +57,7 @@ class GlobalPoolingToReduceOp : public NodeTransformation { auto inputs = graph->FindInputs(node->id); auto outputs = graph->FindOutputs(node->id); const auto& pool_attr = - absl::any_cast(node->operation.attributes); + std::any_cast(node->operation.attributes); if (!IsGlobalAveragePooling(pool_attr, inputs[0]->tensor.shape, outputs[0]->tensor.shape)) { return {TransformStatus::SKIPPED, ""}; @@ -75,7 +76,7 @@ class GlobalPoolingToReduceOp : public NodeTransformation { } // namespace std::unique_ptr NewGlobalPoolingToReduceOp() { - return absl::make_unique(); + return std::make_unique(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc index 226e7d4b2a9696..d8e7aebb2a8960 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" +#include #include #include #include @@ -56,7 +57,7 @@ class MakeFullyConnectedFromConvolution : public NodeTransformation { return {TransformStatus::SKIPPED, ""}; } - const auto& conv_attr = absl::any_cast( + const auto& conv_attr = std::any_cast( node->operation.attributes); if (!IsConvEquivalentToFullyConnected(conv_attr)) { return {TransformStatus::SKIPPED, ""}; @@ -76,7 +77,7 @@ class MakeFullyConnectedFromConvolution : public NodeTransformation { } // namespace std::unique_ptr NewMakeFullyConnectedFromConvolution() { - return absl::make_unique(); + return std::make_unique(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc index 783dcb02aa7d1a..24ae7894949cf6 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" +#include #include #include #include @@ -102,7 +103,7 @@ TEST(MakeFullyConnected, Smoke) { graph.nodes()[1]->operation.type); ASSERT_EQ(ToString(OperationType::FULLY_CONNECTED), graph.nodes()[2]->operation.type); - auto fc_attr = absl::any_cast( + auto fc_attr = std::any_cast( graph.nodes()[2]->operation.attributes); EXPECT_EQ(OHWI(32, 1, 1, 16), fc_attr.weights.shape); EXPECT_EQ(Linear(32), fc_attr.bias.shape); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc index 6245f82289a6bb..865024002929f0 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" +#include #include #include #include @@ -37,7 +38,7 @@ bool IsConstZeros(const Node& node) { return false; } auto& attr = - absl::any_cast(node.operation.attributes); + std::any_cast(node.operation.attributes); for (auto f : attr.tensor.data) { if (f != 0) { return false; @@ -62,7 +63,7 @@ class MakePaddingFromZerosConcat : public NodeTransformation { auto dep = graph->FindProducer(input->id); if (dep != nullptr && IsConstZeros(*dep)) { auto& concat_attr = - absl::any_cast(node->operation.attributes); + std::any_cast(node->operation.attributes); PadAttributes pad_attr; pad_attr.type = PaddingContentType::ZEROS; pad_attr.appended = BHWC(0, 0, 0, 0); @@ -101,7 +102,7 @@ class MakePaddingFromZerosConcat : public NodeTransformation { } // namespace std::unique_ptr NewMakePaddingFromConcat() { - return absl::make_unique(); + return std::make_unique(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc index c33960c21d0eac..abe3594d0cdbd1 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" +#include #include #include #include @@ -71,7 +72,7 @@ TEST(MakePadding, Smoke) { ASSERT_EQ(2, graph.values().size()); auto pad_node = graph.nodes()[0]; ASSERT_EQ(ToString(OperationType::PAD), pad_node->operation.type); - auto pad_attr = absl::any_cast(pad_node->operation.attributes); + auto pad_attr = std::any_cast(pad_node->operation.attributes); EXPECT_EQ(BHWC(0, 0, 0, 0), pad_attr.prepended); EXPECT_EQ(BHWC(0, 5, 0, 0), pad_attr.appended); } diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc index 995cbd17af470c..7703de58f51330 100644 --- a/tensorflow/lite/delegates/gpu/gl_delegate.cc +++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc @@ -26,9 +26,11 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tensorflow/lite/builtin_ops.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_builder.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" @@ -38,18 +40,21 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h" #include "tensorflow/lite/delegates/gpu/gl/api.h" #include "tensorflow/lite/delegates/gpu/gl/command_queue.h" -#include "tensorflow/lite/delegates/gpu/gl/compiler.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h" #include "tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h" #include "tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h" #include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/kernels/registry.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/object_manager.h" #include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/runtime_options.h" #include "tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #ifndef TFLITE_GPU_BINARY_RELEASE -#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/gpu/gl/metadata_generated.h" #include "tensorflow/lite/schema/schema_generated.h" #endif // TFLITE_GPU_BINARY_RELEASE diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index cdf4bdcf58d331..43ff934dbdf758 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -234,9 +234,10 @@ cc_library( }) + select({ ":xnnpack_use_transient_indirection_buffers_explicit": ["-DXNNPACK_DELEGATE_USE_TRANSIENT_INDIRECTION_BUFFERS=1"], "//conditions:default": [], - }), + }) + ["-DFLATBUFFERS_LOCALE_INDEPENDENT=0"], linkstatic = True, deps = [ + ":flexbuffers_util", ":quantization_util", ":tflite_with_xnnpack_dynamic_fully_connected", ":tflite_with_xnnpack_logging", @@ -260,6 +261,7 @@ cc_library( "//tensorflow/lite/tools/optimize:reduced_precision_support", "@XNNPACK", "@eigen_archive//:eigen3", + "@flatbuffers//:runtime_cc", "@pthreadpool", ], ) @@ -278,9 +280,10 @@ cc_library( name = "xnnpack_delegate_test_mode", srcs = ["xnnpack_delegate.cc"], hdrs = ["xnnpack_delegate.h"], - copts = tflite_copts() + ["-DXNNPACK_DELEGATE_TEST_MODE=1"], + copts = tflite_copts() + ["-DXNNPACK_DELEGATE_TEST_MODE=1"] + ["-DFLATBUFFERS_LOCALE_INDEPENDENT=0"], linkstatic = True, deps = [ + ":flexbuffers_util", ":quantization_util", ":weight_cache", "//tensorflow/lite:kernel_api", @@ -299,6 +302,7 @@ cc_library( "//tensorflow/lite/tools/optimize:reduced_precision_support", "@XNNPACK", "@eigen_archive//:eigen3", + "@flatbuffers//:runtime_cc", "@pthreadpool", ], ) @@ -341,6 +345,15 @@ cc_library( ], ) +cc_library( + name = "flexbuffers_util", + hdrs = ["flexbuffers_util.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@flatbuffers//:runtime_cc", + ], +) + ################################ Tester classes ################################ cc_library( @@ -2937,4 +2950,14 @@ cc_test( ], ) +cc_test( + name = "flexbuffers_util_test", + srcs = ["flexbuffers_util_test.cc"], + deps = [ + ":flexbuffers_util", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:runtime_cc", + ], +) + tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]}) diff --git a/tensorflow/lite/delegates/xnnpack/flexbuffers_util.h b/tensorflow/lite/delegates/xnnpack/flexbuffers_util.h new file mode 100644 index 00000000000000..6f303c8a92a2da --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/flexbuffers_util.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_FLEXBUFFERS_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_XNNPACK_FLEXBUFFERS_UTIL_H_ + +#include "flatbuffers/base.h" // from @flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers + +namespace tflite::xnnpack { +// We use this class defined with internal linkage as a key to prevent the +// following workaround to leak into other translation units. +struct FloatPointer { + const float* ptr = nullptr; +}; +} // namespace tflite::xnnpack + +namespace flexbuffers { + +// TODO(b/359351192): switch to xnnpack builtin. This is a workaround until we +// are able to use just the value. +// +// We go around the access policy of the `Reference` class by specializing a +// template function that was not specialized for our use case. +// +// This is weakly tolerant to an update to the `Reference` class because: +// - THIS IS MEANT TO BE TEMPORARY until we actually use the XNNPack +// implementation of SDPA (and dependent on not needing data ptr). +// - The flexbuffer spec is public and set, so the layout should not evolve +// much. +// +// The alternative was to copy/paste the code to get to the map data and grab +// the pointer which basically means rewriting flexbuffer.h. +template <> +tflite::xnnpack::FloatPointer inline flexbuffers::Reference::As< + tflite::xnnpack::FloatPointer>() const { +#if !FLATBUFFERS_LITTLEENDIAN + // Flexbuffers are always stored in little endian order. Returning a pointer + // to the float data on a big endian architecture is meaningless. + return nullptr; +#else + return {IsFloat() ? reinterpret_cast(data_) : nullptr}; +#endif +} + +} // namespace flexbuffers + +#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_FLEXBUFFERS_UTIL_H_ diff --git a/tensorflow/lite/delegates/xnnpack/flexbuffers_util_test.cc b/tensorflow/lite/delegates/xnnpack/flexbuffers_util_test.cc new file mode 100644 index 00000000000000..d3e112bea1547c --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/flexbuffers_util_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/xnnpack/flexbuffers_util.h" + +#include +#include +#include "flatbuffers/flexbuffers.h" // from @flatbuffers + +namespace tflite::xnnpack { +namespace { + +using ::testing::Pointee; + +TEST(FlexbuffersUtilTest, FloatPointer) { + constexpr float kAValue = 3.14; + constexpr float kBValue = 56; + + flexbuffers::Builder fbb; + fbb.Map([&] { + fbb.Float("a", kAValue); + fbb.Float("b", kBValue); + }); + fbb.Finish(); + + const flexbuffers::Map map = flexbuffers::GetRoot(fbb.GetBuffer()).AsMap(); + + const flexbuffers::Reference a = map["a"]; + EXPECT_TRUE(a.IsFloat()); + EXPECT_THAT(a.As().ptr, Pointee(kAValue)); + + const flexbuffers::Reference b = map["b"]; + EXPECT_TRUE(b.IsFloat()); + EXPECT_THAT(b.As().ptr, Pointee(kBValue)); + + const flexbuffers::Reference c = map["c"]; + ASSERT_TRUE(c.IsNull()); + EXPECT_EQ(c.As().ptr, nullptr); +} + +} // namespace +} // namespace tflite::xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc b/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc index bf54f45cf04233..0a2c6d85cfa02d 100644 --- a/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc +++ b/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -35,6 +36,17 @@ struct SDPATestParams { int head_dim; // embedding_dim//q_heads }; +void PrintTo(const SDPATestParams& p, std::ostream* os) { + if (p.model_name != kOdmlSdpaCustom) { + *os << "{ TFLite file: " << p.model_name << ".tflite.bin }"; + } else { + *os << "{ Custom test: " << p.custom_test_name << ", b:" << p.batch + << ", isl:" << p.input_seq_len << ", msl:" << p.max_seq_len + << ", q:" << p.q_heads << ", k:" << p.kv_heads << "h:" << p.head_dim + << " }"; + } +} + std::string TestName(const testing::TestParamInfo& info) { if (info.param.model_name != kOdmlSdpaCustom) { return info.param.model_name; diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc index 5a9120fad27f3e..0af79ba33cb2ab 100644 --- a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc @@ -119,7 +119,7 @@ void ODMLSDPATester::Test(TfLiteDelegate* delegate) const { std::vector ODMLSDPATester::CreateTfLiteModel() const { if (!model_name_.empty() && model_name_ != kOdmlSdpaCustom) { const char kTestModelFolder[] = - "third_party/tensorflow/lite/delegates/xnnpack/"; + "tensorflow/lite/delegates/xnnpack/"; const std::string test_model = kTestModelFolder + model_name_ + ".tflite.bin"; std::string model_data; diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index db5cf51f6845ee..65b3475b75552a 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xnnpack.h" // from @XNNPACK #include "Eigen/Core" // from @eigen_archive +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "pthreadpool.h" // from @pthreadpool #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_types.h" @@ -41,6 +42,7 @@ limitations under the License. #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/delegates/xnnpack/flexbuffers_util.h" #include "tensorflow/lite/delegates/xnnpack/quantization_util.h" #include "tensorflow/lite/delegates/xnnpack/weight_cache.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" @@ -6701,22 +6703,14 @@ class Subgraph { const TfLiteTensor* tensors, const uint8_t* buffer, const size_t buffer_size, const std::unordered_map& input_output_tensors) { - const float* scale_val = nullptr; - // ensure 28 bytes as we expect - // TODO(b/339106680): this reading method may not work for every case. - if (buffer_size == 28 && sizeof(float) == 4) { - // Custom data here is a flexbuffer map. - // byte_width is 4 for our map. - // First 5 values are "scale", then is the float value, and last is - // flexbuffer metadata. - if (strcmp("scale", reinterpret_cast(buffer)) == 0) { - constexpr size_t kScaleValOffset = 20; - scale_val = reinterpret_cast(buffer + kScaleValOffset); - } - } - + flexbuffers::Map flexbuffer_map = + flexbuffers::GetRoot(buffer, buffer_size).AsMap(); + const float* const scale_ptr = + flexbuffer_map["scale"].As().ptr; + const float* const cap_ptr = + flexbuffer_map["logit_cap"].As().ptr; return VisitDotAttentionNode(subgraph, delegate, logging_context, - node_index, node, tensors, scale_val, + node_index, node, tensors, scale_ptr, cap_ptr, input_output_tensors); } @@ -6724,6 +6718,7 @@ class Subgraph { xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, const float* scale_param, + const float* cap_param, const std::unordered_map& input_output_tensors) { const TfLiteTensor& query_proj = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( @@ -6946,7 +6941,45 @@ class Subgraph { permute_q_out_id, reshape_dims_k_out_id, XNN_INVALID_VALUE_ID, fc_out_id, /*flags=*/0)); } - // TODO(b/323195341): add CapTanh support. + if (cap_param != nullptr) { + uint32_t cap_val_id = XNN_INVALID_VALUE_ID; + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, + /*dims=*/nullptr, cap_param, + XNN_INVALID_VALUE_ID, 0, &cap_val_id)); + uint32_t cap_div_out_id = XNN_INVALID_VALUE_ID; + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, + /*dims=*/nullptr, nullptr, + XNN_INVALID_VALUE_ID, 0, &cap_div_out_id)); + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_divide(subgraph, default_out_min, default_out_max, + fc_out_id, cap_val_id, cap_div_out_id, + /*flags=*/0)); + uint32_t cap_tanh_out_id = XNN_INVALID_VALUE_ID; + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, + /*dims=*/nullptr, nullptr, + XNN_INVALID_VALUE_ID, 0, &cap_tanh_out_id)); + TF_LITE_ENSURE_EQ(logging_context, xnn_status_success, + xnn_define_tanh(subgraph, cap_div_out_id, + cap_tanh_out_id, /*flags=*/0)); + uint32_t cap_logits_id = XNN_INVALID_VALUE_ID; + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, + /*dims=*/nullptr, nullptr, + XNN_INVALID_VALUE_ID, 0, &cap_logits_id)); + TF_LITE_ENSURE_EQ(logging_context, xnn_status_success, + xnn_define_multiply2(subgraph, default_out_min, + default_out_max, cap_tanh_out_id, + cap_val_id, cap_logits_id, 0)); + fc_out_id = cap_logits_id; + } // element_add atten_mask and matmul_out uint32_t padded_logits_id = XNN_INVALID_VALUE_ID; TF_LITE_ENSURE_EQ( diff --git a/tensorflow/lite/java/aar_with_jni.bzl b/tensorflow/lite/java/aar_with_jni.bzl index 808183ad93b16b..f2770119daaadf 100644 --- a/tensorflow/lite/java/aar_with_jni.bzl +++ b/tensorflow/lite/java/aar_with_jni.bzl @@ -6,7 +6,8 @@ def aar_with_jni( name, android_library, headers = None, - flatten_headers = False): + flatten_headers = False, + strip_headers_prefix = ""): """Generates an Android AAR with repo root license given an Android library target. Args: @@ -18,6 +19,7 @@ def aar_with_jni( generated .aar file. This is useful for distributing self-contained .aars with native libs that can be used directly by native clients. flatten_headers: Whether to flatten the output paths of included headers. + strip_headers_prefix: The prefix to strip from the output paths of included headers. """ # Generate dummy AndroidManifest.xml for dummy apk usage @@ -83,9 +85,14 @@ zip $$origdir/$(location :{1}.aar) LICENSE """.format(src) else: cmd += """ - mkdir -p headers/$$(dirname $(location {0})) - cp -RL $$origdir/$(location {0}) headers/$(location {0}) - """.format(src) + default_dir=$$(dirname $(rootpath {0})) + modified_dir=$$(echo $$default_dir | sed -e 's/^{1}//g') + mkdir -p headers/$$modified_dir + cp -RL $$origdir/$(location {0}) headers/$$modified_dir + if [ -n "{1}" ]; then + sed -i -e 's/^#include \"{1}/#include \"/g' headers/$$modified_dir/$$(basename $(location {0})) + fi + """.format(src, strip_headers_prefix.replace("/", "\\/")) cmd += "zip -r $$origdir/$(location :{0}.aar) headers".format(name) native.genrule( diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc index 6c72d9003ea76c..d1eb3130c78c04 100644 --- a/tensorflow/lite/kernels/batch_matmul.cc +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -441,7 +440,6 @@ RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) { return swapped_shape; } -template TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, const RuntimeShape& input_shape, const TfLiteTensor* input, @@ -494,18 +492,10 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, output_size *= output_shape.Dims(i); } std::fill_n(GetTensorData(output), output_size, 0.0f); - if (kernel_type == kGenericOptimized) { - optimized_ops::BatchMatMul( - filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, - input_offset_ptr, row_sums_ptr, GetTensorShape(output), - GetTensorData(accum_scratch), GetTensorData(output), - &(data->compute_row_sums), CpuBackendContext::GetFromContext(context)); - } else { - reference_ops::BatchMatMul( - filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, - input_offset_ptr, row_sums_ptr, GetTensorShape(output), - GetTensorData(output), &(data->compute_row_sums)); - } + reference_ops::BatchMatMul( + filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, + input_offset_ptr, row_sums_ptr, GetTensorShape(output), + GetTensorData(output), &(data->compute_row_sums)); return kTfLiteOk; } @@ -638,9 +628,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* row_sums; TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/6, &row_sums)); - return EvalHybrid( - context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized, - scaling_factors, accum_scratch, row_sums, input_offsets, output); + return EvalHybrid(context, node, data, lhs_shape, lhs, rhs_shape, rhs, + input_quantized, scaling_factors, accum_scratch, row_sums, + input_offsets, output); } else if (lhs->type == kTfLiteInt8 && rhs->type == kTfLiteInt8) { if (output->type == kTfLiteInt8) { return EvalInt8Int8(context, data, lhs_shape, lhs, rhs_shape, diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index d08592faec5856..40f3b812825497 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -126,7 +126,7 @@ constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional constexpr int kFwOutputTensor = 0; constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set. -// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc) +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc) // Temporary tensors. enum TemporaryTensor { diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc index 4813b7c84204e9..e58c1471457318 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc @@ -61,7 +61,7 @@ constexpr int kBwAuxWeightsTensor = 11; // Optional. constexpr int kFwOutputTensor = 0; constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false. -// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc) +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc) // Temporary tensors. enum TemporaryTensor { diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc index 4190fd7121c30f..d92701059822f6 100644 --- a/tensorflow/lite/kernels/embedding_lookup.cc +++ b/tensorflow/lite/kernels/embedding_lookup.cc @@ -104,13 +104,13 @@ TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node, // Propagate empty tensor if input is empty return kTfLiteOk; } - const int row_bytes = value->bytes / row_size; + const int64_t row_bytes = value->bytes / row_size; char* output_raw = GetTensorData(output); const char* value_raw = GetTensorData(value); const int32_t* lookup_data = GetTensorData(lookup); for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { - int idx = lookup_data[i]; + int64_t idx = lookup_data[i]; if (idx >= row_size || idx < 0) { TF_LITE_KERNEL_LOG(context, "Embedding Lookup: index out of bounds. " diff --git a/tensorflow/lite/kernels/embedding_lookup_test.cc b/tensorflow/lite/kernels/embedding_lookup_test.cc index d13ddd443f6891..493d086aa50804 100644 --- a/tensorflow/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/lite/kernels/embedding_lookup_test.cc @@ -92,6 +92,19 @@ class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { } } } + + template + void Set2DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int64_t rows = tensor->dims->data[0]; + int64_t columns = tensor->dims->data[1]; + T* data = GetTensorData(tensor); + for (int64_t i = 0; i < rows; i++) { + for (int64_t j = 0; j < columns; j++) { + data[i * columns + j] = function(i, j); + } + } + } }; class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { @@ -144,6 +157,28 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { }))); } +#if !defined(MEMORY_SANITIZER) && !defined(GOOGLE_UNSUPPORTED_OS_LOONIX) && \ + defined(__LP64__) +TEST(EmbeddingLookupOpTest, LargeTableTest) { + EmbeddingLookupOpModel m({1}, {256000, 9216}); + // Choose a value specifically designed to overflow int32.max + m.SetInput({235248}); + m.Set2DWeightMatrix( + [](int i, int j) -> float { return j + i / 100.; }); + + // This will cause a lookup at index 235248 in a buffer where every row + // has 9216 entries * 4 bytes per entry, which will overflow unless + // the Op is using a 64-bit offset for address calculation. + ASSERT_EQ(m.Invoke(), kTfLiteOk); + std::vector exp(9216); + + for (int s = 0; s < exp.size(); s++) { + exp[s] = static_cast(s) + 2352.48f; + } + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(exp))); +} +#endif + TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestUint8) { HybridEmbeddingLookupOpModel m({3}, {3, 8}, TensorType_UINT8); m.SetInput({1, 0, 2}); diff --git a/tensorflow/lite/kernels/internal/averagepool_quantized_test.cc b/tensorflow/lite/kernels/internal/averagepool_quantized_test.cc index fea343ae6b8824..5173586d423ab5 100644 --- a/tensorflow/lite/kernels/internal/averagepool_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/averagepool_quantized_test.cc @@ -34,11 +34,11 @@ namespace { // are the same. void RunOneAveragePoolTest(const PoolParams& params, const RuntimeShape& input_shape, - const int8* input_data, + const int8_t* input_data, const RuntimeShape& output_shape) { const int buffer_size = output_shape.FlatSize(); - std::vector optimized_averagePool_output(buffer_size); - std::vector reference_averagePool_output(buffer_size); + std::vector optimized_averagePool_output(buffer_size); + std::vector reference_averagePool_output(buffer_size); bool reference_success = reference_integer_ops::AveragePool( params, input_shape, input_data, output_shape, @@ -86,7 +86,7 @@ void CreateDataAndRunAveragePool(bool padding_same) { auto output_shape = RuntimeShape({batch, output_height, output_width, output_depth}); const int buffer_size = input_shape.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandom(&input_data); PoolParams params; @@ -172,17 +172,17 @@ void CreateExtremalDataAndRunAveragePool(bool padding_same) { filter_height, output_height); const int buffer_size = input_shape.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); // Test small values - int8 min = std::numeric_limits::min(); - int8 max = std::numeric_limits::min() + 10; + int8_t min = std::numeric_limits::min(); + int8_t max = std::numeric_limits::min() + 10; FillRandom(&input_data, min, max); RunOneAveragePoolTest(params, input_shape, input_data.data(), output_shape); // Test large values - min = std::numeric_limits::max() - 10; - max = std::numeric_limits::max(); + min = std::numeric_limits::max() - 10; + max = std::numeric_limits::max(); FillRandom(&input_data, min, max); RunOneAveragePoolTest(params, input_shape, input_data.data(), output_shape); } diff --git a/tensorflow/lite/kernels/internal/conv_per_channel_quantized_16x8_test.cc b/tensorflow/lite/kernels/internal/conv_per_channel_quantized_16x8_test.cc index 562797bfffeb0e..f0ad42b2cd100f 100644 --- a/tensorflow/lite/kernels/internal/conv_per_channel_quantized_16x8_test.cc +++ b/tensorflow/lite/kernels/internal/conv_per_channel_quantized_16x8_test.cc @@ -38,8 +38,8 @@ namespace { void PickOutputMultiplier( const ConvParams& params, const RuntimeShape& input_shape, - const int16* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, + const int16_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, const std::int64_t* bias_data, const RuntimeShape& output_shape, float* output_multiplier) { const int stride_width = params.stride_width; @@ -81,9 +81,9 @@ void PickOutputMultiplier( (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height); if (is_point_inside_image) { - int32 input_val = input_data[Offset(input_shape, batch, in_y, - in_x, in_channel)]; - int32 filter_val = + int32_t input_val = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + int32_t filter_val = filter_data[Offset(filter_shape, output_channel, filter_y, filter_x, in_channel)]; acc += static_cast(filter_val) * @@ -296,8 +296,8 @@ void TryTestOneConvFilter(int test_num) { for (int c = 0; c < output_shape_inference.Dims(3); c++) { int offset = Offset(output_shape_inference, n, h, w, c); float float_res = output_data_float.data()[offset]; - int16 int16_res = reference_output_data.data()[offset]; - int32 output_mul = output_multiplier.data()[c]; + int16_t int16_res = reference_output_data.data()[offset]; + int32_t output_mul = output_multiplier.data()[c]; int shift = output_shift.data()[c]; float scale = (float)output_mul / (float)(1ULL << 31); if (shift > 0) scale = scale * (float)(1 << shift); diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_16x8_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_16x8_test.cc index 7d586c5ac94430..f395cdd13ff18b 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_16x8_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_16x8_test.cc @@ -38,8 +38,8 @@ namespace { void PickOutputMultiplier( const DepthwiseParams& params, const RuntimeShape& input_shape, - const int16* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, + const int16_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, const std::int64_t* bias_data, const RuntimeShape& output_shape, float* output_multiplier) { const int stride_width = params.stride_width; @@ -81,9 +81,9 @@ void PickOutputMultiplier( (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height); if (is_point_inside_image) { - int32 input_val = input_data[Offset(input_shape, batch, in_y, - in_x, in_channel)]; - int32 filter_val = filter_data[Offset( + int32_t input_val = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + int32_t filter_val = filter_data[Offset( filter_shape, 0, filter_y, filter_x, output_channel)]; acc += static_cast(filter_val) * static_cast(input_val); @@ -286,8 +286,8 @@ void TryTestOneDepthwiseConv3x3Filter() { for (int c = 0; c < output_shape_inference.Dims(3); c++) { int offset = Offset(output_shape_inference, n, h, w, c); float float_res = output_data_float.data()[offset]; - int16 int16_res = reference_output_data.data()[offset]; - int32 output_mul = output_multiplier.data()[c]; + int16_t int16_res = reference_output_data.data()[offset]; + int32_t output_mul = output_multiplier.data()[c]; int shift = output_shift.data()[c]; float scale = (float)output_mul / (float)(1ULL << 31); if (shift > 0) scale = scale * (float)(1 << shift); diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc index 8336b63b0ba48e..716b0fce731298 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc @@ -39,9 +39,9 @@ namespace { void PickOutputMultiplier( const DepthwiseParams& params, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, float* output_multiplier) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; @@ -50,7 +50,7 @@ void PickOutputMultiplier( const int pad_width = params.padding_values.width; const int pad_height = params.padding_values.height; const int depth_multiplier = params.depth_multiplier; - const int32 input_offset = params.input_offset; + const int32_t input_offset = params.input_offset; const int batches = MatchingDim(input_shape, 0, output_shape, 0); const int input_height = input_shape.Dims(1); @@ -72,7 +72,7 @@ void PickOutputMultiplier( const int output_channel = m + in_channel * depth_multiplier; const int in_x_origin = (out_x * stride_width) - pad_width; const int in_y_origin = (out_y * stride_height) - pad_height; - int32 acc = 0; + int32_t acc = 0; for (int filter_y = 0; filter_y < filter_height; ++filter_y) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { const int in_x = in_x_origin + dilation_width_factor * filter_x; @@ -83,9 +83,9 @@ void PickOutputMultiplier( (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height); if (is_point_inside_image) { - int32 input_val = input_data[Offset(input_shape, batch, in_y, - in_x, in_channel)]; - int32 filter_val = filter_data[Offset( + int32_t input_val = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + int32_t filter_val = filter_data[Offset( filter_shape, 0, filter_y, filter_x, output_channel)]; acc += filter_val * (input_val + input_offset); } diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index c9d301ab9564c3..d5a2da2b9d58f8 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -124,7 +124,7 @@ inline void DispatchDepthwiseConvGeneral( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const std::int32_t* output_shift_adjust, const std::int32_t* output_multiplier_adjust, const RuntimeShape& output_shape, @@ -139,11 +139,11 @@ inline void DispatchDepthwiseConvGeneral( template <> inline void DispatchDepthwiseConvGeneral( const DepthwiseParams& params, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const std::int32_t* output_shift_adjust, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const std::int32_t* output_shift_adjust, const std::int32_t* output_multiplier_adjust, - const RuntimeShape& output_shape, int8* output_data, int thread_start, + const RuntimeShape& output_shape, int8_t* output_data, int thread_start, int thread_end, int thread_dim) { optimized_integer_ops::depthwise_conv::DepthwiseConvGeneral( params, output_multiplier_adjust, output_shift_adjust, input_shape, @@ -160,7 +160,7 @@ inline void DispatchDepthwiseConvImpl( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, typename QuantizationTypeImpl::ExternalType* output_data) { @@ -349,7 +349,7 @@ inline void DispatchDepthwiseConvImpl( CpuBackendContext backend_context; backend_context.SetMaxNumThreads(test_param.num_threads); optimized_ops::DepthwiseConv< - typename QuantizationTypeImpl::ExternalType, int32>( + typename QuantizationTypeImpl::ExternalType, int32_t>( params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data, &backend_context); } @@ -363,7 +363,7 @@ inline void DispatchDepthwiseConvImpl( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl< QuantizationType::kPerChannelInt8>::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, typename QuantizationTypeImpl< QuantizationType::kPerChannelInt8>::ExternalType* output_data) { @@ -530,7 +530,7 @@ inline void DispatchDepthwiseConv( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, typename QuantizationTypeImpl::ExternalType* output_data) { @@ -546,10 +546,10 @@ template <> struct ReferenceRunner { static inline void Run( const TestParam& test_param, const tflite::DepthwiseParams& op_params, - const uint8* input_data, const RuntimeShape& input_shape, - const uint8* filter_data, const RuntimeShape& filter_shape, + const uint8_t* input_data, const RuntimeShape& input_shape, + const uint8_t* filter_data, const RuntimeShape& filter_shape, const std::int32_t* bias_data, const RuntimeShape& bias_shape, - const RuntimeShape& output_shape, uint8* reference_output_data) { + const RuntimeShape& output_shape, uint8_t* reference_output_data) { switch (test_param.output_rounding) { case DepthwiseConvOutputRounding::kUpward: reference_ops::depthwise_conv::DepthwiseConvBasicKernel< @@ -577,10 +577,10 @@ template <> struct ReferenceRunner { static inline void Run( const TestParam& test_param, const tflite::DepthwiseParams& op_params, - const int8* input_data, const RuntimeShape& input_shape, - const int8* filter_data, const RuntimeShape& filter_shape, + const int8_t* input_data, const RuntimeShape& input_shape, + const int8_t* filter_data, const RuntimeShape& filter_shape, const std::int32_t* bias_data, const RuntimeShape& bias_shape, - const RuntimeShape& output_shape, int8* reference_output_data) { + const RuntimeShape& output_shape, int8_t* reference_output_data) { switch (test_param.output_rounding) { case DepthwiseConvOutputRounding::kUpward: reference_ops::depthwise_conv::DepthwiseConvBasicKernel< @@ -646,8 +646,8 @@ int TestOneDepthwiseConvWithGivenOutputShift( op_params.output_shift = -output_shift; const int depth = output_shape.Dims(3); - std::vector output_multiplier_per_channel(depth, output_multiplier); - std::vector output_shift_per_channel(depth, -output_shift); + std::vector output_multiplier_per_channel(depth, output_multiplier); + std::vector output_shift_per_channel(depth, -output_shift); if (output_multiplier_adjust != nullptr) { for (int i = 0; i < depth; ++i) { output_multiplier_per_channel[i] += output_multiplier_adjust[i]; @@ -898,8 +898,10 @@ bool TryTestDepthwiseConv(const TestParam& test_param, if (test_param.quantization_type == QuantizationType::kPerChannelInt8) { std::vector input_data(input_buffer_size); std::vector filter_data(filter_buffer_size); - FillRandom(&input_data, static_cast(-127), static_cast(127)); - FillRandom(&filter_data, static_cast(-127), static_cast(127)); + FillRandom(&input_data, static_cast(-127), + static_cast(127)); + FillRandom(&filter_data, static_cast(-127), + static_cast(127)); std::int32_t filter_offset = 0; EXPECT_TRUE(params_specialization == ParamsSpecialization::kSymmetric); diff --git a/tensorflow/lite/kernels/internal/log_quantized_test.cc b/tensorflow/lite/kernels/internal/log_quantized_test.cc index 2a27a097d2ab4c..7d0a549cbe180c 100644 --- a/tensorflow/lite/kernels/internal/log_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/log_quantized_test.cc @@ -53,23 +53,24 @@ class LogQuantizedTest : public ::testing::Test { }; // input_integer_bits <= 30. output_integer_bits > 0. -inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits, - int output_integer_bits) { +inline int32_t LogPositiveValuesViaFloat(int32_t input_val, + int input_integer_bits, + int output_integer_bits) { const double float_log_sum_of_exps = std::log( static_cast(input_val) * 0.5 / (1 << (30 - input_integer_bits))); static constexpr double min_int = - static_cast(std::numeric_limits::min()); + static_cast(std::numeric_limits::min()); static constexpr double max_int = - static_cast(std::numeric_limits::max()); + static_cast(std::numeric_limits::max()); double double_result = tflite::TfLiteRound(float_log_sum_of_exps * (1 << (31 - output_integer_bits))); return static_cast( std::min(max_int, std::max(min_int, double_result))); } -void CheckOutputData(const std::vector& test_output, - const std::vector& reference_output, - const std::vector& test_input, +void CheckOutputData(const std::vector& test_output, + const std::vector& reference_output, + const std::vector& test_input, const string& check_label, int input_integer_bits, int output_integer_bits, int tolerance) { // In the special case of small input, specifically raw value of 5, a rounding @@ -107,8 +108,8 @@ void CheckOutputData(const std::vector& test_output, } } -void RightShiftVector(const std::vector& shifts, - std::vector* vec) { +void RightShiftVector(const std::vector& shifts, + std::vector* vec) { const int n = vec->size(); ASSERT_EQ(n, shifts.size()); for (int i = 0; i < n; ++i) { @@ -117,15 +118,15 @@ void RightShiftVector(const std::vector& shifts, } template -void RunSingleTest(const std::vector& test_input, +void RunSingleTest(const std::vector& test_input, const string& check_label, int tolerance) { const int n = test_input.size(); - std::vector float_gen_output(n, 0); - std::vector quantized_output(n, 0); + std::vector float_gen_output(n, 0); + std::vector quantized_output(n, 0); // Workaround the stupid things that intelligent humans do. // Consequence of __builtin_clz(0u) may equal 31 instead of 32. - std::vector fudged_input(n, 0); + std::vector fudged_input(n, 0); for (int i = 0; i < n; ++i) { fudged_input[i] = std::max(test_input[i], 2); } @@ -134,7 +135,7 @@ void RunSingleTest(const std::vector& test_input, quantized_output[i] = tflite::log_x_for_x_greater_than_or_equal_to_1_impl( - gemmlowp::FixedPoint::FromRaw( + gemmlowp::FixedPoint::FromRaw( fudged_input[i])) .raw(); float_gen_output[i] = LogPositiveValuesViaFloat( @@ -151,8 +152,9 @@ void RunSingleTest(const std::vector& test_input, } template -void RunSingleTest(const std::vector& test_input, int input_integer_bits, - const string& check_label, int tolerance) { +void RunSingleTest(const std::vector& test_input, + int input_integer_bits, const string& check_label, + int tolerance) { #define INPUT_CASE(K) \ case K: \ return RunSingleTest(test_input, check_label, \ @@ -195,9 +197,9 @@ void RunSingleTest(const std::vector& test_input, int input_integer_bits, #undef INPUT_CASE } -void RunSingleTest(const std::vector& test_input, int input_integer_bits, - int output_integer_bits, const string& check_label, - int tolerance) { +void RunSingleTest(const std::vector& test_input, + int input_integer_bits, int output_integer_bits, + const string& check_label, int tolerance) { #define OUTPUT_CASE(K) \ case K: \ return RunSingleTest(test_input, input_integer_bits, check_label, \ @@ -248,9 +250,9 @@ void RunUniformTest(int test_size, int input_integer_bits, test_data[0] = 2; test_data[1] = 3; test_data[2] = 4; - test_data[3] = std::numeric_limits::max() - 2; - test_data[4] = std::numeric_limits::max() - 1; - test_data[5] = std::numeric_limits::max(); + test_data[3] = std::numeric_limits::max() - 2; + test_data[4] = std::numeric_limits::max() - 1; + test_data[5] = std::numeric_limits::max(); RunSingleTest(test_data, input_integer_bits, output_integer_bits, check_label + " / uniform test", tolerance); diff --git a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc index 72e4685d1e949a..3dfbd6930fe8c8 100644 --- a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -34,11 +34,11 @@ limitations under the License. namespace tflite { namespace { -void RunLogSoftmaxFloatReference(const uint8* input_data, +void RunLogSoftmaxFloatReference(const uint8_t* input_data, const RuntimeShape& shape_common, - int32 input_offset, const double input_scale, + int32_t input_offset, const double input_scale, int stride, float beta, - uint8* reference_output_data) { + uint8_t* reference_output_data) { const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); @@ -67,11 +67,11 @@ void RunLogSoftmaxFloatReference(const uint8* input_data, // - input and output data type // - Dequnatize function // - clamping values -void RunLogSoftmaxFloatReference(const int8* input_data, +void RunLogSoftmaxFloatReference(const int8_t* input_data, const RuntimeShape& shape_common, - int32 input_offset, const double input_scale, + int32_t input_offset, const double input_scale, int stride, float beta, - int8* reference_output_data) { + int8_t* reference_output_data) { const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); @@ -143,21 +143,22 @@ void CheckOutputData(const T* test_output, const T* reference_output, // Runs the LogSoftmax and compares against the float reference implementation // and the quantized reference implementation. -void RunOneLogSoftmaxTest(const uint8* input_data, - const RuntimeShape& shape_common, int32 input_offset, - const double input_scale, int stride, float beta) { +void RunOneLogSoftmaxTest(const uint8_t* input_data, + const RuntimeShape& shape_common, + int32_t input_offset, const double input_scale, + int stride, float beta) { const int buffer_size = shape_common.FlatSize(); - std::vector optimized_logsoftmax_output(buffer_size); - std::vector reference_float_logsoftmax_output(buffer_size); - std::vector reference_quant_logsoftmax_output(buffer_size); + std::vector optimized_logsoftmax_output(buffer_size); + std::vector reference_float_logsoftmax_output(buffer_size); + std::vector reference_quant_logsoftmax_output(buffer_size); RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_logsoftmax_output.data()); - int32 input_beta_multiplier; + int32_t input_beta_multiplier; int input_beta_left_shift; - int32 reverse_scaling_divisor; + int32_t reverse_scaling_divisor; int reverse_scaling_right_shift; static const int kScaledDiffIntegerBits = 5; tflite::PreprocessLogSoftmaxScalingExp( @@ -201,20 +202,22 @@ void RunOneLogSoftmaxTest(const uint8* input_data, // Runs the LogSoftmax and compares against the float reference implementation // and the int8 quantized reference implementation. -void RunOneLogSoftmaxTest(const int8* input_data, - const RuntimeShape& shape_common, int32 input_offset, - const double input_scale, int stride, float beta) { +void RunOneLogSoftmaxTest(const int8_t* input_data, + const RuntimeShape& shape_common, + int32_t input_offset, const double input_scale, + int stride, float beta) { const int buffer_size = shape_common.FlatSize(); - std::vector quantized_logsoftmax_reference_implementation(buffer_size); - std::vector float_logsoftmax_optimized_implementation(buffer_size); + std::vector quantized_logsoftmax_reference_implementation( + buffer_size); + std::vector float_logsoftmax_optimized_implementation(buffer_size); RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, float_logsoftmax_optimized_implementation.data()); - int32 input_beta_multiplier; + int32_t input_beta_multiplier; int input_beta_left_shift; - int32 reverse_scaling_divisor; + int32_t reverse_scaling_divisor; int reverse_scaling_right_shift; static const int kScaledDiffIntegerBits = 5; tflite::PreprocessLogSoftmaxScalingExp( @@ -258,7 +261,7 @@ bool TryOneUniformLogSoftmax() { const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); - const int32 input_offset = UniformRandomInt(-256, 0); + const int32_t input_offset = UniformRandomInt(-256, 0); static constexpr float beta = 1.0f; auto shape_common = @@ -291,7 +294,7 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); - const int32 input_offset = UniformRandomInt(-256, 0); + const int32_t input_offset = UniformRandomInt(-256, 0); static constexpr float beta = 1.0f; // Extra parameters for skyscraper input patterns. const double middle_proportion = @@ -303,7 +306,7 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { RuntimeShape({batch, input_height, input_width, input_depth}); const int buffer_size = shape_common.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, diff --git a/tensorflow/lite/kernels/internal/maxpool_quantized_test.cc b/tensorflow/lite/kernels/internal/maxpool_quantized_test.cc index 84afd3ddd52211..50b39085387b1b 100644 --- a/tensorflow/lite/kernels/internal/maxpool_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/maxpool_quantized_test.cc @@ -33,11 +33,12 @@ namespace { // Runs the reference and optimized MaxPool functions and asserts the values // are the same. void RunOneMaxPoolTest(const PoolParams& params, - const RuntimeShape& input_shape, const int8* input_data, + const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& output_shape) { const int buffer_size = output_shape.FlatSize(); - std::vector optimized_maxpool_output(buffer_size); - std::vector reference_maxpool_output(buffer_size); + std::vector optimized_maxpool_output(buffer_size); + std::vector reference_maxpool_output(buffer_size); reference_integer_ops::MaxPool(params, input_shape, input_data, output_shape, reference_maxpool_output.data()); @@ -80,7 +81,7 @@ void CreateDataAndRunMaxPool(bool padding_same) { auto output_shape = RuntimeShape({batch, output_height, output_width, output_depth}); const int buffer_size = input_shape.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandom(&input_data); PoolParams params; diff --git a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h index 502ecf0ee6426e..726a279bfaef13 100644 --- a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h @@ -117,111 +117,6 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, } } -inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, - const RuntimeShape& rhs_shape, const int8_t* rhs_data, - const float* scaling_factors, - const int32_t* input_offset, int32_t* row_sums, - const RuntimeShape& output_shape, - int32_t* accum_scratch, float* output_data, - bool* compute_row_sums, CpuBackendContext* context) { - using ::tflite::cpu_backend_gemm::Gemm; - using ::tflite::cpu_backend_gemm::GemmParams; - using ::tflite::cpu_backend_gemm::MatrixParams; - - const RuntimeShape extended_lhs_shape = - RuntimeShape::ExtendedShape(5, lhs_shape); - const RuntimeShape extended_rhs_shape = - RuntimeShape::ExtendedShape(5, rhs_shape); - - // Determine which dimension is the broadcast dimension. - auto broadcast_dim = [](int lhs_dim, int rhs_dim) { - if (lhs_dim == rhs_dim) return lhs_dim; - if (lhs_dim == 1) return rhs_dim; - TFLITE_DCHECK_EQ(rhs_dim, 1); - return lhs_dim; - }; - - // Compute the "extent" for iterating on this dimension. - // If we are broadcasting, then don't advance (i.e return 0). - auto extent = [](const RuntimeShape& shape, int x) { - if (shape.Dims(x) == 1) { - return 0; - } - int prod = 1; - for (int i = x + 1; i < shape.DimensionsCount(); ++i) { - prod *= shape.Dims(i); - } - return prod; - }; - - const int batch_dim0 = - broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); - const int batch_dim1 = - broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); - const int batch_dim2 = - broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); - - const int lhs_ext0 = extent(extended_lhs_shape, 0); - const int lhs_ext1 = extent(extended_lhs_shape, 1); - const int lhs_ext2 = extent(extended_lhs_shape, 2); - const int rhs_ext0 = extent(extended_rhs_shape, 0); - const int rhs_ext1 = extent(extended_rhs_shape, 1); - const int rhs_ext2 = extent(extended_rhs_shape, 2); - - // Set params for each matrix multiply. - const int lhs_rows = extended_lhs_shape.Dims(3); - const int rhs_cols = extended_rhs_shape.Dims(4); - const int accum_depth = extended_lhs_shape.Dims(4); - - const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols; - const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols; - const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols; - const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows; - const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows; - const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows; - - if (!compute_row_sums || *compute_row_sums) { - int num_weights_matrices = 1; - for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) { - num_weights_matrices *= extended_lhs_shape.Dims(i); - } - tensor_utils::ReductionSumVector( - lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth); - if (compute_row_sums) { - *compute_row_sums = false; - } - } - - for (int b0 = 0; b0 < batch_dim0; ++b0) { - const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); - const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); - const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0); - const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0); - int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0); - for (int b1 = 0; b1 < batch_dim1; ++b1) { - const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; - const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; - const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1); - const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1); - int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1); - for (int b2 = 0; b2 < batch_dim2; ++b2) { - const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; - const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; - const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2); - const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2); - int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2); - float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + - b1 * batch_dim2 + b2) * - lhs_rows * rhs_cols; - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - lhs_ptr2, lhs_rows, accum_depth, rhs_ptr2, scale_ptr2, rhs_cols, - out_ptr, /*per_channel_scale=*/nullptr, ioff_ptr2, accum_scratch, - woff_ptr2, compute_row_sums, context); - } - } - } -} - inline void BatchMatMul(const FullyConnectedParams& params, const RuntimeShape& lhs_shape, const int8_t* lhs_data, const RuntimeShape& rhs_shape, const int8_t* rhs_data, diff --git a/tensorflow/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/lite/kernels/internal/resize_bilinear_test.cc index f65127d029fce3..ee60b084edfbd2 100644 --- a/tensorflow/lite/kernels/internal/resize_bilinear_test.cc +++ b/tensorflow/lite/kernels/internal/resize_bilinear_test.cc @@ -55,7 +55,7 @@ void TestOneResizeBilinear(const tflite::ResizeBilinearParams& op_params, FillRandom(&input_data, min_amplitude, max_amplitude); RuntimeShape output_size_dims({1, 1, 1, 2}); - std::vector output_size_data = {output_height, output_width}; + std::vector output_size_data = {output_height, output_width}; reference_ops::ResizeBilinear(op_params, input_dims_inference, input_data.data(), output_size_dims, @@ -66,7 +66,7 @@ void TestOneResizeBilinear(const tflite::ResizeBilinearParams& op_params, output_size_data.data(), output_dims_inference, output_data.data()); bool strict_match = false; - if (std::is_same::value && ((depth % 8) == 0) && + if (std::is_same::value && ((depth % 8) == 0) && ((input_width * 8) == output_width) && ((input_height * 8) == output_height)) { strict_match = true; @@ -111,9 +111,9 @@ TEST_P(ResizeBilinearImplTest, TestResizeBilinearUint8) { const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); - TestOneResizeBilinear(op_params, batch, depth, input_width, - input_height, output_width, output_height, - 0.025); + TestOneResizeBilinear(op_params, batch, depth, input_width, + input_height, output_width, output_height, + 0.025); } } @@ -136,9 +136,9 @@ TEST_P(ResizeBilinearImplTest, TestResizeBilinearUint8_2x2) { // versions. error_threshold = 1e-3; } - TestOneResizeBilinear(op_params, batch, depth, input_width, - input_height, output_width, output_height, - error_threshold); + TestOneResizeBilinear(op_params, batch, depth, input_width, + input_height, output_width, output_height, + error_threshold); } } @@ -217,7 +217,7 @@ TEST(ResizeBilinear, TestResizeBilinearHalfPixelCentersFloat_3x3to2x2) { std::vector output_data(output_buffer_size, 3); RuntimeShape output_size_dims({1, 1, 1, 2}); - std::vector output_size_data = {2, 2}; + std::vector output_size_data = {2, 2}; tflite::ResizeBilinearParams op_params; op_params.align_corners = false; @@ -261,7 +261,7 @@ TEST(ResizeBilinear, TestResizeBilinearHalfPixelCentersFloat_2x2to4x4) { std::vector output_data(output_buffer_size, 3); RuntimeShape output_size_dims({1, 1, 1, 2}); - std::vector output_size_data = {4, 4}; + std::vector output_size_data = {4, 4}; tflite::ResizeBilinearParams op_params; op_params.align_corners = false; @@ -312,7 +312,7 @@ void TestResizeBilinearHalfPixelCenters_2x2to4x6() { std::vector output_data(output_buffer_size, 3); RuntimeShape output_size_dims({1, 1, 1, 2}); - std::vector output_size_data = {4, 6}; + std::vector output_size_data = {4, 6}; tflite::ResizeBilinearParams op_params; op_params.align_corners = false; @@ -394,9 +394,9 @@ TEST_P(ResizeBilinearImplX8ChannelTest, TestResizeBilinearX8ChannelUint8) { const int output_width = input_width * scale_factor; const int output_height = input_height * scale_factor; - TestOneResizeBilinear(op_params, batch, depth, input_width, - input_height, output_width, output_height, - 0.025); + TestOneResizeBilinear(op_params, batch, depth, input_width, + input_height, output_width, output_height, + 0.025); } } @@ -418,9 +418,9 @@ TEST_P(ResizeBilinearImplX8ChannelTest, TestResizeBilinearX8ChannelInt8) { const int output_width = input_width * scale_factor; const int output_height = input_height * scale_factor; - TestOneResizeBilinear(op_params, batch, depth, input_width, - input_height, output_width, output_height, - 0.025); + TestOneResizeBilinear(op_params, batch, depth, input_width, + input_height, output_width, output_height, + 0.025); } } diff --git a/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc b/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc index debeb36e48fb9e..31ff68cc3ec3c8 100644 --- a/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc @@ -28,7 +28,7 @@ namespace { template void TestReferenceResizeNearestNeighbor( const RuntimeShape& input_shape, const std::vector& input_data, - const std::vector& output_size_data, + const std::vector& output_size_data, const RuntimeShape& output_shape, const std::vector& expected_output_data, bool align_corners = false, bool half_pixel_centers = false) { @@ -48,7 +48,7 @@ void TestReferenceResizeNearestNeighbor( TEST(ResizeNearestNeighborReference, Test2x2To1x1) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {1, 1}; + std::vector output_size_data = {1, 1}; RuntimeShape output_shape = {1, 1, 1, 1}; std::vector output_data = {1}; @@ -59,7 +59,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To1x1) { TEST(ResizeNearestNeighborReference, Test2x2To1x1_AlignCorners) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {1, 1}; + std::vector output_size_data = {1, 1}; RuntimeShape output_shape = {1, 1, 1, 1}; std::vector output_data = {1}; @@ -71,7 +71,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To1x1_AlignCorners) { TEST(ResizeNearestNeighborReference, Test2x2To1x1_HalfPixelCenters) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {1, 1}; + std::vector output_size_data = {1, 1}; RuntimeShape output_shape = {1, 1, 1, 1}; std::vector output_data = {4}; @@ -82,10 +82,10 @@ TEST(ResizeNearestNeighborReference, Test2x2To1x1_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test2x2To3x3) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 1, 2, 1, 1, 2, 3, 3, 4}; + std::vector output_data = {1, 1, 2, 1, 1, 2, 3, 3, 4}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data); @@ -94,7 +94,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3) { TEST(ResizeNearestNeighborReference, Test2x2To3x3Int16) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; std::vector output_data = {1, 1, 2, 1, 1, 2, 3, 3, 4}; @@ -104,10 +104,10 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3Int16) { TEST(ResizeNearestNeighborReference, Test2x2To3x3_AlignCorners) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 2, 2, 3, 4, 4, 3, 4, 4}; + std::vector output_data = {1, 2, 2, 3, 4, 4, 3, 4, 4}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data, @@ -116,10 +116,10 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3_AlignCorners) { TEST(ResizeNearestNeighborReference, Test2x2To3x3_HalfPixelCenters) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 2, 2, 3, 4, 4, 3, 4, 4}; + std::vector output_data = {1, 2, 2, 3, 4, 4, 3, 4, 4}; TestReferenceResizeNearestNeighbor( input_shape, input_data, output_size_data, output_shape, output_data, @@ -129,7 +129,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test3x3To2x2) { RuntimeShape input_shape = {1, 3, 3, 1}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - std::vector output_size_data = {2, 2}; + std::vector output_size_data = {2, 2}; RuntimeShape output_shape = {1, 2, 2, 1}; std::vector output_data = {1, 2, 4, 5}; @@ -140,7 +140,7 @@ TEST(ResizeNearestNeighborReference, Test3x3To2x2) { TEST(ResizeNearestNeighborReference, Test3x3To2x2_AlignCorners) { RuntimeShape input_shape = {1, 3, 3, 1}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - std::vector output_size_data = {2, 2}; + std::vector output_size_data = {2, 2}; RuntimeShape output_shape = {1, 2, 2, 1}; std::vector output_data = {1, 3, 7, 9}; @@ -152,7 +152,7 @@ TEST(ResizeNearestNeighborReference, Test3x3To2x2_AlignCorners) { TEST(ResizeNearestNeighborReference, Test3x3To2x2_HalfPixelCenters) { RuntimeShape input_shape = {1, 3, 3, 1}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - std::vector output_size_data = {2, 2}; + std::vector output_size_data = {2, 2}; RuntimeShape output_shape = {1, 2, 2, 1}; std::vector output_data = {1, 3, 7, 9}; @@ -163,10 +163,10 @@ TEST(ResizeNearestNeighborReference, Test3x3To2x2_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test2x2To2x5) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {2, 5}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {2, 5}; RuntimeShape output_shape = {1, 2, 5, 1}; - std::vector output_data = {1, 1, 1, 2, 2, 3, 3, 3, 4, 4}; + std::vector output_data = {1, 1, 1, 2, 2, 3, 3, 3, 4, 4}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data); @@ -174,10 +174,10 @@ TEST(ResizeNearestNeighborReference, Test2x2To2x5) { TEST(ResizeNearestNeighborReference, Test2x2To2x5_HalfPixelCenters) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {2, 5}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {2, 5}; RuntimeShape output_shape = {1, 2, 5, 1}; - std::vector output_data = {1, 1, 2, 2, 2, 3, 3, 4, 4, 4}; + std::vector output_data = {1, 1, 2, 2, 2, 3, 3, 4, 4, 4}; TestReferenceResizeNearestNeighbor( input_shape, input_data, output_size_data, output_shape, output_data, @@ -186,11 +186,11 @@ TEST(ResizeNearestNeighborReference, Test2x2To2x5_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test4x4To3x3) { RuntimeShape input_shape = {1, 4, 4, 1}; - std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 2, 3, 5, 6, 7, 9, 10, 11}; + std::vector output_data = {1, 2, 3, 5, 6, 7, 9, 10, 11}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data); @@ -198,11 +198,11 @@ TEST(ResizeNearestNeighborReference, Test4x4To3x3) { TEST(ResizeNearestNeighborReference, Test4x4To3x3_AlignCorners) { RuntimeShape input_shape = {1, 4, 4, 1}; - std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 3, 4, 9, 11, 12, 13, 15, 16}; + std::vector output_data = {1, 3, 4, 9, 11, 12, 13, 15, 16}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data, @@ -211,11 +211,11 @@ TEST(ResizeNearestNeighborReference, Test4x4To3x3_AlignCorners) { TEST(ResizeNearestNeighborReference, Test4x4To3x3_HalfPixelCenters) { RuntimeShape input_shape = {1, 4, 4, 1}; - std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 3, 4, 9, 11, 12, 13, 15, 16}; + std::vector output_data = {1, 3, 4, 9, 11, 12, 13, 15, 16}; TestReferenceResizeNearestNeighbor( input_shape, input_data, output_size_data, output_shape, output_data, @@ -225,7 +225,7 @@ TEST(ResizeNearestNeighborReference, Test4x4To3x3_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test2x2To5x2) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {5, 2}; + std::vector output_size_data = {5, 2}; RuntimeShape output_shape = {1, 5, 2, 1}; std::vector output_data = {1, 2, 1, 2, 1, 2, 3, 4, 3, 4}; @@ -236,7 +236,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To5x2) { TEST(ResizeNearestNeighborReference, Test2x2To5x2_HalfPixelCenters) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {5, 2}; + std::vector output_size_data = {5, 2}; RuntimeShape output_shape = {1, 5, 2, 1}; std::vector output_data = {1, 2, 1, 2, 3, 4, 3, 4, 3, 4}; @@ -249,7 +249,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To5x2_HalfPixelCenters_AlignCorners) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {5, 2}; + std::vector output_size_data = {5, 2}; RuntimeShape output_shape = {1, 5, 2, 1}; std::vector output_data = {2, 2, 2, 2, 4, 4, 4, 4, 4, 4}; @@ -260,11 +260,11 @@ TEST(ResizeNearestNeighborReference, TEST(ResizeNearestNeighborReference, Test2x2To4x4) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {4, 4}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {4, 4}; RuntimeShape output_shape = {1, 4, 4, 1}; - std::vector output_data = {1, 1, 2, 2, 1, 1, 2, 2, - 3, 3, 4, 4, 3, 3, 4, 4}; + std::vector output_data = {1, 1, 2, 2, 1, 1, 2, 2, + 3, 3, 4, 4, 3, 3, 4, 4}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data); @@ -279,7 +279,7 @@ TEST(ResizeNearestNeighborReference, Test2x2x2x2To2x3x3x2) { RuntimeShape input_shape = {2, 2, 2, 2}; std::vector input_data = {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {2, 3, 3, 2}; // Output: // [ [ 1, 1 ], [ 1, 1 ], [ 2, 2 ], @@ -300,7 +300,7 @@ TEST(ResizeNearestNeighborReference, Test2x2x2x2To2x3x3x2_AlignCorners) { RuntimeShape input_shape = {2, 2, 2, 2}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {2, 3, 3, 2}; std::vector output_data = { 1, 2, 3, 4, 3, 4, 5, 6, 7, 8, 7, 8, 5, 6, 7, 8, 7, 8, @@ -316,7 +316,7 @@ TEST(ResizeNearestNeighborReference, Test2x2x2x2To2x3x3x2_HalfPixelCenters) { RuntimeShape input_shape = {2, 2, 2, 2}; std::vector input_data = {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {2, 3, 3, 2}; std::vector output_data = {1, 1, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4, 3, 3, 4, 4, 4, 4, 5, 5, 6, 6, 6, 6, @@ -332,7 +332,7 @@ TEST(ResizeNearestNeighborReference, RuntimeShape input_shape = {2, 2, 2, 2}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {2, 3, 3, 2}; std::vector output_data = {1, 2, 3, 4, 3, 4, 5, 6, 7, 8, 7, 8, 5, 6, 7, 8, 7, 8, 1, 2, 3, 4, 3, 4, @@ -351,14 +351,14 @@ void TestOptimizedResizeNearestNeighbor(int batch, int depth, int input_width, RuntimeShape input_shape({batch, input_height, input_width, depth}); RuntimeShape output_shape({batch, output_height, output_width, depth}); - std::vector input_data(input_shape.FlatSize(), 0); - FillRandom(&input_data, static_cast(0), static_cast(255)); + std::vector input_data(input_shape.FlatSize(), 0); + FillRandom(&input_data, static_cast(0), static_cast(255)); - std::vector reference_output_data(output_shape.FlatSize(), 0); + std::vector reference_output_data(output_shape.FlatSize(), 0); // Initialize the output data with something other than zero, so we can catch // issue with kernels failing to initialize the output. - std::vector output_data(output_shape.FlatSize(), 3); - std::vector output_size_data = {output_height, output_width}; + std::vector output_data(output_shape.FlatSize(), 3); + std::vector output_size_data = {output_height, output_width}; ResizeNearestNeighborParams op_params{/*align_corners=*/false, /*half_pixel_centers=*/false}; @@ -412,22 +412,22 @@ bool is_valid_scale(int input_width, int input_height, int output_width, const float width_scale_float = static_cast(input_width) / output_width; - int32 height_scale_int = (input_height << 16) / output_height + 1; - int32 width_scale_int = (input_width << 16) / output_width + 1; + int32_t height_scale_int = (input_height << 16) / output_height + 1; + int32_t width_scale_int = (input_width << 16) / output_width + 1; for (int y = 0; y < output_height; ++y) { - int32 in_y_float = - std::min(static_cast(std::floor(y * height_scale_float)), + int32_t in_y_float = + std::min(static_cast(std::floor(y * height_scale_float)), input_height - 1); - int32 in_y_int = std::min((y * height_scale_int) >> 16, input_height - 1); + int32_t in_y_int = std::min((y * height_scale_int) >> 16, input_height - 1); if (in_y_int != in_y_float) { return false; } for (int x = 0; x < output_width; ++x) { - int32 in_x_float = - std::min(static_cast(std::floor(x * width_scale_float)), + int32_t in_x_float = + std::min(static_cast(std::floor(x * width_scale_float)), input_width - 1); - int32 in_x_int = std::min((x * width_scale_int) >> 16, input_width - 1); + int32_t in_x_int = std::min((x * width_scale_int) >> 16, input_width - 1); if (in_x_int != in_x_float) { return false; } diff --git a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc index 9b5ef171eaf9b5..4f736225d3508a 100644 --- a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc @@ -32,11 +32,11 @@ limitations under the License. namespace tflite { namespace { -void RunSoftmaxFloatReference(const uint8* input_data, +void RunSoftmaxFloatReference(const uint8_t* input_data, const RuntimeShape& shape_common, - int32 input_offset, const double input_scale, + int32_t input_offset, const double input_scale, int stride, float beta, - uint8* reference_output_data) { + uint8_t* reference_output_data) { const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); @@ -103,18 +103,18 @@ void CheckOutputData(const T* test_output, const T* reference_output, // Runs the Softmax and compares against the float reference implementation and // the quantized reference implementation. -void RunOneSoftmaxTest(const uint8* input_data, - const RuntimeShape& shape_common, int32 input_offset, +void RunOneSoftmaxTest(const uint8_t* input_data, + const RuntimeShape& shape_common, int32_t input_offset, const double input_scale, int stride, float beta) { const int buffer_size = shape_common.FlatSize(); - std::vector optimized_softmax_output(buffer_size); - std::vector reference_float_softmax_output(buffer_size); - std::vector reference_quant_softmax_output(buffer_size); + std::vector optimized_softmax_output(buffer_size); + std::vector reference_float_softmax_output(buffer_size); + std::vector reference_quant_softmax_output(buffer_size); RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_softmax_output.data()); - int32 input_beta_multiplier; + int32_t input_beta_multiplier; int input_beta_left_shift; static const int kScaledDiffIntegerBits = 5; tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits, @@ -180,14 +180,14 @@ bool TryOneUniformSoftmax() { const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); - const int32 input_offset = UniformRandomInt(-256, 0); + const int32_t input_offset = UniformRandomInt(-256, 0); const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); auto shape_common = RuntimeShape({batch, input_height, input_width, input_depth}); const int buffer_size = shape_common.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandom(&input_data); RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); @@ -213,7 +213,7 @@ bool TryOneSkyscraperSoftmax(bool small_depth) { const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); - const int32 input_offset = UniformRandomInt(-256, 0); + const int32_t input_offset = UniformRandomInt(-256, 0); const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); // Extra parameters for skyscraper input patterns. const double middle_proportion = @@ -225,7 +225,7 @@ bool TryOneSkyscraperSoftmax(bool small_depth) { RuntimeShape({batch, input_height, input_width, input_depth}); const int buffer_size = shape_common.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, diff --git a/tensorflow/lite/kernels/internal/tensor_test.cc b/tensorflow/lite/kernels/internal/tensor_test.cc index d746d66dc94359..0006f385d7b863 100644 --- a/tensorflow/lite/kernels/internal/tensor_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_test.cc @@ -24,28 +24,28 @@ using ::testing::ElementsAre; TEST(TensorTest, GetTensorShape4D) { RuntimeShape d = GetTensorShape({2, 3, 4, 5}); EXPECT_THAT( - std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), + std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), ElementsAre(2, 3, 4, 5)); } TEST(TensorTest, GetTensorShape3D) { RuntimeShape d = GetTensorShape({3, 4, 5}); EXPECT_THAT( - std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), + std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), ElementsAre(3, 4, 5)); } TEST(TensorTest, GetTensorShape2D) { RuntimeShape d = GetTensorShape({4, 5}); EXPECT_THAT( - std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), + std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), ElementsAre(4, 5)); } TEST(TensorTest, GetTensorShape1D) { RuntimeShape d = GetTensorShape({5}); EXPECT_THAT( - std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), + std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), ElementsAre(5)); } diff --git a/tensorflow/lite/kernels/internal/test_util.h b/tensorflow/lite/kernels/internal/test_util.h index ec64590d0d3508..7e17170cfa57e5 100644 --- a/tensorflow/lite/kernels/internal/test_util.h +++ b/tensorflow/lite/kernels/internal/test_util.h @@ -93,8 +93,8 @@ void FillRandom(std::vector* vec) { // the depth) with higher values than the surround. template void FillRandomSkyscraper(std::vector* vec, int depth, - double middle_proportion, uint8 middle_min, - uint8 sides_max) { + double middle_proportion, uint8_t middle_min, + uint8_t sides_max) { for (auto base_it = std::begin(*vec); base_it != std::end(*vec); base_it += depth) { auto left_it = base_it + std::ceil(0.5 * depth * (1.0 - middle_proportion)); diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index 39f7bc7da53a49..2333caebce546a 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -456,6 +456,12 @@ std::string GetShapeDebugString(const TfLiteIntArray* shape) { return str; } +std::string GetTensorDebugString(const TfLiteTensor* tensor) { + return std::string("{\n type: ") + TfLiteTypeGetName(tensor->type) + + "\n data: {...}\n dims: " + GetShapeDebugString(tensor->dims) + + "\n}"; +} + TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, const TfLiteTensor* input1, const TfLiteTensor* input2, diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index e318118fb649f3..070f363b5a6412 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -310,6 +310,8 @@ TfLiteStatus GetOutputShapeFromInput(TfLiteContext* context, std::string GetShapeDebugString(const TfLiteIntArray* shape); +std::string GetTensorDebugString(const TfLiteTensor* tensor); + #endif // !defined(TF_LITE_STATIC_MEMORY) // Calculates the output_shape that is necessary for element-wise operations diff --git a/tensorflow/lite/kernels/variants/BUILD b/tensorflow/lite/kernels/variants/BUILD index 46c7755cef5248..531fc8bfe0f6eb 100644 --- a/tensorflow/lite/kernels/variants/BUILD +++ b/tensorflow/lite/kernels/variants/BUILD @@ -308,7 +308,7 @@ cc_library( srcs = ["tensor_array.cc"], hdrs = ["tensor_array.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//tensorflow/lite:__subpackages__"], + visibility = ["//visibility:private"], deps = [ "//tensorflow/lite:array", "//tensorflow/lite:util", diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 2389b3b8d393e3..403eb9549369a2 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -600,6 +600,7 @@ def build_conversion_flags( qdq_conversion_mode=None, disable_per_channel_quantization_for_dense_layers=False, enable_composite_direct_lowering=False, + model_origin_framework=lite_constants.UNSET, **_, ): """Builds protocol buffer describing a conversion of a model. @@ -731,6 +732,8 @@ def build_conversion_flags( layers. The flag works only for integer quantized model. enable_composite_direct_lowering: If set, attempts to lower composite ops directly to tflite ops. + model_origin_framework: A str specifying the framework of the original + model. Can be {TENSORFLOW, KERAS, JAX, PYTORCH} Returns: conversion_flags: protocol buffer describing the conversion process. @@ -854,6 +857,11 @@ def build_conversion_flags( conversion_flags.enable_composite_direct_lowering = ( enable_composite_direct_lowering ) + conversion_flags.model_origin_framework = ( + _conversion_flags_pb2.TocoFlags.ModelOriginFramework.Value( + model_origin_framework + ) + ) return conversion_flags diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 670340e8dba7fa..e49c63763c222c 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -310,15 +310,13 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): # Model must have at least 7 bytes to hold model identifier def testTooShortModelContent(self): - with self.assertRaisesRegex( - ValueError, - 'Model provided must have at least 7 bytes to hold identifier.', - ): + with self.assertRaisesRegex(ValueError, + 'The model is not a valid Flatbuffer buffer'): interpreter_wrapper.Interpreter(model_content=b'short') def testInvalidModelContent(self): with self.assertRaisesRegex(ValueError, - 'Model provided has model identifier \''): + 'The model is not a valid Flatbuffer buffer'): interpreter_wrapper.Interpreter(model_content=b'wrong_identifier') def testInvalidModelFile(self): diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 7ab81eec5d58fd..a14d6dcc9e2121 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include #include #include #include @@ -745,12 +746,32 @@ PyObject* InterpreterWrapper::GetTensor(int tensor_index, tensor->type != kTfLiteVariant) { // Make a buffer copy but we must tell Numpy It owns that data or else // it will leak. - void* data = malloc(tensor->bytes); + size_t numpy_bytes = tensor->bytes; + if (tensor->type == kTfLiteInt4) { + // Numpy doesn't have int4 type, so we double the size of the buffer + // to hold int8 type for each (4-bit packed) element. + numpy_bytes *= 2; + } + void* data = malloc(numpy_bytes); if (!data) { PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); return nullptr; } - memcpy(data, tensor->data.raw, tensor->bytes); + if (tensor->type == kTfLiteInt4) { + int8_t* tensor_data = reinterpret_cast(tensor->data.raw); + int8_t* numpy_data = static_cast(data); + // Unpack each 4-bit value to an 8-bit container. + for (size_t i = 0; i < tensor->bytes; i++) { + int8_t byte = tensor_data[i]; + int8_t lower = static_cast(byte << 4) >> 4; + int8_t upper = static_cast(byte >> 4); + numpy_data[2 * i] = lower; + numpy_data[2 * i + 1] = upper; + } + } else { + memcpy(data, tensor->data.raw, tensor->bytes); + } + PyObject* np_array; if (tensor->sparsity == nullptr) { np_array = @@ -866,7 +887,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( return nullptr; } std::unique_ptr model = - Model::BuildFromBuffer(buf, length, error_reporter.get()); + Model::VerifyAndBuildFromBuffer(buf, length, /*extra_verifier=*/nullptr, + error_reporter.get()); return CreateInterpreterWrapper( std::move(model), op_resolver_id, std::move(error_reporter), registerers_by_name, registerers_by_func, error_msg, preserve_all_tensors, diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 2005f80d03bc12..117abe593de4bf 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -674,6 +674,7 @@ def __init__(self): self._experimental_qdq_conversion_mode = None self._experimental_disable_per_channel_quantization_for_dense_layers = False self._experimental_enable_composite_direct_lowering = False + self.model_origin_framework = constants.UNSET # Debug parameters self.ir_dump_dir = None @@ -836,6 +837,7 @@ def _get_base_converter_args(self): "enable_composite_direct_lowering": ( self._experimental_enable_composite_direct_lowering ), + "model_origin_framework": self.model_origin_framework, } if self.saved_model_dir: diff --git a/tensorflow/lite/python/lite_constants.py b/tensorflow/lite/python/lite_constants.py index 4700a5920b57c0..843c2225eb6f2f 100644 --- a/tensorflow/lite/python/lite_constants.py +++ b/tensorflow/lite/python/lite_constants.py @@ -31,6 +31,21 @@ TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF TFLITE = _toco_flags_pb2.TFLITE GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT +UNSET = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.UNSET +) +TENSORFLOW = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.TENSORFLOW +) +KERAS = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.KERAS +) +JAX = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.JAX +) +PYTORCH = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.PYTORCH +) _tf_export(v1=["lite.constants.FLOAT"]).export_constant(__name__, "FLOAT") _tf_export(v1=["lite.constants.FLOAT16"]).export_constant(__name__, "FLOAT16") @@ -65,6 +80,11 @@ "TENSORFLOW_GRAPHDEF", "TFLITE", "GRAPHVIZ_DOT", + "UNSET", + "TENSORFLOW", + "KERAS", + "JAX", + "PYTORCH", "EXPERIMENTAL_USE_TOCO_API_DIRECTLY", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/lite/stateful_error_reporter.h b/tensorflow/lite/stateful_error_reporter.h index cf6693431f9118..10dc09646cb273 100644 --- a/tensorflow/lite/stateful_error_reporter.h +++ b/tensorflow/lite/stateful_error_reporter.h @@ -15,9 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ #define TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ +// LINT.IfChange #include -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace tflite { @@ -30,5 +31,6 @@ class StatefulErrorReporter : public ErrorReporter { }; } // namespace tflite +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/stateful_error_reporter.h) #endif // TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/testdata/no_signatures.bin b/tensorflow/lite/testdata/no_signatures.bin new file mode 100644 index 00000000000000..1a6f71b7936722 Binary files /dev/null and b/tensorflow/lite/testdata/no_signatures.bin differ diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index e7f4fdafd7ad1e..5d3976fb9ee5af 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -217,6 +217,33 @@ cc_library( ], ) +cc_library( + name = "matchers", + testonly = True, + srcs = ["matchers.h"], + hdrs = ["matchers.h"], + deps = [ + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "matchers_test", + srcs = ["matchers_test.cc"], + deps = [ + ":matchers", + "//tensorflow/lite/core/c:c_api_types", + "//tensorflow/lite/core/c:common", + "@com_google_absl//absl/base", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "message", srcs = ["message.cc"], diff --git a/tensorflow/lite/testing/matchers.h b/tensorflow/lite/testing/matchers.h new file mode 100644 index 00000000000000..604b3dd9ff6cfe --- /dev/null +++ b/tensorflow/lite/testing/matchers.h @@ -0,0 +1,272 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_TESTING_MATCHERS_H_ +#define TENSORFLOW_LITE_TESTING_MATCHERS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +// gMock matchers for TfLiteTensors. +// +// EXPECT_THAT(a, EqualsTensor(b)); +// EXPECT_THAT(a, Approximately(EqualsTensor(b))); +// EXPECT_THAT(a, Approximately(EqualsTensor(b), /*margin*/)); +// EXPECT_THAT(a, Approximately(EqualsTensor(b), /*margin=*/0, /*fraction*/)); +// +// TODO: who/impjdi - Expand to more dtypes than just float. +// TODO: who/impjdi - Add cross-dtype matchers. + +inline void PrintTo(const TfLiteTensor& tensor, std::ostream* os) { + *os << "\n" << ::tflite::GetTensorDebugString(&tensor); +} + +namespace testing { +namespace tflite { +namespace internal { + +enum class FloatComparison { kExact, kApproximate }; + +struct TensorComparison { + FloatComparison float_comp = FloatComparison::kExact; + bool custom_margin = false; + bool custom_fraction = false; + double margin = 0.0; // only used if custom_margin == true + double fraction = 0.0; // only used if custom_fraction == true +}; + +class TensorMatcher { + public: + TensorMatcher(const TensorComparison& comp, const TfLiteTensor& expected) + : comp_(comp), expected_(expected) {} + + bool MatchAndExplain(const TfLiteTensor& actual, + MatchResultListener* listener) const { + const bool match = Match(actual); + if (listener->IsInterested() && !match) *listener << DescribeDiff(actual); + return match; + } + + void DescribeTo(std::ostream* os) const { Describe(os, "is "); } + void DescribeNegationTo(std::ostream* os) const { Describe(os, "is not "); } + + void SetCompareApproximately() { + comp_.float_comp = FloatComparison::kApproximate; + } + + void SetMargin(double margin) { + ABSL_QCHECK_GE(margin, 0.0) // Crash OK + << "Using a negative margin for Approximately"; + comp_.custom_margin = true; + comp_.margin = margin; + } + + void SetFraction(double fraction) { + ABSL_QCHECK(0.0 <= fraction && fraction < 1.0) // Crash OK + << "Fraction for Approximately must be >= 0.0 and < 1.0"; + comp_.custom_fraction = true; + comp_.fraction = fraction; + } + + private: + static std::string TensorIndex(int index, const TfLiteIntArray* dims) { + if (!dims->size) return ""; + std::vector index_nd(dims->size); + for (int i = dims->size - 1; i >= 0; --i) { + index_nd[i] = index % dims->data[i]; + index /= dims->data[i]; + } + return absl::StrCat("[", absl::StrJoin(index_nd, "]["), "]"); + } + + bool CompareFloat(float x, float y) const { + switch (comp_.float_comp) { + case FloatComparison::kExact: + return x == y; + case FloatComparison::kApproximate: + if (x == y) return true; + float fraction, margin; + if (comp_.custom_margin || comp_.custom_fraction) { + fraction = comp_.fraction; + margin = comp_.margin; + } else { + constexpr float kEpsilon = 32 * FLT_EPSILON; + if (std::fabs(x) <= kEpsilon && std::fabs(y) <= kEpsilon) return true; + fraction = kEpsilon; + margin = kEpsilon; + } + if (!std::isfinite(x) || !std::isfinite(y)) return false; + float relative_margin = fraction * std::max(std::fabs(x), std::fabs(y)); + return std::fabs(x - y) <= std::max(margin, relative_margin); + } + return false; + } + + void Describe(std::ostream* os, std::string_view prefix) const { + *os << prefix; + if (comp_.float_comp == FloatComparison::kApproximate) { + *os << "approximately "; + if (comp_.custom_margin || comp_.custom_fraction) { + *os << "("; + if (comp_.custom_margin) { + std::stringstream ss; + ss << std::setprecision(std::numeric_limits::digits10 + 2) + << comp_.margin; + *os << "absolute error of float values <= " << ss.str(); + } + if (comp_.custom_margin && comp_.custom_fraction) { + *os << " or "; + } + if (comp_.custom_fraction) { + std::stringstream ss; + ss << std::setprecision(std::numeric_limits::digits10 + 2) + << comp_.fraction; + *os << "relative error of float values <= " << ss.str(); + } + *os << ") "; + } + } + *os << "equal to "; + PrintTo(expected_, os); + } + + std::string DescribeDiff(const TfLiteTensor& actual) const { + if (actual.type != expected_.type) { + return absl::StrCat( + "dtypes don't match: ", TfLiteTypeGetName(actual.type), " vs ", + TfLiteTypeGetName(expected_.type)); + } + if (!actual.dims) return "actual.dims is null."; + if (!expected_.dims) return "expected.dims is null."; + if (actual.dims->size != expected_.dims->size) { + return absl::StrCat("dims don't match: ", actual.dims->size, "D vs ", + expected_.dims->size, "D"); + } + if (int n = actual.dims->size; + std::memcmp(actual.dims->data, expected_.dims->data, n * sizeof(int))) { + return absl::StrCat( + "shapes don't match: ", ::tflite::GetShapeDebugString(actual.dims), + " vs ", ::tflite::GetShapeDebugString(expected_.dims)); + } + if (!actual.data.raw) return "actual.data is null."; + if (!expected_.data.raw) return "expected.data is null."; + if (actual.bytes != expected_.bytes) { + return absl::StrCat("bytes don't match: ", actual.bytes, " vs ", + expected_.bytes); + } + std::string error = "\n"; + TfLiteIntArray* dims = actual.dims; + int n = ::tflite::NumElements(dims); + constexpr int kMaxMismatches = 20; + for (int i = 0, j = 0; i < n; ++i) { + if (!CompareFloat(actual.data.f[i], expected_.data.f[i])) { + absl::StrAppend(&error, "data", TensorIndex(i, dims), + " don't match: ", actual.data.f[i], " vs ", + expected_.data.f[i], "\n"); + ++j; + } + if (j == kMaxMismatches) { + absl::StrAppend(&error, "Too many mismatches; stopping after ", j, + ".\n"); + break; + } + } + return error; + } + + bool Match(const TfLiteTensor& actual) const { + if (actual.type != expected_.type) return false; + if (!actual.dims) return false; + if (!expected_.dims) return false; + if (actual.dims->size != expected_.dims->size) return false; + if (int n = actual.dims->size; + std::memcmp(actual.dims->data, expected_.dims->data, n * sizeof(int))) { + return false; + } + if (!actual.data.raw) return false; + if (!expected_.data.raw) return false; + if (actual.bytes != expected_.bytes) return false; + switch (comp_.float_comp) { + case FloatComparison::kExact: + if (int n = actual.bytes; + std::memcmp(actual.data.raw, expected_.data.raw, n)) { + return false; + } + break; + case FloatComparison::kApproximate: + for (int i = 0, n = ::tflite::NumElements(actual.dims); i < n; ++i) { + if (!CompareFloat(actual.data.f[i], expected_.data.f[i])) { + return false; + } + } + break; + }; + return true; + } + + TensorComparison comp_; + TfLiteTensor expected_; +}; + +} // namespace internal + +inline PolymorphicMatcher EqualsTensor( + const TfLiteTensor& expected) { + internal::TensorComparison comp; + return MakePolymorphicMatcher(internal::TensorMatcher(comp, expected)); +} + +template +inline InnerTensorMatcherT Approximately(InnerTensorMatcherT m) { + m.mutable_impl().SetCompareApproximately(); + return m; +} + +template +inline InnerTensorMatcherT Approximately(InnerTensorMatcherT m, double margin) { + m.mutable_impl().SetCompareApproximately(); + m.mutable_impl().SetMargin(margin); + return m; +} + +template +inline InnerTensorMatcherT Approximately(InnerTensorMatcherT m, double margin, + double fraction) { + m.mutable_impl().SetCompareApproximately(); + m.mutable_impl().SetMargin(margin); + m.mutable_impl().SetFraction(fraction); + return m; +} + +} // namespace tflite +} // namespace testing + +#endif // TENSORFLOW_LITE_TESTING_MATCHERS_H_ diff --git a/tensorflow/lite/testing/matchers_test.cc b/tensorflow/lite/testing/matchers_test.cc new file mode 100644 index 00000000000000..bae6cff1af3a08 --- /dev/null +++ b/tensorflow/lite/testing/matchers_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/testing/matchers.h" + +#include +#include +#include + +#include +#include +#include "absl/base/casts.h" +#include "absl/types/span.h" +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/core/c/common.h" + +namespace tflite { +namespace { + +// A wrapper of TfLiteTensor that frees dims at destruction. +struct Tensor : public TfLiteTensor { + template + Tensor(TfLiteType dtype, const std::vector& shape, absl::Span buf) { + type = dtype; + dims = TfLiteIntArrayCreate(shape.size()); + std::memcpy(dims->data, shape.data(), shape.size() * sizeof(int)); + data = {.data = buf.data()}; + bytes = buf.size() * sizeof(T); + } + ~Tensor() { TfLiteIntArrayFree(dims); } +}; + +// Delegate pretty print to PrintTo(TfLiteTensor&). +void PrintTo(const Tensor& tensor, std::ostream* os) { // NOLINT + PrintTo(absl::implicit_cast(tensor), os); +} + +using ::testing::tflite::Approximately; +using ::testing::tflite::EqualsTensor; + +TEST(TensorMatcherTest, ExactlyEqualsSelf) { + float data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(data)); + EXPECT_THAT(a, EqualsTensor(a)); +} + +TEST(TensorMatcherTest, ExactlyEqualsSame) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.71828f, 3.14159f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, EqualsTensor(b)); +} + +TEST(TensorMatcherTest, DoesNotExactlyEqualDifferentType) { + float data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(data)); + Tensor b(TfLiteType::kTfLiteInt32, {1, 2}, absl::MakeSpan(data)); + EXPECT_THAT(a, Not(EqualsTensor(b))); +} + +TEST(TensorMatcherTest, DoesNotExactlyEqualDifferentDims) { + float data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(data)); + Tensor b(TfLiteType::kTfLiteFloat32, {2, 1}, absl::MakeSpan(data)); + EXPECT_THAT(a, Not(EqualsTensor(b))); +} + +TEST(TensorMatcherTest, DoesNotExactlyEqualDifferentData) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {3.14159f, 2.71828f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Not(EqualsTensor(b))); +} + +TEST(TensorMatcherTest, ApproximatelyEqualsDefaultMargin) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.718277f, 3.141593f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Approximately(EqualsTensor(b))); +} + +TEST(TensorMatcherTest, ApproximatelyEqualsWithLooseMargin) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.72f, 3.14f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Approximately(EqualsTensor(b), /*margin=*/0.01)); +} + +TEST(TensorMatcherTest, DoesNotApproximatelyEqualWithTightMargin) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.72f, 3.14f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Not(Approximately(EqualsTensor(b), /*margin=*/0.001))); +} + +TEST(TensorMatcherTest, ApproximatelyEqualsWithLooseFraction) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.72f, 3.14f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT( + a, Approximately(EqualsTensor(b), /*margin=*/0.0, /*fraction=*/0.999)); +} + +TEST(TensorMatcherTest, DoesNotApproximatelyEqualWithTightFraction) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.72f, 3.14f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Not(Approximately(EqualsTensor(b), /*margin=*/0.0, + /*fraction=*/0.0001))); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD index 6bd43f8091d32f..7377ec00d6b666 100644 --- a/tensorflow/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -112,6 +112,7 @@ cc_library( deps = [ ":operator", ":types", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantize_weights", "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:status", @@ -122,7 +123,6 @@ cc_library( "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:toco_port", "//tensorflow/lite/toco:tooling_util", - "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index e9124e89f2a892..44223eac63c130 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -23,6 +23,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/lite/tools/versioning/runtime_version.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" @@ -670,19 +670,19 @@ tensorflow::Status Export( flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); const uint8_t* buffer = builder.GetBufferPointer(); const ::tflite::Model* input_model = ::tflite::GetModel(buffer); - ::tflite::optimize::BufferType quantized_type; + ::mlir::lite::toco_legacy::BufferType quantized_type; if (params.quantize_weights == QuantizedBufferType::INT8) { - quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8; + quantized_type = ::mlir::lite::toco_legacy::BufferType::QUANTIZED_INT8; } else if (params.quantize_weights == QuantizedBufferType::FLOAT16) { - quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16; + quantized_type = ::mlir::lite::toco_legacy::BufferType::QUANTIZED_FLOAT16; } else { return tensorflow::errors::InvalidArgument( "Quantized type not recognized"); } - if (!::tflite::optimize::QuantizeWeights( + if (!::mlir::lite::toco_legacy::QuantizeWeights( &q_builder, input_model, quantized_type, !params.disable_per_channel, - ::tflite::optimize::QuantizerType::OLD_QUANTIZER) + ::mlir::lite::toco_legacy::QuantizerType::OLD_QUANTIZER) .ok()) { return tensorflow::errors::InvalidArgument( "Quantize weights transformation failed."); diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index 1760841a333f6a..ac5ed8c3ef6ae2 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -41,7 +41,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 64. +// Next ID to use: 65. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -360,4 +360,16 @@ message TocoFlags { // Enables the attempt to directly lower composites into tflite ops. // WARNING: Experimental interface, subject to change. optional bool enable_composite_direct_lowering = 63 [default = false]; + + // The source model framework. + enum ModelOriginFramework { + UNSET = 0; + TENSORFLOW = 1; + KERAS = 2; + JAX = 3; + PYTORCH = 4; + } + + // The source model type. + optional ModelOriginFramework model_origin_framework = 64 [default = UNSET]; } diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index b08a2d913b6ec7..b9260be0b9eac3 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -414,6 +414,19 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + copts = tflite_copts(), + deps = [ + ":utils", + "//tensorflow/lite/c:common", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 63aae4ff6b0029..b26dbde5a742d9 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -162,6 +162,7 @@ cc_library( ":benchmark_model_lib", ":benchmark_utils", ":profiling_listener", + "//tensorflow/core/example:example_protos_cc_impl", "//tensorflow/lite:framework", "//tensorflow/lite:simple_memory_arena_debug_dump", "//tensorflow/lite:string_util", @@ -180,6 +181,7 @@ cc_library( "//tensorflow/lite/tools/delegates:tflite_execution_providers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@ruy//ruy/profiler", ], ) diff --git a/tensorflow/lite/tools/benchmark/CMakeLists.txt b/tensorflow/lite/tools/benchmark/CMakeLists.txt index 56794382ff45a8..eb0862f58aea00 100644 --- a/tensorflow/lite/tools/benchmark/CMakeLists.txt +++ b/tensorflow/lite/tools/benchmark/CMakeLists.txt @@ -47,6 +47,8 @@ list(APPEND TFLITE_BENCHMARK_LIBS list(APPEND TFLITE_BENCHMARK_LIBS profiling_info_proto + feature_proto + example_proto protobuf::libprotobuf ) diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md index e92d841b9c6a87..4b2f82fed258d0 100644 --- a/tensorflow/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -90,6 +90,15 @@ and the following optional parameters: and the path to include the name of the output CSV; otherwise results are printed to `stdout`. +* `output_filepath`: `str` (default="") \ + File path to save output tensor data to. If specified, the output tensor + values are saved as binary data in the file. + +* `output_proto_filepath`: `str` (default="") \ + File path to save output tensor data as tensorflow example proto. If + specified, the output tensor values are saved in tensorflow example and then + serialized to the file. + * `print_preinvoke_state`: `bool` (default=false) \ Whether to print out the TfLite interpreter internals just before calling tflite::Interpreter::Invoke. The internals will include allocated memory diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 8fb5b23b7860d9..15a18a6f5c4196 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -36,7 +37,10 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "ruy/profiler/profiler.h" // from @ruy +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/kernels/register.h" @@ -87,6 +91,49 @@ const char* kOpProfilingOutputModes[] = {kOpProfilingOutputModeStdout, kOpProfilingOutputModeCsv, kOpProfilingOutputModeProto}; +// Sets feature values in the tensorflow::Example proto from the tflite tensor. +// Returns an error if the tensor type is not supported or the tensor dime is a +// nullptr. +TfLiteStatus MaybeSetFeatureValuesFromTensor(const TfLiteTensor& tensor, + tensorflow::Example& example) { + if (tensor.dims == nullptr) { + return kTfLiteError; + } + + int total_elements = 1; + for (int i = 0; i < tensor.dims->size; i++) { + total_elements *= tensor.dims->data[i]; + } + tensorflow::Feature& feature = + (*example.mutable_features()->mutable_feature())[tensor.name]; + switch (tensor.type) { + case kTfLiteFloat32: + case kTfLiteFloat64: + feature.mutable_float_list()->mutable_value()->Resize(total_elements, 0); + return utils::TfLiteTensorToFloat32Array( + tensor, + absl::MakeSpan( + feature.mutable_float_list()->mutable_value()->mutable_data(), + feature.float_list().value_size())); + case kTfLiteUInt8: + case kTfLiteInt8: + case kTfLiteUInt16: + case kTfLiteInt16: + case kTfLiteInt32: + case kTfLiteUInt32: + case kTfLiteUInt64: + case kTfLiteInt64: + feature.mutable_int64_list()->mutable_value()->Resize(total_elements, 0); + return utils::TfLiteTensorToInt64Array( + tensor, + absl::MakeSpan( + feature.mutable_int64_list()->mutable_value()->mutable_data(), + feature.int64_list().value_size())); + default: + return kTfLiteError; + } +} + // Dumps ruy profiling events if the ruy profiler is enabled. class RuyProfileListener : public BenchmarkListener { public: @@ -153,17 +200,37 @@ class OutputSaver : public BenchmarkListener { } void OnBenchmarkEnd(const BenchmarkResults& results) override { - std::string path = params_->Get("output_filepath"); - if (path.empty()) return; + // If the output_filepath is specified, save the output tensors to the file. + const std::string path = params_->Get("output_filepath"); + if (!path.empty()) { + std::ofstream ofs(path, std::ofstream::out); + if (ofs.good()) { + for (int i = 0; i < interpreter_runner_->outputs().size(); i++) { + int tensor_index = interpreter_runner_->outputs()[i]; + ofs.write(interpreter_runner_->tensor(tensor_index)->data.raw, + interpreter_runner_->tensor(tensor_index)->bytes); + } + ofs.close(); + } + } - std::ofstream ofs(path, std::ofstream::out); - if (ofs.good()) { + // If the output_proto_filepath is specified, save the output tensors as + // tensorflow::Example proto and serialize it to the file. + const std::string output_proto_path = + params_->Get("output_proto_filepath"); + if (!output_proto_path.empty()) { + tensorflow::Example example; for (int i = 0; i < interpreter_runner_->outputs().size(); i++) { - int tensor_index = interpreter_runner_->outputs()[i]; - ofs.write(interpreter_runner_->tensor(tensor_index)->data.raw, - interpreter_runner_->tensor(tensor_index)->bytes); + const int tensor_index = interpreter_runner_->outputs()[i]; + const TfLiteTensor& tensor = + *(interpreter_runner_->tensor(tensor_index)); + MaybeSetFeatureValuesFromTensor(tensor, example); + } + std::ofstream ofs(output_proto_path, std::ios::out | std::ios::binary); + if (ofs.good()) { + example.SerializeToOstream(&ofs); + ofs.close(); } - ofs.close(); } } @@ -518,6 +585,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { BenchmarkParam::Create(false)); default_params.AddParam("output_filepath", BenchmarkParam::Create("")); + default_params.AddParam("output_proto_filepath", + BenchmarkParam::Create("")); default_params.AddParam("tensor_name_display_length", BenchmarkParam::Create(25)); @@ -622,6 +691,9 @@ std::vector BenchmarkTfLiteModel::GetFlags() { CreateFlag( "output_filepath", ¶ms_, "File path to export outputs layer as binary data."), + CreateFlag( + "output_proto_filepath", ¶ms_, + "File path to export outputs layer as tf example proto."), CreateFlag( "tensor_name_display_length", ¶ms_, "The number of characters to show for the tensor's name when " @@ -700,6 +772,9 @@ void BenchmarkTfLiteModel::LogParams() { "Constant CAST output cache", verbose); LOG_BENCHMARK_PARAM(std::string, "output_filepath", "File path to export outputs layer to", verbose); + LOG_BENCHMARK_PARAM(std::string, "output_proto_filepath", + "File path to export outputs layer as tf example to", + verbose); LOG_BENCHMARK_PARAM(int32_t, "tensor_name_display_length", "Tensor name display length", verbose); LOG_BENCHMARK_PARAM(int32_t, "tensor_type_display_length", diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index 8a21ae38052b17..71bfa0de5fa4a3 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG d25d603e0b708d856e4cafca7dac1e6b7126c320 + GIT_TAG 9ddeb74f9f6866174d61888947e4aa9ffe963b1b GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/lite/tools/utils.cc b/tensorflow/lite/tools/utils.cc index 12396ed7c3ce05..b8c18b24c5cce6 100644 --- a/tensorflow/lite/tools/utils.cc +++ b/tensorflow/lite/tools/utils.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include +#include #include +#include "absl/types/span.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -55,6 +57,30 @@ inline InputTensorData CreateInputTensorData(int num_elements, return tmp; } +// Converts a TfLiteTensor to a float array. Returns an error if the tensor +// dimension is a null pointer. +template +TfLiteStatus ConvertToArray(const TfLiteTensor& tflite_tensor, + absl::Span& values) { + if (tflite_tensor.dims == nullptr) { + return kTfLiteError; + } + + int total_elements = 1; + for (int i = 0; i < tflite_tensor.dims->size; i++) { + total_elements *= tflite_tensor.dims->data[i]; + } + if (total_elements != values.size()) { + return kTfLiteError; + } + const TensorType* tensor_data = + reinterpret_cast(tflite_tensor.data.data); + for (int i = 0; i < total_elements; i++) { + values[i] = static_cast(tensor_data[i]); + } + return kTfLiteOk; +} + } // namespace InputTensorData CreateRandomTensorData(const TfLiteTensor& tensor, @@ -168,5 +194,41 @@ void GetDataRangesForType(TfLiteType type, float* low_range, } } +TfLiteStatus TfLiteTensorToFloat32Array(const TfLiteTensor& tensor, + absl::Span values) { + switch (tensor.type) { + case kTfLiteFloat32: + return ConvertToArray(tensor, values); + case kTfLiteFloat64: + return ConvertToArray(tensor, values); + default: + return kTfLiteError; + } +} + +TfLiteStatus TfLiteTensorToInt64Array(const TfLiteTensor& tensor, + absl::Span values) { + switch (tensor.type) { + case kTfLiteUInt8: + return ConvertToArray(tensor, values); + case kTfLiteInt8: + return ConvertToArray(tensor, values); + case kTfLiteUInt16: + return ConvertToArray(tensor, values); + case kTfLiteInt16: + return ConvertToArray(tensor, values); + case kTfLiteInt32: + return ConvertToArray(tensor, values); + case kTfLiteUInt32: + return ConvertToArray(tensor, values); + case kTfLiteUInt64: + return ConvertToArray(tensor, values); + case kTfLiteInt64: + return ConvertToArray(tensor, values); + default: + return kTfLiteError; + } +} + } // namespace utils } // namespace tflite diff --git a/tensorflow/lite/tools/utils.h b/tensorflow/lite/tools/utils.h index 2fc9c62de119d2..12d69e29dc2dd6 100644 --- a/tensorflow/lite/tools/utils.h +++ b/tensorflow/lite/tools/utils.h @@ -16,8 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_UTILS_H_ #define TENSORFLOW_LITE_TOOLS_UTILS_H_ +#include #include +#include +#include "absl/types/span.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" namespace tflite { @@ -43,6 +47,14 @@ InputTensorData CreateRandomTensorData(const TfLiteTensor& tensor, // benchmarking and/or testing purposes. void GetDataRangesForType(TfLiteType type, float* low_range, float* high_range); +// Converts TfLiteTensor to float array. Returns an error if the tensor type is +// not supported or the values size is not equal to the tensor dimension. +TfLiteStatus TfLiteTensorToFloat32Array(const TfLiteTensor& tensor, + absl::Span values); + +// Same as above, but converts to int64_t array. +TfLiteStatus TfLiteTensorToInt64Array(const TfLiteTensor& tensor, + absl::Span values); } // namespace utils } // namespace tflite diff --git a/tensorflow/lite/tools/utils_test.cc b/tensorflow/lite/tools/utils_test.cc new file mode 100644 index 00000000000000..ce519827aaf12f --- /dev/null +++ b/tensorflow/lite/tools/utils_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/tools/utils.h" + +#include + +#include +#include + +#include +#include +#include "absl/types/span.h" +#include "tensorflow/lite/c/common.h" + +namespace tflite::tools { +namespace { +using ::testing::FloatEq; + +// Helper function to test TfLiteTensorToFloat32Array. +template +void TestTfLiteTensorToFloat32Array(TfLiteType type) { + T data[] = {1, 2, 3, 4}; + TfLiteTensor tensor; + tensor.data.data = data; + tensor.type = type; + // Create an int array with 1 dimension and the array size is 4. + tensor.dims = TfLiteIntArrayCreate(1); + tensor.dims->data[0] = 4; + std::vector result(4, 0.0); + const auto status = + utils::TfLiteTensorToFloat32Array(tensor, absl::MakeSpan(result)); + TfLiteIntArrayFree(tensor.dims); + ASSERT_EQ(status, kTfLiteOk); + ASSERT_EQ(result.size(), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_THAT(result[i], FloatEq(static_cast(data[i]))); + } +} + +// Helper function to test TfLiteTensorToFloat32Array. +template +void TestTfLiteTensorToInt64Array(TfLiteType type) { + T data[] = {1, 2, 3, 4}; + TfLiteTensor tensor; + tensor.data.data = data; + tensor.type = type; + // Create an int array with 1 dimension and the array size is 4. + tensor.dims = TfLiteIntArrayCreate(1); + tensor.dims->data[0] = 4; + std::vector result(4, 0); + const auto status = + utils::TfLiteTensorToInt64Array(tensor, absl::MakeSpan(result)); + TfLiteIntArrayFree(tensor.dims); + ASSERT_EQ(status, kTfLiteOk); + ASSERT_EQ(result.size(), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(result[i], static_cast(data[i])); + } +} + +// Tests TfLiteTensorToFloat32Array for supported TfLiteTypes. +TEST(Utils, TfLiteTensorToFloat32Array) { + TestTfLiteTensorToFloat32Array(kTfLiteFloat32); + TestTfLiteTensorToFloat32Array(kTfLiteFloat64); +} + +TEST(Utils, TfLiteTensorToInt64Array) { + TestTfLiteTensorToInt64Array(kTfLiteInt8); + TestTfLiteTensorToInt64Array(kTfLiteUInt8); + TestTfLiteTensorToInt64Array(kTfLiteInt16); + TestTfLiteTensorToInt64Array(kTfLiteUInt16); + TestTfLiteTensorToInt64Array(kTfLiteInt32); + TestTfLiteTensorToInt64Array(kTfLiteUInt32); + TestTfLiteTensorToInt64Array(kTfLiteInt64); + TestTfLiteTensorToInt64Array(kTfLiteUInt64); +} + +} // namespace +} // namespace tflite::tools diff --git a/tensorflow/lite/tools/versioning/gpu_compatibility.cc b/tensorflow/lite/tools/versioning/gpu_compatibility.cc index dd8658bda26a28..061eaca7a3c05b 100644 --- a/tensorflow/lite/tools/versioning/gpu_compatibility.cc +++ b/tensorflow/lite/tools/versioning/gpu_compatibility.cc @@ -1085,7 +1085,8 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, /*required_const_inputs=*/0, /*required_outputs=*/1)); - // Two arguments elemenetwise operations + // Two arguments elementwise operations + case kTfLiteBuiltinAtan2: case kTfLiteBuiltinDiv: case kTfLiteBuiltinEqual: case kTfLiteBuiltinFloorDiv: diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index ab15889a196aad..8e09aa303c21a4 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -235,6 +235,7 @@ tf_staging/third_party/googleapis/build_rules.bzl: tf_staging/third_party/googleapis/googleapis.BUILD: tf_staging/third_party/googleapis/repository_rules.bzl: tf_staging/third_party/gpus/BUILD: +tf_staging/third_party/gpus/compiler_common_tools.bzl: tf_staging/third_party/gpus/crosstool/BUILD.rocm.tpl: tf_staging/third_party/gpus/crosstool/BUILD.sycl.tpl: tf_staging/third_party/gpus/crosstool/BUILD.tpl: @@ -252,6 +253,27 @@ tf_staging/third_party/gpus/cuda/LICENSE: tf_staging/third_party/gpus/cuda/build_defs.bzl.tpl: tf_staging/third_party/gpus/cuda/cuda_config.h.tpl: tf_staging/third_party/gpus/cuda/cuda_config.py.tpl: +tf_staging/third_party/gpus/cuda/hermetic/BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/BUILD: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_configure.bzl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl: tf_staging/third_party/gpus/cuda_configure.bzl: tf_staging/third_party/gpus/find_cuda_config:.py tf_staging/third_party/gpus/rocm/BUILD.tpl: @@ -284,6 +306,9 @@ tf_staging/third_party/nccl/archive.BUILD: tf_staging/third_party/nccl/archive.patch: tf_staging/third_party/nccl/build_defs.bzl.tpl: tf_staging/third_party/nccl/generated_names.bzl.tpl: +tf_staging/third_party/nccl/hermetic/BUILD: +tf_staging/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl: +tf_staging/third_party/nccl/hermetic/nccl_configure.bzl: tf_staging/third_party/nccl/nccl_configure.bzl: tf_staging/third_party/nccl/system.BUILD.tpl: tf_staging/third_party/nlohmann_json.BUILD: @@ -321,6 +346,7 @@ tf_staging/third_party/remote_config/remote_platform_configure.bzl: tf_staging/third_party/repo.bzl: tf_staging/third_party/six.BUILD: tf_staging/third_party/snappy.BUILD: +tf_staging/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD: tf_staging/third_party/sqlite.BUILD: tf_staging/third_party/stablehlo/BUILD: tf_staging/third_party/systemlibs/BUILD.tpl: diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index 355f2c04e20dd4..a1d8ae68ecdd80 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -292,6 +292,7 @@ py_strict_library( "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", "@pypi_wrapt//:pkg", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index d42e18551808d6..87b794fe094156 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -41,6 +41,7 @@ from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest +from tensorflow.python.util import numpy_compat from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.tf_export import tf_export @@ -140,14 +141,19 @@ def _get_feeds_for_indexed_slices(feed, feed_val): def _convert_to_numpy_obj(numpy_dtype, obj): """Explicitly convert obj based on numpy type except for string type.""" - return numpy_dtype(obj) if numpy_dtype is not object else str(obj) + return ( + numpy_dtype(np.array(obj).astype(numpy_dtype)) + if numpy_dtype is not object + else str(obj) + ) def register_session_run_conversion_functions( tensor_type, fetch_function, feed_function=None, - feed_function_for_partial_run=None): + feed_function_for_partial_run=None, +): """Register fetch and feed conversion functions for `tf.Session.run()`. This function registers a triple of conversion functions for fetching and/or @@ -1181,7 +1187,7 @@ def _feed_fn(feed, feed_val): np_val = subfeed_val.to_numpy_array() feed_handles[subfeed_t.ref()] = subfeed_val else: - np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) + np_val = numpy_compat.np_asarray(subfeed_val, subfeed_dtype) if (not is_tensor_handle_feed and not subfeed_t.get_shape().is_compatible_with(np_val.shape)): diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 1464424d01871b..29b1611b41ef35 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 8, 8) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 8, 20) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/README.md b/tensorflow/python/compiler/tensorrt/README.md index 4c1d96bbed7e99..ec95cb6de69d30 100644 --- a/tensorflow/python/compiler/tensorrt/README.md +++ b/tensorflow/python/compiler/tensorrt/README.md @@ -1,5 +1,7 @@ # Using TensorRT in TensorFlow (TF-TRT) +Note: Starting from v.2.18.0, TensorFlow doesn't support TensorRT. + This module provides necessary bindings and introduces `TRTEngineOp` operator that wraps a subgraph in TensorRT. This module is under active development. diff --git a/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi b/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi index e88ec5672773ef..29126c1902939e 100644 --- a/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi +++ b/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi @@ -14,4 +14,3 @@ # ============================================================================== def TF_DATA_DefaultProtocol() -> str: ... -def TF_DATA_DisableCompressionAtRegistrationTime() -> bool: ... diff --git a/tensorflow/python/data/experimental/service/utils_wrapper.cc b/tensorflow/python/data/experimental/service/utils_wrapper.cc index f94982931e148b..c725ff3f58ec13 100644 --- a/tensorflow/python/data/experimental/service/utils_wrapper.cc +++ b/tensorflow/python/data/experimental/service/utils_wrapper.cc @@ -23,8 +23,4 @@ limitations under the License. PYBIND11_MODULE(_pywrap_utils_exp, m) { m.def("TF_DATA_DefaultProtocol", []() -> std::string { return tensorflow::data::DefaultProtocol(); }); - - m.def("TF_DATA_DisableCompressionAtRegistrationTime", []() -> bool { - return tensorflow::data::DisableCompressionAtRegistrationTime(); - }); }; diff --git a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc index cd3d2eeb1d1c8a..1f334a09464de4 100644 --- a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc +++ b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc @@ -19,15 +19,9 @@ limitations under the License. #include "Python.h" #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/c_api_experimental.h" -#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/parallel_device/parallel_device.h" -#include "tensorflow/c/safe_ptr.h" -#include "tensorflow/python/lib/core/py_exception_registry.h" #include "tensorflow/python/lib/core/pybind11_lib.h" -#include "tensorflow/python/lib/core/pybind11_status.h" #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" namespace py = pybind11; diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 532d7f1555521f..8c49758d560dcd 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -508,11 +508,12 @@ def testEagerTensorFormatForVariant(self): f"{t!r}", ">") def testNumpyTooManyDimensions(self): - t = constant_op.constant(1., shape=[1] * 33) + max_dims = 64 if np.lib.NumpyVersion(np.__version__) >= "2.0.0.dev0" else 32 + t = constant_op.constant(1., shape=[1] * (max_dims + 1)) with self.assertRaisesRegex( errors.InvalidArgumentError, - "Cannot convert tensor with 33 dimensions to NumPy array. NumPy arrays " - "can have at most 32 dimensions"): + "Cannot convert tensor with %d dimensions to NumPy array. NumPy arrays " + "can have at most %d dimensions"% (max_dims + 1, max_dims)): t.numpy() def testNumpyDimsTooBig(self): diff --git a/tensorflow/python/flags_pybind.pyi b/tensorflow/python/flags_pybind.pyi index b34ed2f4b68c19..7c450b682a40a8 100644 --- a/tensorflow/python/flags_pybind.pyi +++ b/tensorflow/python/flags_pybind.pyi @@ -24,6 +24,7 @@ class Flags: enable_function_pruning_before_inlining: Flag enable_nested_function_shape_inference: Flag enable_quantized_dtypes_training: Flag + enable_skip_encapsulation_for_non_tpu_graphs: Flag enable_tf2min_ici_weight: Flag graph_building_optimization: Flag more_stack_traces: Flag diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 69e16041623de8..1a18e8feabd380 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -6,6 +6,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") # Placeholder: load py_proto_library load( "//tensorflow:tensorflow.bzl", + "if_hermetic_cuda_tools", "if_not_windows", "if_oss", "if_xla_available", @@ -1046,6 +1047,13 @@ tf_python_pybind_extension( "python_api_dispatcher.h", "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", ], + # This data is needed to add hermetic CUDA tools in python runfiles. + data = if_hermetic_cuda_tools( + [ + "@cuda_nvcc//:ptxas", + "@cuda_nvcc//:nvvm", + ], + ), enable_stub_generation = True, pytype_srcs = [ "_pywrap_python_api_dispatcher.pyi", @@ -2051,6 +2059,7 @@ py_strict_library( "//tensorflow/python/types:internal", "//tensorflow/python/util:compat", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", ], diff --git a/tensorflow/python/framework/extension_type_test.py b/tensorflow/python/framework/extension_type_test.py index 0169690eaf3c33..a97180fab1ce8a 100644 --- a/tensorflow/python/framework/extension_type_test.py +++ b/tensorflow/python/framework/extension_type_test.py @@ -130,7 +130,7 @@ def _masked_array_repr(values, mask): """Returns a string representation for a masked numpy array.""" assert len(values) == len(mask) if len(values.shape) == 1: - items = [repr(v) if m else '_' for (v, m) in zip(values, mask)] + items = [repr(v.item()) if m else '_' for (v, m) in zip(values, mask)] else: items = [_masked_array_repr(v, m) for (v, m) in zip(values, mask)] return '[%s]' % ', '.join(items) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 34b1eed754bbed..823ced42bf766e 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -213,7 +213,15 @@ def numpy_text(tensor, is_repr=False) -> str: """Human readable representation of a tensor's numpy value.""" if tensor.dtype.is_numpy_compatible: # pylint: disable=protected-access - text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) + tensor_numpy = tensor._numpy() + if is_repr: + if np.isscalar(tensor_numpy) and not isinstance(tensor_numpy, bytes): + # .item() converts the numpy scalars to python items. + text = repr(tensor_numpy.item()) + else: + text = repr(tensor_numpy) + else: + text = str(tensor_numpy) # pylint: enable=protected-access else: text = "" diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 59fbeb3429c68d..d629fcdbf1787d 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Utilities to create TensorProtos.""" + import typing from typing import Protocol + import numpy as np from tensorflow.core.framework import tensor_pb2 @@ -27,8 +29,10 @@ from tensorflow.python.types import internal from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util import numpy_compat from tensorflow.python.util.tf_export import tf_export + # Fallback in case fast_tensor_util is not properly compiled. # pylint: disable=g-import-not-at-top try: @@ -519,7 +523,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False, nparray = np.empty(shape, dtype=np_dt) else: _AssertCompatible(values, dtype) - nparray = np.array(values, dtype=np_dt) + nparray = numpy_compat.np_array(values, np_dt) # check to them. # We need to pass in quantized values as tuples, so don't apply the shape if (list(nparray.shape) != _GetDenseDimensions(values) and diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py index 7304f262b720b1..b4d4ed25a950b3 100644 --- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py @@ -377,7 +377,7 @@ def testReverse0DimAuto(self): self.assertAllEqual(x_tf, x_np) def _reverse1DimAuto(self, np_dtype): - x_np = np.array([1, 200, 3, 40, 5], dtype=np_dtype) + x_np = np.array([1, 120, 3, 40, 5], dtype=np_dtype) for use_gpu in [False, True]: for axis_dtype in [dtypes.int32, dtypes.int64]: @@ -388,7 +388,7 @@ def _reverse1DimAuto(self, np_dtype): self.assertAllEqual(x_tf, np.asarray(x_np)[::-1]) def _reverse2DimAuto(self, np_dtype): - x_np = np.array([[1, 200, 3], [4, 5, 60]], dtype=np_dtype) + x_np = np.array([[1, 120, 3], [4, 5, 60]], dtype=np_dtype) for reverse_f in [array_ops.reverse_v2, array_ops.reverse]: for use_gpu in [False, True]: diff --git a/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py index 88d51257b517be..67eb28739df0b9 100644 --- a/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py @@ -336,7 +336,7 @@ def expected_pinv(self, a, rcond): a_pinv = np.zeros(s, dtype=a.dtype) for i in np.ndindex(a.shape[:(a.ndim - 2)]): a_pinv[i] = np.linalg.pinv( - a[i], rcond=rcond if isinstance(rcond, float) else rcond[i]) + a[i], rcond=rcond if isinstance(rcond.tolist(), float) else rcond[i]) return a_pinv def test_symmetric(self): diff --git a/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py index cc1800755ed2fa..2320ba25b88897 100644 --- a/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py @@ -1254,7 +1254,7 @@ def _ConstructAndTestGradient(self, err_tolerance = 1e-4 else: if x_init_value is None: - x_init_value = np.asfarray( + x_init_value = np.asarray( np.arange(1, total_size + 1), dtype=np.float32).reshape(input_sizes) func_name = "max_pool" @@ -1332,7 +1332,7 @@ def _ConstructAndTestSecondGradient(self, err_tolerance = 1e-3 else: if x_init_value is None: - x_init_value = np.asfarray( + x_init_value = np.asarray( np.arange(1, total_size + 1), dtype=np.float32).reshape(input_sizes) func_name = "max_pool" diff --git a/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py index 4273c209d42213..5e01d981a90062 100644 --- a/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py @@ -348,8 +348,8 @@ def testRepr(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable(1) text = "%r" % v - self.assertEqual( - "", text) + error_msg = "" + self.assertEqual(error_msg, text) def testReprUnavailable(self): with context.eager_mode(): diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 4cd43ae0c37d28..1c81b35e48cc5e 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -716,16 +716,22 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, // These objects are efficiently handled by Numpy. We transform them into // Numpy arrays and handle them in the Numpy case below. Note that Tensors // implement the __array__ function, and will be handled in this shortcut. - Safe_PyObjectPtr array = - make_safe(PyArray_FromArrayAttr(obj, nullptr, nullptr)); - if (array == nullptr) { - return nullptr; + // We used to call PyArray_FromArrayAttr here, but NumPy 2.0 changed its + // semantics such that it errors if a copy of the array is required. + // (Ideally no copy would be needed here, but that would be a larger change.) + Safe_PyObjectPtr array; + if (PyObject_HasAttrString(obj, "__array__")) { + array = make_safe(PyObject_CallMethod(obj, "__array__", nullptr)); + if (array == nullptr) { + return nullptr; + } + if (!PyArray_Check(array.get())) { + PyErr_SetString(PyExc_ValueError, + "Value returned by __array__ is not a NumPy array"); + return nullptr; + } } - if (array.get() == Py_NotImplemented) { - // The Py_NotImplemented returned from PyArray_FromArrayAttr is not - // Py_INCREF'ed, so we don't want the Safe_PyObjectPtr to Py_DECREF it. - array.release(); - + if (!array) { // Try __array_interface__ objects (such as PIL Image). array = make_safe(PyArray_FromInterface(obj)); if (array == nullptr) { diff --git a/tensorflow/python/lib/io/BUILD b/tensorflow/python/lib/io/BUILD index 0a97ef20c14055..57c9b7be07f095 100644 --- a/tensorflow/python/lib/io/BUILD +++ b/tensorflow/python/lib/io/BUILD @@ -62,7 +62,6 @@ py_strict_library( "//tensorflow:__subpackages__", "//tensorflow:internal", "//third_party/cloud_tpu/convergence_tools:__subpackages__", - "//third_party/proto_splitter:__subpackages__", # TODO(b/277279227): remove this dep from proto_splitter "//third_party/py/tf_slim:__subpackages__", ], deps = [ diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 76dd388f386a96..29f16a77bbca1d 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1276,7 +1276,39 @@ def _maybe_cast(elem): return _maybe_cast -_NON_AUTOPACKABLE_TYPES = set(np.core.numerictypes.ScalarType) +_NON_AUTOPACKABLE_TYPES = set(( + int, + float, + complex, + bool, + bytes, + str, + memoryview, + np.bool_, + np.complex64, + np.clongdouble, + np.complex128, + np.float16, + np.float32, + np.float64, + np.longdouble, + np.int8, + np.int16, + np.int32, + np.int64, + np.longlong, + np.timedelta64, + np.datetime64, + np.object_, + np.bytes_, + np.str_, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.ulonglong, + np.void, +)) _NON_AUTOPACKABLE_TYPES.add(np.ndarray) diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py index 83e464d53d1a8d..f1b679b3de40d1 100644 --- a/tensorflow/python/ops/bitwise_ops_test.py +++ b/tensorflow/python/ops/bitwise_ops_test.py @@ -60,7 +60,7 @@ def count_bits(x): for dtype in dtype_list: with self.cached_session(): print("PopulationCount test: ", dtype) - inputs = np.array(raw_inputs, dtype=dtype.as_numpy_dtype) + inputs = np.array(raw_inputs).astype(dtype.as_numpy_dtype) truth = [count_bits(x) for x in inputs] input_tensor = constant_op.constant(inputs, dtype=dtype) popcnt_result = self.evaluate( diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 80463a67efc9ac..4b6b11853d4c4c 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1416,8 +1416,15 @@ def assert_rank_in( except ValueError as e: if e.args[0] == 'Static rank condition failed': raise ValueError( - '%sTensor %s must have rank in %s. Received rank %d, ' - 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) + '%sTensor %s must have rank in %s. Received rank %d, shape %s' + % ( + message, + name, + tuple(r.item() for r in e.args[2]), + e.args[1], + x.get_shape(), + ) + ) else: raise diff --git a/tensorflow/python/ops/gradient_checker_v2_test.py b/tensorflow/python/ops/gradient_checker_v2_test.py index 19835aeb09e4cb..362feab73b70cb 100644 --- a/tensorflow/python/ops/gradient_checker_v2_test.py +++ b/tensorflow/python/ops/gradient_checker_v2_test.py @@ -255,7 +255,8 @@ def f(x): *gradient_checker.compute_gradient(f, [x])) # Typical test would assert error < max_err, so assert this test would # raise AssertionError, since NaN is not < 1.0. - with self.assertRaisesRegex(AssertionError, "nan not less than 1.0"): + error_msg = r"(nan|np.float32\(nan\)) not less than 1.0" + with self.assertRaisesRegex(AssertionError, error_msg): self.assertLess(error, 1.0) def testGradGrad(self): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 510865596fe4b5..d55366762b8a11 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2010,7 +2010,12 @@ def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disa # infer dtype if not explicitly provided if dtype is None: dtype_hierarchy = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 + dtypes.int32, + dtypes.int64, + dtypes.float16, + dtypes.bfloat16, + dtypes.float32, + dtypes.float64, ] assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta]) inferred_dtype = max([arg.dtype for arg in [start, limit, delta]], diff --git a/tensorflow/python/ops/ragged/convert_to_tensor_or_ragged_tensor_op_test.py b/tensorflow/python/ops/ragged/convert_to_tensor_or_ragged_tensor_op_test.py index 92a58e8190b275..4957ed02cce806 100644 --- a/tensorflow/python/ops/ragged/convert_to_tensor_or_ragged_tensor_op_test.py +++ b/tensorflow/python/ops/ragged/convert_to_tensor_or_ragged_tensor_op_test.py @@ -130,7 +130,8 @@ def testConvertRaggedTensorValue(self, value=ragged_factory_ops.constant_value([['a', 'b'], ['c']], dtype=str), dtype=dtypes.int32, - message=r"invalid literal for int\(\) with base 10: 'a'"), + message=(r"invalid literal for int\(\) with base 10: " + r"('a'|np.str_\('a'\))")), ]) def testConvertRaggedTensorValueError(self, value, @@ -216,7 +217,8 @@ def testConvertNumpyArray(self, dict( value=np.array([['a', 'b'], ['c', 'd']], dtype=str), dtype=dtypes.int32, - message=r"invalid literal for int\(\) with base 10: 'a'"), + message=(r"invalid literal for int\(\) with base 10: " + r"('a'|np.str_\('a'\))")), ]) def testConvertNumpyArrayError(self, value, diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py index 9e096e01b56d7a..215304c867507c 100644 --- a/tensorflow/python/ops/ragged/ragged_factory_ops.py +++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py @@ -150,14 +150,19 @@ def _ragged_factory(values, row_splits): return ragged_tensor_value.RaggedTensorValue(values, row_splits) def _inner_factory(pylist, dtype, shape, name=None): # pylint: disable=unused-argument - return np.reshape(np.array(pylist, dtype=dtype), shape) + if dtype is object or dtype is None: + return np.reshape(np.array(pylist, dtype=dtype), shape) + else: + return np.reshape(np.array(pylist).astype(dtype), shape) - return _constant_value(_ragged_factory, _inner_factory, pylist, dtype, - ragged_rank, inner_shape) + return _constant_value( + _ragged_factory, _inner_factory, pylist, dtype, ragged_rank, inner_shape + ) -def _constant_value(ragged_factory, inner_factory, pylist, dtype, ragged_rank, - inner_shape): +def _constant_value( + ragged_factory, inner_factory, pylist, dtype, ragged_rank, inner_shape +): """Constructs a constant RaggedTensor or RaggedTensorValue. Args: diff --git a/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py b/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py index d6b7d12999ba52..03b864b01d86dd 100644 --- a/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py +++ b/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py @@ -192,7 +192,8 @@ def testNaNGradFails(self): error = gradient_checker.compute_gradient_error(x, (), y, ()) # Typical test would assert error < max_err, so assert this test would # raise AssertionError, since NaN is not < 1.0. - with self.assertRaisesRegex(AssertionError, "False is not true"): + error_msg = "(False|np.False_) is not true" + with self.assertRaisesRegex(AssertionError, error_msg): self.assertTrue(error < 1.0) diff --git a/tensorflow/python/platform/BUILD b/tensorflow/python/platform/BUILD index 0ca7e7bfae738f..7c6c086871fcae 100644 --- a/tensorflow/python/platform/BUILD +++ b/tensorflow/python/platform/BUILD @@ -225,7 +225,6 @@ py_strict_library( "//tensorflow_models:__subpackages__", "//third_party/cloud_tpu/convergence_tools:__subpackages__", "//third_party/mlperf:__subpackages__", - "//third_party/proto_splitter:__subpackages__", # TODO(b/277279227): remove this dep from proto_splitter "//third_party/py/tf_slim:__subpackages__", ], deps = [ @@ -301,7 +300,7 @@ py_strict_library( py_strict_library( name = "gfile", srcs = ["gfile.py"], - visibility = visibility + ["//third_party/py/tf_slim/training:__pkg__"], + visibility = visibility, deps = [ "//tensorflow/python/lib/io:file_io", "//tensorflow/python/util:deprecation", diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index cfc2790371eed4..ee56bb821a2f30 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -889,6 +889,20 @@ py_strict_library( ], ) +py_strict_library( + name = "numpy_compat", + srcs = ["numpy_compat.py"], + compatible_with = get_compatible_with_portable(), + visibility = util_subpackage_visibility, + deps = [ + # global_test_configuration is added here because all major tests depend on this + # library. It isn't possible to add these test dependencies via tensorflow.bzl's + # py test because not all tensorflow tests use tensorflow.bzl's py test. + "//tensorflow/python:global_test_configuration", + "//third_party/py/numpy", + ], +) + py_strict_library( name = "object_identity", srcs = ["object_identity.py"], diff --git a/tensorflow/python/util/numpy_compat.py b/tensorflow/python/util/numpy_compat.py new file mode 100644 index 00000000000000..87a705066a8273 --- /dev/null +++ b/tensorflow/python/util/numpy_compat.py @@ -0,0 +1,66 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functions for NumPy 1.x vs. 2.x compatibility.""" + +import numpy as np + + +def np_array(values, dtype): + """Creates a NumPy array containing input values. + + It will make a copy of the object. + + In NumPy 2.x and later, strict type casting can lead to errors when values + overflow the specified dtype. This function addresses this by replacing direct + np.array(..., dtype=...) calls with np.array(...).astype(...). This allows for + intended overflows, aligning with the behavior of older NumPy versions. + + Args: + values: Array_like objects. E.g., a python list, tuple, or an object + whose __array__ method returns an array. + dtype: The desired numpy data type for the array. + + Returns: + A NumPy array with the specified data type. + """ + if dtype is not None and np.issubdtype(dtype, np.number): + return np.array(values).astype(dtype) + else: + return np.array(values, dtype=dtype) + + +def np_asarray(values, dtype): + """Converts input values to a NumPy array. + + It will not make a copy. + + In NumPy 2.x and later, strict type casting can lead to errors when values + overflow the specified dtype. This function addresses this by replacing direct + np.array(..., dtype=...) calls with np.array(...).astype(...). This allows for + intended overflows, aligning with the behavior of older NumPy versions. + + Args: + values: Array_like objects. E.g., a python list, tuple, or an object + whose __array__ method returns an array. + dtype: The desired numpy data type for the array. + + Returns: + A NumPy array with the specified data type. + """ + if dtype is not None and np.issubdtype(dtype, np.number): + return np.asarray(values).astype(dtype) + else: + return np.asarray(values, dtype=dtype) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d1e35f02a0f190..c1c50767cdad69 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -70,6 +70,7 @@ load( "tsl_gpu_library", _cc_header_only_library = "cc_header_only_library", _if_cuda_or_rocm = "if_cuda_or_rocm", + _if_hermetic_cuda_tools = "if_hermetic_cuda_tools", _if_nccl = "if_nccl", _transitive_hdrs = "transitive_hdrs", ) @@ -800,7 +801,7 @@ def tf_cc_shared_object( testonly = kwargs.pop("testonly", False) for name_os, name_os_major, name_os_full in names: - # Windows DLLs cant be versioned + # Windows DLLs can't be versioned if name_os.endswith(".dll"): name_os_major = name_os name_os_full = name_os @@ -1075,7 +1076,8 @@ def tf_cc_binary( ], ), tags = tags, - data = depset(data + added_data_deps), + data = depset(data + added_data_deps).to_list() + + tf_binary_additional_srcs(fullversion = True), linkopts = linkopts + _rpath_linkopts(name_os), visibility = visibility, **kwargs @@ -1568,7 +1570,7 @@ def tf_cc_test( ), data = data + tf_binary_dynamic_kernel_dsos() + - tf_binary_additional_srcs(), + tf_binary_additional_srcs(fullversion = True), exec_properties = tf_exec_properties(kwargs), **kwargs ) @@ -1733,6 +1735,7 @@ def tf_gpu_only_cc_test( tf_gpu_kernel_library( name = gpu_lib_name, srcs = srcs + tf_binary_additional_srcs(), + data = tf_binary_additional_srcs(fullversion = True), deps = deps, testonly = 1, features = features, @@ -3574,3 +3577,6 @@ def replace_with_portable_tf_lib_when_required(non_portable_tf_deps, use_lib_wit def tf_python_framework_friends(): return ["//tensorflow:__subpackages__"] + +def if_hermetic_cuda_tools(if_true, if_false = []): + return _if_hermetic_cuda_tools(if_true, if_false) diff --git a/tensorflow/tools/pip_package/build_pip_package.py b/tensorflow/tools/pip_package/build_pip_package.py index 9588fc19e3d4e9..1846082b8147b8 100644 --- a/tensorflow/tools/pip_package/build_pip_package.py +++ b/tensorflow/tools/pip_package/build_pip_package.py @@ -69,6 +69,36 @@ def prepare_headers(headers: list[str], srcs_dir: str) -> None: srcs_dir: target directory where headers are copied to. """ path_to_exclude = [ + "cuda_cccl/_virtual_includes", + "cuda_cublas/_virtual_includes", + "cuda_cudart/_virtual_includes", + "cuda_cudnn/_virtual_includes", + "cuda_cufft/_virtual_includes", + "cuda_cupti/_virtual_includes", + "cuda_curand/_virtual_includes", + "cuda_cusolver/_virtual_includes", + "cuda_cusparse/_virtual_includes", + "cuda_nccl/_virtual_includes", + "cuda_nvcc/_virtual_includes", + "cuda_nvjitlink/_virtual_includes", + "cuda_nvml/_virtual_includes", + "cuda_nvrtc/_virtual_includes", + "cuda_nvtx/_virtual_includes", + "external/cuda_cccl", + "external/cuda_cublas", + "external/cuda_cudart", + "external/cuda_cudnn", + "external/cuda_cufft", + "external/cuda_cupti", + "external/cuda_curand", + "external/cuda_cusolver", + "external/cuda_cusparse", + "external/cuda_nccl", + "external/cuda_nvcc", + "external/cuda_nvjitlink", + "external/cuda_nvml", + "external/cuda_nvrtc", + "external/cuda_nvtx", "external/pypi", "external/jsoncpp_git/src", "local_config_cuda/cuda/_virtual_includes", diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD index da86a3ae4401f4..105cecfae4465c 100644 --- a/tensorflow/tools/proto_splitter/cc/BUILD +++ b/tensorflow/tools/proto_splitter/cc/BUILD @@ -179,11 +179,11 @@ cc_library( hdrs = ["graph_def_splitter.h"], deps = [ ":composable_splitter", + ":composable_splitter_base", ":large_node_splitter", ":max_size", ":repeated_field_splitter", ":size_splitter", - ":split", ":util", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", @@ -309,18 +309,15 @@ tf_cc_test( cc_library( name = "large_node_splitter", - srcs = ["large_node_splitter.cc"], hdrs = ["large_node_splitter.h"], deps = [ ":composable_splitter", + ":composable_splitter_base", ":max_size", ":size_splitter", ":util", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter.cc index 81e8d5d9a3aec4..7f274734a6b76e 100644 --- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter.cc +++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter.cc @@ -31,18 +31,18 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/tools/proto_splitter/cc/composable_splitter.h" +#include "tensorflow/tools/proto_splitter/cc/composable_splitter_base.h" #include "tensorflow/tools/proto_splitter/cc/large_node_splitter.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" #include "tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h" #include "tensorflow/tools/proto_splitter/cc/size_splitter.h" -#include "tensorflow/tools/proto_splitter/cc/split.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" -namespace tensorflow { -namespace tools::proto_splitter { +namespace tensorflow::tools::proto_splitter { namespace { @@ -144,7 +144,7 @@ class FunctionDefSplitter : public SizeSplitter { LargeNodeSplitterFactory large_node_splitter_factory; std::vector factories = { &constant_splitter_factory, &large_node_splitter_factory}; - auto ret = RepeatedFieldSplitters::Create( + auto ret = RepeatedFieldSplitter::Create( message(), this, &fields, "node_def"s, &factories); if (!ret.ok()) return ret.status(); auto splitter = ret.value(); @@ -184,7 +184,7 @@ absl::Status GraphDefSplitter::BuildChunks() { LargeNodeSplitterFactory large_node_splitter_factory; std::vector factories = {&constant_splitter_factory, &large_node_splitter_factory}; - auto node_splitter_ret = RepeatedFieldSplitters::Create( + auto node_splitter_ret = RepeatedFieldSplitter::Create( g, this, &field_in_parent, "node"s, &factories); if (!node_splitter_ret.ok()) return node_splitter_ret.status(); auto node_splitter = node_splitter_ret.value(); @@ -193,7 +193,7 @@ absl::Status GraphDefSplitter::BuildChunks() { std::vector library_field = {"library"s}; std::vector fn_factories = {&function_splitter_factory}; auto library_splitter_ret = - RepeatedFieldSplitters::Create( + RepeatedFieldSplitter::Create( g->mutable_library(), this, &library_field, "function"s, &fn_factories); if (!library_splitter_ret.ok()) return library_splitter_ret.status(); @@ -238,5 +238,4 @@ absl::Status GraphDefSplitter::BuildChunks() { return absl::OkStatus(); } -} // namespace tools::proto_splitter -} // namespace tensorflow +} // namespace tensorflow::tools::proto_splitter diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc index 1fb19f5263008a..b5c27118cf0cbc 100644 --- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc @@ -44,23 +44,6 @@ namespace { using ::tensorflow::proto_splitter::ChunkedMessage; -// Ensures that all Messages are less than the max size. std::string chunks are -// not limited by the max size, so they are ignored in this check. -#define EXPECT_CHUNK_SIZES(chunks, max_size) \ - do { \ - for (auto chunk : *chunks) { \ - if (std::holds_alternative>( \ - chunk)) { \ - EXPECT_LE(std::get>(chunk) \ - ->ByteSizeLong(), \ - max_size); \ - } else if (std::holds_alternative(chunk)) { \ - EXPECT_LE(std::get(chunk)->ByteSizeLong(), \ - max_size); \ - } \ - } \ - } while (0) - TEST(GraphDefSplitterTest, TestLargeConstant) { GraphDef proto; const std::string graph_def_path = diff --git a/tensorflow/tools/proto_splitter/cc/large_node_splitter.cc b/tensorflow/tools/proto_splitter/cc/large_node_splitter.cc deleted file mode 100644 index cf0ff26f51f985..00000000000000 --- a/tensorflow/tools/proto_splitter/cc/large_node_splitter.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/tools/proto_splitter/cc/large_node_splitter.h" - -#include -#include - -#include "absl/memory/memory.h" -#include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/tools/proto_splitter/cc/max_size.h" -#include "tensorflow/tools/proto_splitter/cc/size_splitter.h" -#include "tensorflow/tools/proto_splitter/cc/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace tools::proto_splitter { - -template -absl::StatusOr LargeNodeSplitter::BuildChunksReturnSize() { - MessageType* msg = - tsl::protobuf::DynamicCastToGenerated(message()); - int initial_size = GetInitialSize(); - std::shared_ptr new_msg = std::make_shared(); - msg->Swap(new_msg.get()); - std::vector fields = {}; - auto x = std::make_unique(new_msg); - TF_RETURN_IF_ERROR(AddChunk(std::move(x), &fields, index_)); - return initial_size; -} - -template -absl::StatusOr> -LargeNodeSplitterFactory::CreateSplitter( - tsl::protobuf::Message* message, ComposableSplitterBase* parent_splitter, - std::vector* fields_in_parent, int size) { - if (!(LARGE_SIZE_CHECK(size, GetMaxSize()))) return nullptr; - LargeNodeSplitter* splitter = new LargeNodeSplitter( - message, parent_splitter, fields_in_parent); - return absl::WrapUnique(splitter); -} - -template class LargeNodeSplitter; -template class LargeNodeSplitter; -template class LargeNodeSplitter; -template class LargeNodeSplitter; -template class LargeNodeSplitterFactory; -template class LargeNodeSplitterFactory; -template class LargeNodeSplitterFactory; -template class LargeNodeSplitterFactory; - -} // namespace tools::proto_splitter -} // namespace tensorflow diff --git a/tensorflow/tools/proto_splitter/cc/large_node_splitter.h b/tensorflow/tools/proto_splitter/cc/large_node_splitter.h index e5969cf652dd37..15c9964fa44644 100644 --- a/tensorflow/tools/proto_splitter/cc/large_node_splitter.h +++ b/tensorflow/tools/proto_splitter/cc/large_node_splitter.h @@ -20,7 +20,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/tools/proto_splitter/cc/composable_splitter.h" +#include "tensorflow/tools/proto_splitter/cc/composable_splitter_base.h" +#include "tensorflow/tools/proto_splitter/cc/max_size.h" #include "tensorflow/tools/proto_splitter/cc/size_splitter.h" +#include "tensorflow/tools/proto_splitter/cc/util.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace tools::proto_splitter { @@ -40,6 +44,19 @@ class LargeNodeSplitter : public SizeSplitter { int* index_ = nullptr; }; +template +absl::StatusOr LargeNodeSplitter::BuildChunksReturnSize() { + MessageType* msg = + tsl::protobuf::DynamicCastToGenerated(message()); + int initial_size = GetInitialSize(); + std::shared_ptr new_msg = std::make_shared(); + msg->Swap(new_msg.get()); + std::vector fields = {}; + auto x = std::make_unique(new_msg); + TF_RETURN_IF_ERROR(AddChunk(std::move(x), &fields, index_)); + return initial_size; +} + template class LargeNodeSplitterFactory : public SizeSplitterFactory { public: @@ -50,6 +67,17 @@ class LargeNodeSplitterFactory : public SizeSplitterFactory { std::vector* fields_in_parent, int size) override; }; +template +absl::StatusOr> +LargeNodeSplitterFactory::CreateSplitter( + tsl::protobuf::Message* message, ComposableSplitterBase* parent_splitter, + std::vector* fields_in_parent, int size) { + if (!(LARGE_SIZE_CHECK(size, GetMaxSize()))) return nullptr; + LargeNodeSplitter* splitter = new LargeNodeSplitter( + message, parent_splitter, fields_in_parent); + return absl::WrapUnique(splitter); +} + } // namespace tools::proto_splitter } // namespace tensorflow diff --git a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc index 01601c7e22a1fc..552009f3916e61 100644 --- a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc +++ b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h" +#include #include #include #include @@ -24,67 +25,63 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/tools/proto_splitter/cc/composable_splitter.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" -#include "tensorflow/tools/proto_splitter/cc/split.h" +#include "tensorflow/tools/proto_splitter/cc/size_splitter.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" -namespace tensorflow { -namespace tools::proto_splitter { +namespace tensorflow::tools::proto_splitter { // Additional bytes added to each node to account for the extra info needed to // encode the field key (realistically 3 but making it 5 for some wiggle room). constexpr int kExtraBytes = 5; template -absl::StatusOr> -RepeatedFieldSplitters::Create( +absl::StatusOr> +RepeatedFieldSplitter::Create( tsl::protobuf::Message* message, ComposableSplitter* parent_splitter, std::vector* fields_in_parent, const FieldType& repeated_field, std::vector* splitter_factories) { - // std::vector all_fields = *fields_in_parent; - // all_fields.push_back(repeated_field); - // std::vector - TF_ASSIGN_OR_RETURN(auto field_ret, GetField(*message, {repeated_field})); if (!field_ret.field->is_repeated()) { return absl::FailedPreconditionError("Unable to split non-repeated field."); } - auto ret = RepeatedFieldSplitters( + auto ret = RepeatedFieldSplitter( message, parent_splitter, fields_in_parent, repeated_field, splitter_factories); return ret; } template -absl::StatusOr RepeatedFieldSplitters< - ParentMessage, RepeatedMessage>::BuildChunksReturnSize() { - // std::vector all_fields = *fields_in_parent(); - // all_fields.push_back(repeated_field_); - - TF_ASSIGN_OR_RETURN(auto ret, GetMutableField(message(), {repeated_field_})); +absl::StatusOr +RepeatedFieldSplitter::BuildChunksReturnSize() { + TF_ASSIGN_OR_RETURN(MutableFieldResult mfr, + GetMutableField(message(), {repeated_field_})); + tsl::protobuf::Message* parent = mfr.parent; + const tsl::protobuf::FieldDescriptor* repeated_field = mfr.field; uint64_t max_size = GetMaxSize(); size_t initial_size = GetInitialSize(); // List of indices at which to split the repeated field. For example, [3, 5] // means that the field list is split into: [:3], [3:5], [5:] - std::vector repeated_msg_split = {0}; + std::vector repeated_msg_split; // Track the total byte size of the current node split. uint64_t total_size = 0; // Linearly iterate through all nodes. It may be possible to optimize this // further by making best guesses as to where to split the nodes, since // most nodes (aside from constants) are relatively small. - int repeated_field_size = - ret.parent->GetReflection()->FieldSize(*ret.parent, ret.field); - for (int i = 0; i < repeated_field_size; ++i) { + int repeated_field_length = + parent->GetReflection()->FieldSize(*parent, repeated_field); + for (int i = 0; i < repeated_field_length; ++i) { tsl::protobuf::Message* node = - ret.parent->GetReflection()->MutableRepeatedMessage(ret.parent, - ret.field, i); + parent->GetReflection()->MutableRepeatedMessage(parent, repeated_field, + i); auto node_size = node->ByteSizeLong(); std::vector new_fields = {repeated_field_, i}; @@ -106,25 +103,20 @@ absl::StatusOr RepeatedFieldSplitters< total_size += node_size + kExtraBytes; } - if (repeated_msg_split.size() > 1) { + if (!repeated_msg_split.empty()) { auto repeated_nodes_ptrs = - ret.parent->GetReflection() - ->template MutableRepeatedPtrField(ret.parent, - ret.field); - - int start = repeated_msg_split[0]; + parent->GetReflection() + ->template MutableRepeatedPtrField(parent, + repeated_field); - std::vector extracted_nodes; - extracted_nodes.resize(repeated_field_size - start); - repeated_nodes_ptrs->ExtractSubrange(start, repeated_field_size - start, + std::vector extracted_nodes(repeated_field_length); + repeated_nodes_ptrs->ExtractSubrange(0, repeated_field_length, &extracted_nodes.at(0)); - repeated_msg_split.push_back(repeated_field_size); - auto extracted_node = extracted_nodes.begin(); - - for (int i = 1; i < repeated_msg_split.size(); ++i) { - start = repeated_msg_split[i - 1]; - int end = repeated_msg_split[i]; + // Last range end is the size of the repeated field. + repeated_msg_split.push_back(repeated_field_length); + int range_start = 0; + for (int range_end : repeated_msg_split) { auto new_msg = std::make_shared(); std::vector empty_fields; auto x = std::make_unique(new_msg); @@ -134,10 +126,12 @@ absl::StatusOr RepeatedFieldSplitters< TF_ASSIGN_OR_RETURN(auto new_ret, GetMutableField(new_msg.get(), repeated_field_)); - for (int j = 0; j < end - start; ++j) { + for (int j = range_start; j < range_end; ++j) { new_msg->GetReflection()->AddAllocatedMessage( - new_msg.get(), new_ret.field, *extracted_node++); + new_msg.get(), new_ret.field, extracted_nodes[j]); } + + range_start = range_end; } } @@ -147,9 +141,8 @@ absl::StatusOr RepeatedFieldSplitters< } // Declare template classes to fix linking error. -template class RepeatedFieldSplitters; -template class RepeatedFieldSplitters; -template class RepeatedFieldSplitters; +template class RepeatedFieldSplitter; +template class RepeatedFieldSplitter; +template class RepeatedFieldSplitter; -} // namespace tools::proto_splitter -} // namespace tensorflow +} // namespace tensorflow::tools::proto_splitter diff --git a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h index eef7247a1925ef..5395f76ad9b69f 100644 --- a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h +++ b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h @@ -20,25 +20,24 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/tools/proto_splitter/cc/composable_splitter.h" #include "tensorflow/tools/proto_splitter/cc/size_splitter.h" +#include "tensorflow/tools/proto_splitter/cc/util.h" #include "tsl/platform/protobuf.h" -namespace tensorflow { -namespace tools::proto_splitter { +namespace tensorflow::tools::proto_splitter { // Splitter that works on repeated message fields. template -class RepeatedFieldSplitters : public SizeSplitter { +class RepeatedFieldSplitter : public SizeSplitter { public: - static absl::StatusOr Create( + static absl::StatusOr Create( tsl::protobuf::Message* message, ComposableSplitter* parent_splitter, std::vector* fields_in_parent, const FieldType& repeated_field, std::vector* splitter_factories); absl::StatusOr BuildChunksReturnSize() override; - FieldType repeated_field_; private: - explicit RepeatedFieldSplitters( + explicit RepeatedFieldSplitter( tsl::protobuf::Message* message, ComposableSplitter* parent_splitter, std::vector* fields_in_parent, const FieldType& repeated_field, std::vector* splitter_factories) @@ -46,10 +45,10 @@ class RepeatedFieldSplitters : public SizeSplitter { repeated_field_(repeated_field), splitter_factories_(splitter_factories) {} + FieldType repeated_field_; std::vector* splitter_factories_; }; -} // namespace tools::proto_splitter -} // namespace tensorflow +} // namespace tensorflow::tools::proto_splitter #endif // TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_REPEATED_FIELD_SPLITTER_H_ diff --git a/tensorflow/tools/proto_splitter/cc/test_util.h b/tensorflow/tools/proto_splitter/cc/test_util.h index 9187521fc14712..dd73cbd3bd1b00 100644 --- a/tensorflow/tools/proto_splitter/cc/test_util.h +++ b/tensorflow/tools/proto_splitter/cc/test_util.h @@ -28,6 +28,23 @@ limitations under the License. namespace tensorflow { namespace tools::proto_splitter { +// Ensures that all Messages are less than the max size. std::string chunks are +// not limited by the max size, so they are ignored in this check. +#define EXPECT_CHUNK_SIZES(chunks, max_size) \ + do { \ + for (auto chunk : *chunks) { \ + if (std::holds_alternative>( \ + chunk)) { \ + EXPECT_LE(std::get>(chunk) \ + ->ByteSizeLong(), \ + max_size); \ + } else if (std::holds_alternative(chunk)) { \ + EXPECT_LE(std::get(chunk)->ByteSizeLong(), \ + max_size); \ + } \ + } \ + } while (0) + inline std::string SerializeAsString(const tsl::protobuf::Message& message) { std::string result; { diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/code_check_full.bats b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/code_check_full.bats index b007c07b974934..f575f22005911d 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/code_check_full.bats +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/code_check_full.bats @@ -216,6 +216,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ "somepath(//tensorflow/tools/pip_package:build_pip_package, " \ "@local_config_cuda//cuda:cudart + "\ "@local_config_cuda//cuda:cudart + "\ @@ -236,6 +238,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ --define framework_shared_object=false \ "somepath(//tensorflow/tools/pip_package:build_pip_package, " \ "@local_config_cuda//cuda:cudart + "\ diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl index f0fa44c759b346..abf72cbc605e91 100644 --- a/tensorflow/tools/toolchains/remote_config/configs.bzl +++ b/tensorflow/tools/toolchains/remote_config/configs.bzl @@ -225,8 +225,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -236,8 +236,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "9.1", + cuda_version = "12.3.2", + cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -248,8 +248,8 @@ def initialize_rbe_configs(): name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -258,8 +258,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -270,8 +270,8 @@ def initialize_rbe_configs(): name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -479,7 +479,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -558,7 +558,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -710,11 +710,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -749,11 +749,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -788,12 +788,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -826,12 +826,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -864,12 +864,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "9.1", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "9.1.1", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "10.0", }, ) diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl index ae776c2a2fd388..9c4c93c988901e 100644 --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl @@ -1,9 +1,9 @@ """Macro that creates external repositories for remote config.""" load("//tensorflow/tools/toolchains/remote_config:containers.bzl", "containers") -load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure") +load("//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") @@ -42,7 +42,7 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_CUDNN_VERSION": cudnn_version, "TF_CUDA_VERSION": cuda_version, "CUDNN_INSTALL_PATH": cudnn_install_path if cudnn_install_path != None else "/usr/lib/x86_64-linux-gnu", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": tensorrt_version if tensorrt_version != None else "", "TENSORRT_INSTALL_PATH": tensorrt_install_path if tensorrt_install_path != None else "/usr/lib/x86_64-linux-gnu", "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", @@ -51,20 +51,26 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_SYSROOT": sysroot if sysroot else "", }) - container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os) + cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) + cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) + container_name = "cuda%s-cudnn%s-%s" % ( + cuda_version_in_container, + cudnn_version_in_container, + os, + ) container_image = _container_image_uri(container_name) exec_properties = { "container-image": container_image, "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, @@ -175,13 +181,13 @@ def sigbuild_tf_configs(name_container_map, env): "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 2959fe82ffbb7c..db96e3fc4383b6 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -29,7 +29,6 @@ load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") load("//third_party/FP16:workspace.bzl", FP16 = "repo") load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo") load("//third_party/git:git_configure.bzl", "git_configure") -load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/gpus:sycl_configure.bzl", "sycl_configure") load("//third_party/hexagon:workspace.bzl", hexagon_nn = "repo") @@ -42,7 +41,6 @@ load("//third_party/kissfft:workspace.bzl", kissfft = "repo") load("//third_party/libprotobuf_mutator:workspace.bzl", libprotobuf_mutator = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") load("//third_party/nasm:workspace.bzl", nasm = "repo") -load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") load("//third_party/opencl_headers:workspace.bzl", opencl_headers = "repo") load("//third_party/pasta:workspace.bzl", pasta = "repo") load("//third_party/py:python_configure.bzl", "python_configure") @@ -106,9 +104,7 @@ def _tf_toolchains(): # Note that we check the minimum bazel version in WORKSPACE. clang6_configure(name = "local_config_clang6") cc_download_clang_toolchain(name = "local_config_download_clang") - cuda_configure(name = "local_config_cuda") tensorrt_configure(name = "local_config_tensorrt") - nccl_configure(name = "local_config_nccl") git_configure(name = "local_config_git") syslibs_configure(name = "local_config_syslibs") python_configure(name = "local_config_python") @@ -154,18 +150,18 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "c4b8e34fe70cb5ccbc1c176a2119f1be673ec982ea2b4a78bc8102877cc24e14", - strip_prefix = "XNNPACK-d25d603e0b708d856e4cafca7dac1e6b7126c320", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/d25d603e0b708d856e4cafca7dac1e6b7126c320.zip"), + sha256 = "0e5d5c16686beff813e3946b26ca412f28acaf611228d20728ffb6479264fe19", + strip_prefix = "XNNPACK-9ddeb74f9f6866174d61888947e4aa9ffe963b1b", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/9ddeb74f9f6866174d61888947e4aa9ffe963b1b.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) # XNNPack dependency. tf_http_archive( name = "KleidiAI", - sha256 = "e1a3a6a27dcae459e61c33f5eb235a7c809c3208b3b8a649f361a641269ebdc8", - strip_prefix = "kleidiai-8fda0bd9224cad4360c011a09bbb582c5ab7496a", - urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/8fda0bd9224cad4360c011a09bbb582c5ab7496a/kleidiai-8fda0bd9224cad4360c011a09bbb582c5ab7496a.zip"), + sha256 = "88233e427be6579560073267575f00f3b5fc370a31a43bbdd87a1810bd4bf1b6", + strip_prefix = "kleidiai-cddf991af5de49fd34949fa39690e4e906e04074", + urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/cddf991af5de49fd34949fa39690e4e906e04074/kleidiai-cddf991af5de49fd34949fa39690e4e906e04074.zip"), ) tf_http_archive( @@ -789,9 +785,9 @@ def _tf_repositories(): tf_http_archive( name = "pybind11", - urls = tf_mirror_urls("https://github.com/pybind/pybind11/archive/v2.10.4.tar.gz"), - sha256 = "832e2f309c57da9c1e6d4542dedd34b24e4192ecb4d62f6f4866a737454c9970", - strip_prefix = "pybind11-2.10.4", + urls = tf_mirror_urls("https://github.com/pybind/pybind11/archive/v2.13.4.tar.gz"), + sha256 = "efc901aa0aab439a3fea6efeaf930b5a349fb06394bf845c64ce15a9cf8f0240", + strip_prefix = "pybind11-2.13.4", build_file = "//third_party:pybind11.BUILD", system_build_file = "//third_party/systemlibs:pybind11.BUILD", ) diff --git a/third_party/gpus/check_cuda_libs.py b/third_party/gpus/check_cuda_libs.py index afd6380b0ac203..b1a10a86b9aac6 100644 --- a/third_party/gpus/check_cuda_libs.py +++ b/third_party/gpus/check_cuda_libs.py @@ -14,6 +14,9 @@ # ============================================================================== """Verifies that a list of libraries is installed on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + Takes a list of arguments with every two subsequent arguments being a logical tuple of (path, check_soname). The path to the library and either True or False to indicate whether to check the soname field on the shared library. diff --git a/third_party/gpus/compiler_common_tools.bzl b/third_party/gpus/compiler_common_tools.bzl new file mode 100644 index 00000000000000..bd07f49ec457bb --- /dev/null +++ b/third_party/gpus/compiler_common_tools.bzl @@ -0,0 +1,174 @@ +"""Common compiler functions. """ + +load( + "//third_party/remote_config:common.bzl", + "err_out", + "raw_exec", + "realpath", +) + +def to_list_of_strings(elements): + """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. + + This is to be used to put a list of strings into the bzl file templates + so it gets interpreted as list of strings in Starlark. + + Args: + elements: list of string elements + + Returns: + single string of elements wrapped in quotes separated by a comma.""" + quoted_strings = ["\"" + element + "\"" for element in elements] + return ", ".join(quoted_strings) + +_INC_DIR_MARKER_BEGIN = "#include <...>" + +# OSX add " (framework directory)" at the end of line, strip it. +_OSX_FRAMEWORK_SUFFIX = " (framework directory)" +_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) + +# TODO(dzc): Once these functions have been factored out of Bazel's +# cc_configure.bzl, load them from @bazel_tools instead. +def _cxx_inc_convert(path): + """Convert path returned by cc -E xc++ in a complete path.""" + path = path.strip() + if path.endswith(_OSX_FRAMEWORK_SUFFIX): + path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() + return path + +def _normalize_include_path(repository_ctx, path): + """Normalizes include paths before writing them to the crosstool. + + If path points inside the 'crosstool' folder of the repository, a relative + path is returned. + If path points outside the 'crosstool' folder, an absolute path is returned. + """ + path = str(repository_ctx.path(path)) + crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) + + if path.startswith(crosstool_folder): + # We drop the path to "$REPO/crosstool" and a trailing path separator. + return path[len(crosstool_folder) + 1:] + return path + +def _is_compiler_option_supported(repository_ctx, cc, option): + """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" + result = repository_ctx.execute([ + cc, + option, + "-o", + "/dev/null", + "-c", + str(repository_ctx.path("tools/cpp/empty.cc")), + ]) + return result.stderr.find(option) == -1 + +def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sys_root): + """Compute the list of default C or C++ include directories.""" + if lang_is_cpp: + lang = "c++" + else: + lang = "c" + sysroot = [] + if tf_sys_root: + sysroot += ["--sysroot", tf_sys_root] + result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + + sysroot) + stderr = err_out(result) + index1 = stderr.find(_INC_DIR_MARKER_BEGIN) + if index1 == -1: + return [] + index1 = stderr.find("\n", index1) + if index1 == -1: + return [] + index2 = stderr.rfind("\n ") + if index2 == -1 or index2 < index1: + return [] + index2 = stderr.find("\n", index2 + 1) + if index2 == -1: + inc_dirs = stderr[index1 + 1:] + else: + inc_dirs = stderr[index1 + 1:index2].strip() + + print_resource_dir_supported = _is_compiler_option_supported( + repository_ctx, + cc, + "-print-resource-dir", + ) + + if print_resource_dir_supported: + resource_dir = repository_ctx.execute( + [cc, "-print-resource-dir"], + ).stdout.strip() + "/share" + inc_dirs += "\n" + resource_dir + + compiler_includes = [ + _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) + for p in inc_dirs.split("\n") + ] + + # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc + # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) + # but Bazel might encounter either (usually reported by the compiler) + # especially when a compiler wrapper (e.g. ccache) is used. + # So we need to also include paths where symlinks are not resolved. + + # Try to find real path to CC installation to "see through" compiler wrappers + # GCC has the path to g++ + index1 = result.stderr.find("COLLECT_GCC=") + if index1 != -1: + index1 = result.stderr.find("=", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname + else: + # Clang has the directory + index1 = result.stderr.find("InstalledDir: ") + if index1 != -1: + index1 = result.stderr.find(" ", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname + else: + # Fallback to the CC path + cc_topdir = repository_ctx.path(cc).dirname.dirname + + # We now have the compiler installation prefix, e.g. /symlink/gcc + # And the resolved installation prefix, e.g. /opt/gcc + cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() + cc_topdir = str(cc_topdir).strip() + + # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. + # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] + # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] + if cc_topdir_resolved != cc_topdir: + unresolved_compiler_includes = [ + cc_topdir + inc[len(cc_topdir_resolved):] + for inc in compiler_includes + if inc.startswith(cc_topdir_resolved) + ] + compiler_includes = compiler_includes + unresolved_compiler_includes + return compiler_includes + +def get_cxx_inc_directories(repository_ctx, cc, tf_sys_root): + """Compute the list of default C and C++ include directories.""" + + # For some reason `clang -xc` sometimes returns include paths that are + # different from the ones from `clang -xc++`. (Symlink and a dir) + # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists + includes_cpp = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + True, + tf_sys_root, + ) + includes_c = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + False, + tf_sys_root, + ) + + return includes_cpp + [ + inc + for inc in includes_c + if inc not in includes_cpp + ] diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl index 8eda7a1cf6ac2b..b9553d9b99ecfe 100644 --- a/third_party/gpus/crosstool/BUILD.tpl +++ b/third_party/gpus/crosstool/BUILD.tpl @@ -2,6 +2,7 @@ # Update cuda_configure.bzl#verify_build_defines when adding new variables. load(":cc_toolchain_config.bzl", "cc_toolchain_config") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") licenses(["restricted"]) @@ -133,9 +134,17 @@ filegroup( srcs = [], ) +filegroup( + name = "cuda_nvcc_files", + srcs = %{cuda_nvcc_files}, +) + filegroup( name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + srcs = [ + ":cuda_nvcc_files", + ":clang/bin/crosstool_wrapper_driver_is_not_gcc" + ], ) filegroup( diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index c46e09484fdfad..eb3a1d8c8ddf02 100644 --- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -181,6 +181,9 @@ def InvokeNvcc(argv, log=False): nvccopts += ['--keep', '--keep-dir', tempdir] # Force C++17 dialect (note, everything in just one string!) nvccopts += ['--std c++17'] + # This is so that nvcc does not complain about MSVC or CLANG. + nvccopts += ['-allow-unsupported-compiler'] + nvccopts += ['--expt-extended-lambda', '--expt-relaxed-constexpr'] if log: Log([NVCC_PATH] + nvccopts) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index 44cdbe34b25f86..094431dcedfc12 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -1,6 +1,10 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Please use `hermetic/cuda_configure` instead. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like @@ -144,7 +148,6 @@ cc_library( name = "cusolver", srcs = ["cuda/lib/%{cusolver_lib}"], data = ["cuda/lib/%{cusolver_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -220,7 +223,6 @@ cc_library( name = "cusparse", srcs = ["cuda/lib/%{cusparse_lib}"], data = ["cuda/lib/%{cusparse_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -242,6 +244,41 @@ py_library( srcs = ["cuda/cuda_config.py"], ) +# Build setting that is always true (i.e. it can not be changed on the +# command line). It is used to create the config settings below that are +# always or never satisfied. +bool_setting( + name = "true_setting", + visibility = ["//visibility:private"], + build_setting_default = True, +) + +# Config settings whether TensorFlow is built with hermetic CUDA. +# These configs are never satisfied. +config_setting( + name = "hermetic_cuda_tools", + flag_values = {":true_setting": "False"}, +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":true_setting": "False"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + %{copy_rules} cc_library( diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl index dee0e898d9ae7a..6b25c8398a7144 100644 --- a/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/gpus/cuda/BUILD.windows.tpl @@ -1,3 +1,7 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Hermetic CUDA repository rule doesn't support Windows. +# Please use `hermetic/cuda_configure`. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index bc865cecb3240a..d1c50ea6377b9e 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -104,9 +104,16 @@ def if_cuda_newer_than(wanted_ver, if_true, if_false = []): wanted_major = int(wanted_ver.split('_')[0]) wanted_minor = int(wanted_ver.split('_')[1]) - configured_version = "%{cuda_version}" - configured_major = int(configured_version.split('.')[0]) - configured_minor = int(configured_version.split('.')[1]) + # Strip "64_" which appears in the CUDA version on Windows. + configured_version = "%{cuda_version}".rsplit("_", 1)[-1] + configured_version_parts = configured_version.split('.') + + # On Windows, the major and minor versions are concatenated without a period and the minor only contains one digit. + if len(configured_version_parts) == 1: + configured_version_parts = [configured_version[0:-1], configured_version[-1:]] + + configured_major = int(configured_version_parts[0]) + configured_minor = int(configured_version_parts[1]) if %{cuda_is_configured} and (wanted_major, wanted_minor) <= (configured_major, configured_minor): return select({"//conditions:default": if_true}) @@ -142,9 +149,13 @@ def cuda_header_library( **kwargs ) -def cuda_library(copts = [], **kwargs): +def cuda_library(copts = [], tags = [],**kwargs): """Wrapper over cc_library which adds default CUDA options.""" - native.cc_library(copts = cuda_default_copts() + copts, **kwargs) + native.cc_library( + copts = cuda_default_copts() + copts, + tags = tags + ["gpu"], + **kwargs + ) def cuda_cc_test(copts = [], **kwargs): """Wrapper over cc_test which adds default CUDA options.""" diff --git a/third_party/gpus/cuda/hermetic/BUILD b/third_party/gpus/cuda/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/gpus/cuda/hermetic/BUILD.tpl new file mode 100644 index 00000000000000..ccf1b9a030d5ad --- /dev/null +++ b/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -0,0 +1,266 @@ +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cc_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + ], + deps = [":cudart_headers", + ":cublas_headers", + ":cccl_headers", + ":nvtx_headers", + ":nvcc_headers", + ":cusolver_headers", + ":cufft_headers", + ":cusparse_headers", + ":curand_headers", + ":cupti_headers", + ":nvml_headers"], +) + +cc_library( + name = "cudart_static", + srcs = ["@cuda_cudart//:static"], + linkopts = [ + "-ldl", + "-lpthread", + %{cudart_static_linkopt} + ], +) + +alias( + name = "cuda_driver", + actual = "@cuda_cudart//:cuda_driver", +) + +alias( + name = "cudart_headers", + actual = "@cuda_cudart//:headers", +) + +alias( + name = "cudart", + actual = "@cuda_cudart//:cudart", +) + +alias( + name = "nvtx_headers", + actual = "@cuda_nvtx//:headers", +) + +alias( + name = "nvml_headers", + actual = "@cuda_nvml//:headers", +) + +alias( + name = "nvcc_headers", + actual = "@cuda_nvcc//:headers", +) + +alias( + name = "cccl_headers", + actual = "@cuda_cccl//:headers", +) + +alias( + name = "cublas_headers", + actual = "@cuda_cublas//:headers", +) + +alias( + name = "cusolver_headers", + actual = "@cuda_cusolver//:headers", +) + +alias( + name = "cufft_headers", + actual = "@cuda_cufft//:headers", +) + +alias( + name = "cusparse_headers", + actual = "@cuda_cusparse//:headers", +) + +alias( + name = "curand_headers", + actual = "@cuda_curand//:headers", +) + +alias( + name = "cublas", + actual = "@cuda_cublas//:cublas", +) + +alias( + name = "cublasLt", + actual = "@cuda_cublas//:cublasLt", +) + +alias( + name = "cusolver", + actual = "@cuda_cusolver//:cusolver", +) + +alias( + name = "cudnn", + actual = "@cuda_cudnn//:cudnn", +) + +alias( + name = "cudnn_header", + actual = "@cuda_cudnn//:headers", +) + +alias( + name = "cufft", + actual = "@cuda_cufft//:cufft", +) + +alias( + name = "curand", + actual = "@cuda_curand//:curand", +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = ":cuda_headers", +) + +alias( + name = "cupti_headers", + actual = "@cuda_cupti//:headers", +) + +alias( + name = "cupti_dsos", + actual = "@cuda_cupti//:cupti", +) + +alias( + name = "cusparse", + actual = "@cuda_cusparse//:cusparse", +) + +alias( + name = "cuda-nvvm", + actual = "@cuda_nvcc//:nvvm", +) + +alias( + name = "nvjitlink", + actual = "@cuda_nvjitlink//:nvjitlink" +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +# Config setting whether TensorFlow is built with hermetic CUDA. +alias( + name = "hermetic_cuda_tools", + actual = "@local_config_cuda//:is_cuda_enabled", +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":include_hermetic_cuda_libs": "True"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + +cc_library( + # This is not yet fully supported, but we need the rule + # to make bazel query happy. + name = "nvptxcompiler", +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl new file mode 100644 index 00000000000000..85c0cbbb196fef --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl @@ -0,0 +1,15 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + hdrs = glob([ + %{comment}"include/cub/**", + %{comment}"include/cuda/**", + %{comment}"include/nv/**", + %{comment}"include/thrust/**", + ]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/gpus/cuda/hermetic/cuda_configure.bzl new file mode 100644 index 00000000000000..270b73c3884855 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -0,0 +1,521 @@ +"""Repository rule for hermetic CUDA autoconfiguration. + +`cuda_configure` depends on the following environment variables: + + * `TF_NEED_CUDA`: Whether to enable building with CUDA. + * `TF_NVCC_CLANG`: Whether to use clang for C++ and NVCC for Cuda compilation. + * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for + both host and device code compilation. + * `TF_SYSROOT`: The sysroot to use when compiling. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + * `HERMETIC_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default + is `3.5,5.2`. If not specified, the value will be determined by the + `TF_CUDA_COMPUTE_CAPABILITIES`. + * `PYTHON_BIN_PATH`: The python binary path +""" + +load( + "//third_party/gpus:compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", + "which", +) + +def _find_cc(repository_ctx): + """Find the C++ compiler.""" + cc_path_envvar = _CLANG_CUDA_COMPILER_PATH + cc_name = "clang" + + cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) + if cc_name_from_env: + cc_name = cc_name_from_env + if cc_name.startswith("/"): + # Return the absolute path. + return cc_name + cc = which(repository_ctx, cc_name) + if cc == None: + fail(("Cannot find {}, either correct your path or set the {}" + + " environment variable").format(cc_name, cc_path_envvar)) + return cc + +def _auto_configure_fail(msg): + """Output failure message when cuda configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _verify_build_defines(params): + """Verify all variables that crosstool/BUILD.tpl expects are substituted. + + Args: + params: dict of variables that will be passed to the BUILD.tpl template. + """ + missing = [] + for param in [ + "cxx_builtin_include_directories", + "extra_no_canonical_prefixes_flags", + "host_compiler_path", + "host_compiler_prefix", + "host_compiler_warnings", + "linker_bin_path", + "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", + "unfiltered_compile_flags", + "win_compiler_deps", + ]: + if ("%{" + param + "}") not in params: + missing.append(param) + + if missing: + _auto_configure_fail( + "BUILD.tpl template is missing these variables: " + + str(missing) + + ".\nWe only got: " + + str(params) + + ".", + ) + +def get_cuda_version(repository_ctx): + return (get_host_environ(repository_ctx, HERMETIC_CUDA_VERSION) or + get_host_environ(repository_ctx, TF_CUDA_VERSION)) + +def enable_cuda(repository_ctx): + """Returns whether to build with CUDA support.""" + return int(get_host_environ(repository_ctx, TF_NEED_CUDA, False)) + +def _flag_enabled(repository_ctx, flag_name): + return get_host_environ(repository_ctx, flag_name) == "1" + +def _use_nvcc_and_clang(repository_ctx): + # Returns the flag if we need to use clang for C++ and NVCC for Cuda. + return _flag_enabled(repository_ctx, _TF_NVCC_CLANG) + +def _tf_sysroot(repository_ctx): + return get_host_environ(repository_ctx, _TF_SYSROOT, "") + +def _py_tmpl_dict(d): + return {"%{cuda_config}": str(d)} + +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "\"\"," if cpu_value == "Darwin" else "\"-lrt\"," + +def _compute_capabilities(repository_ctx): + """Returns a list of strings representing cuda compute capabilities. + + Args: + repository_ctx: the repo rule's context. + + Returns: + list of cuda architectures to compile for. 'compute_xy' refers to + both PTX and SASS, 'sm_xy' refers to SASS only. + """ + capabilities = (get_host_environ( + repository_ctx, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + ) or + get_host_environ( + repository_ctx, + _TF_CUDA_COMPUTE_CAPABILITIES, + )) + capabilities = (capabilities or "compute_35,compute_52").split(",") + + # Map old 'x.y' capabilities to 'compute_xy'. + if len(capabilities) > 0 and all([len(x.split(".")) == 2 for x in capabilities]): + # If all capabilities are in 'x.y' format, only include PTX for the + # highest capability. + cc_list = sorted([x.replace(".", "") for x in capabilities]) + capabilities = [ + "sm_%s" % x + for x in cc_list[:-1] + ] + ["compute_%s" % cc_list[-1]] + for i, capability in enumerate(capabilities): + parts = capability.split(".") + if len(parts) != 2: + continue + capabilities[i] = "compute_%s%s" % (parts[0], parts[1]) + + # Make list unique + capabilities = dict(zip(capabilities, capabilities)).keys() + + # Validate capabilities. + for capability in capabilities: + if not capability.startswith(("compute_", "sm_")): + _auto_configure_fail("Invalid compute capability: %s" % capability) + for prefix in ["compute_", "sm_"]: + if not capability.startswith(prefix): + continue + if len(capability) == len(prefix) + 2 and capability[-2:].isdigit(): + continue + if len(capability) == len(prefix) + 3 and capability.endswith("90a"): + continue + _auto_configure_fail("Invalid compute capability: %s" % capability) + + return capabilities + +def _compute_cuda_extra_copts(compute_capabilities): + copts = ["--no-cuda-include-ptx=all"] + for capability in compute_capabilities: + if capability.startswith("compute_"): + capability = capability.replace("compute_", "sm_") + copts.append("--cuda-include-ptx=%s" % capability) + copts.append("--cuda-gpu-arch=%s" % capability) + + return str(copts) + +def _get_cuda_config(repository_ctx): + """Detects and returns information about the CUDA installation on the system. + + Args: + repository_ctx: The repository context. + + Returns: + A struct containing the following fields: + cuda_version: The version of CUDA on the system. + cudart_version: The CUDA runtime version on the system. + cudnn_version: The version of cuDNN on the system. + compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. + """ + + return struct( + cuda_version = get_cuda_version(repository_ctx), + cupti_version = repository_ctx.read(repository_ctx.attr.cupti_version), + cudart_version = repository_ctx.read(repository_ctx.attr.cudart_version), + cublas_version = repository_ctx.read(repository_ctx.attr.cublas_version), + cusolver_version = repository_ctx.read(repository_ctx.attr.cusolver_version), + curand_version = repository_ctx.read(repository_ctx.attr.curand_version), + cufft_version = repository_ctx.read(repository_ctx.attr.cufft_version), + cusparse_version = repository_ctx.read(repository_ctx.attr.cusparse_version), + cudnn_version = repository_ctx.read(repository_ctx.attr.cudnn_version), + compute_capabilities = _compute_capabilities(repository_ctx), + cpu_value = get_cpu_value(repository_ctx), + ) + +_DUMMY_CROSSTOOL_BZL_FILE = """ +def error_gpu_disabled(): + fail("ERROR: Building with --config=cuda but TensorFlow is not configured " + + "to build with GPU support. Please re-run ./configure and enter 'Y' " + + "at the prompt to build with GPU support.") + + native.genrule( + name = "error_gen_crosstool", + outs = ["CROSSTOOL"], + cmd = "echo 'Should not be run.' && exit 1", + ) + + native.filegroup( + name = "crosstool", + srcs = [":CROSSTOOL"], + output_licenses = ["unencumbered"], + ) +""" + +_DUMMY_CROSSTOOL_BUILD_FILE = """ +load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") + +error_gpu_disabled() +""" + +def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + + # Set up BUILD file for cuda/. + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "False", + "%{cuda_extra_copts}": "[]", + "%{cuda_gpu_architectures}": "[]", + "%{cuda_version}": "0.0", + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + }, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": "", + "%{cudart_version}": "", + "%{cupti_version}": "", + "%{cublas_version}": "", + "%{cusolver_version}": "", + "%{curand_version}": "", + "%{cufft_version}": "", + "%{cusparse_version}": "", + "%{cudnn_version}": "", + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({}), + ) + + # If cuda_configure is not configured to build with GPU support, and the user + # attempts to build with --config=cuda, add a dummy build rule to intercept + # this and fail with an actionable error message. + repository_ctx.file( + "crosstool/error_gpu_disabled.bzl", + _DUMMY_CROSSTOOL_BZL_FILE, + ) + repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) + +def _create_local_cuda_repository(repository_ctx): + """Creates the repository containing files set up to build with CUDA.""" + cuda_config = _get_cuda_config(repository_ctx) + + # Set up BUILD file for cuda/ + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + cuda_config.compute_capabilities, + ), + "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), + "%{cuda_version}": cuda_config.cuda_version, + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt( + cuda_config.cpu_value, + ), + }, + ) + + is_nvcc_and_clang = _use_nvcc_and_clang(repository_ctx) + tf_sysroot = _tf_sysroot(repository_ctx) + + # Set up crosstool/ + cc = _find_cc(repository_ctx) + host_compiler_includes = get_cxx_inc_directories( + repository_ctx, + cc, + tf_sysroot, + ) + + cuda_defines = {} + + # We do not support hermetic CUDA on Windows. + # This ensures the CROSSTOOL file parser is happy. + cuda_defines.update({ + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + "%{win_compiler_deps}": ":empty", + }) + + cuda_defines["%{builtin_sysroot}"] = tf_sysroot + cuda_defines["%{cuda_toolkit_path}"] = repository_ctx.attr.nvcc_binary.workspace_root + cuda_defines["%{compiler}"] = "clang" + cuda_defines["%{host_compiler_prefix}"] = "/usr/bin" + cuda_defines["%{linker_bin_path}"] = "" + cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" + cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( + host_compiler_includes, + ) + cuda_defines["%{cuda_nvcc_files}"] = "if_cuda([\"@{nvcc_archive}//:bin\", \"@{nvcc_archive}//:nvvm\"])".format( + nvcc_archive = repository_ctx.attr.nvcc_binary.repo_name, + ) + + if not is_nvcc_and_clang: + cuda_defines["%{host_compiler_path}"] = str(cc) + cuda_defines["%{host_compiler_warnings}"] = """ + # Some parts of the codebase set -Werror and hit this warning, so + # switch it off for now. + "-Wno-invalid-partial-specialization" + """ + cuda_defines["%{compiler_deps}"] = ":cuda_nvcc_files" + repository_ctx.file( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + "", + ) + else: + cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{host_compiler_warnings}"] = "" + + nvcc_relative_path = "%s/%s" % ( + repository_ctx.attr.nvcc_binary.workspace_root, + repository_ctx.attr.nvcc_binary.name, + ) + cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + + wrapper_defines = { + "%{cpu_compiler}": str(cc), + "%{cuda_version}": cuda_config.cuda_version, + "%{nvcc_path}": nvcc_relative_path, + "%{host_compiler_path}": str(cc), + "%{use_clang_compiler}": "True", + } + repository_ctx.template( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + repository_ctx.attr.crosstool_wrapper_driver_is_not_gcc_tpl, + wrapper_defines, + ) + + _verify_build_defines(cuda_defines) + + # Only expand template variables in the BUILD file + repository_ctx.template( + "crosstool/BUILD", + repository_ctx.attr.crosstool_build_tpl, + cuda_defines, + ) + + # No templating of cc_toolchain_config - use attributes and templatize the + # BUILD file. + repository_ctx.template( + "crosstool/cc_toolchain_config.bzl", + repository_ctx.attr.cc_toolchain_config_tpl, + {}, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": cuda_config.cuda_version, + "%{cudart_version}": cuda_config.cudart_version, + "%{cupti_version}": cuda_config.cupti_version, + "%{cublas_version}": cuda_config.cublas_version, + "%{cusolver_version}": cuda_config.cusolver_version, + "%{curand_version}": cuda_config.curand_version, + "%{cufft_version}": cuda_config.cufft_version, + "%{cusparse_version}": cuda_config.cusparse_version, + "%{cudnn_version}": cuda_config.cudnn_version, + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": ", ".join([ + cc.split("_")[1] + for cc in cuda_config.compute_capabilities + ]), + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({ + "cuda_version": cuda_config.cuda_version, + "cudnn_version": cuda_config.cudnn_version, + "cuda_compute_capabilities": cuda_config.compute_capabilities, + "cpu_compiler": str(cc), + }), + ) + +def _cuda_autoconf_impl(repository_ctx): + """Implementation of the cuda_autoconf repository rule.""" + build_file = repository_ctx.attr.local_config_cuda_build_file + + if not enable_cuda(repository_ctx): + _create_dummy_repository(repository_ctx) + else: + _create_local_cuda_repository(repository_ctx) + + repository_ctx.symlink(build_file, "BUILD") + +_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH" +_PYTHON_BIN_PATH = "PYTHON_BIN_PATH" +_HERMETIC_CUDA_COMPUTE_CAPABILITIES = "HERMETIC_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +HERMETIC_CUDA_VERSION = "HERMETIC_CUDA_VERSION" +TF_CUDA_VERSION = "TF_CUDA_VERSION" +TF_NEED_CUDA = "TF_NEED_CUDA" +_TF_NVCC_CLANG = "TF_NVCC_CLANG" +_TF_SYSROOT = "TF_SYSROOT" + +_ENVIRONS = [ + _CLANG_CUDA_COMPILER_PATH, + TF_NEED_CUDA, + _TF_NVCC_CLANG, + TF_CUDA_VERSION, + HERMETIC_CUDA_VERSION, + _TF_CUDA_COMPUTE_CAPABILITIES, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + _TF_SYSROOT, + _PYTHON_BIN_PATH, + "TMP", + "TMPDIR", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", +] + +cuda_configure = repository_rule( + implementation = _cuda_autoconf_impl, + environ = _ENVIRONS, + attrs = { + "environ": attr.string_dict(), + "cublas_version": attr.label(default = Label("@cuda_cublas//:version.txt")), + "cudart_version": attr.label(default = Label("@cuda_cudart//:version.txt")), + "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.txt")), + "cufft_version": attr.label(default = Label("@cuda_cufft//:version.txt")), + "cupti_version": attr.label(default = Label("@cuda_cupti//:version.txt")), + "curand_version": attr.label(default = Label("@cuda_curand//:version.txt")), + "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.txt")), + "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.txt")), + "nvcc_binary": attr.label(default = Label("@cuda_nvcc//:bin/nvcc")), + "local_config_cuda_build_file": attr.label(default = Label("//third_party/gpus:local_config_cuda.BUILD")), + "build_defs_tpl": attr.label(default = Label("//third_party/gpus/cuda:build_defs.bzl.tpl")), + "cuda_build_tpl": attr.label(default = Label("//third_party/gpus/cuda/hermetic:BUILD.tpl")), + "cuda_config_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.h.tpl")), + "cuda_config_py_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.py.tpl")), + "crosstool_wrapper_driver_is_not_gcc_tpl": attr.label(default = Label("//third_party/gpus/crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl")), + "crosstool_build_tpl": attr.label(default = Label("//third_party/gpus/crosstool:BUILD.tpl")), + "cc_toolchain_config_tpl": attr.label(default = Label("//third_party/gpus/crosstool:cc_toolchain_config.bzl.tpl")), + }, +) +"""Detects and configures the hermetic CUDA toolchain. + +Add the following to your WORKSPACE file: + +```python +cuda_configure(name = "local_config_cuda") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl new file mode 100644 index 00000000000000..510235d801de4e --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -0,0 +1,44 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cublas_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublas.so.%{libcublas_version}", + deps = [":cublasLt"], +) + +cc_import( + name = "cublasLt_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublasLt.so.%{libcublaslt_version}", +) +%{multiline_comment} +cc_library( + name = "cublas", + visibility = ["//visibility:public"], + %{comment}deps = [":cublas_shared_library"], +) + +cc_library( + name = "cublasLt", + visibility = ["//visibility:public"], + %{comment}deps = [":cublasLt_shared_library"], +) + +cc_library( + name = "headers", + %{comment}hdrs = [ + %{comment}"include/cublas.h", + %{comment}"include/cublasLt.h", + %{comment}"include/cublas_api.h", + %{comment}"include/cublas_v2.h", + %{comment}], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl new file mode 100644 index 00000000000000..f7ba469b42b76a --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -0,0 +1,126 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) + +filegroup( + name = "static", + srcs = ["lib/libcudart_static.a"], + visibility = ["@local_config_cuda//cuda:__pkg__"], +) +%{multiline_comment} +# TODO: Replace system provided library with hermetic NVIDIA driver library. +cc_import( + name = "cuda_driver_shared_library", + interface_library = "lib/stubs/libcuda.so", + system_provided = 1, +) + +cc_import( + name = "cudart_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcudart.so.%{libcudart_version}", +) +%{multiline_comment} +cc_library( + name = "cuda_driver", + %{comment}deps = [":cuda_driver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart", + %{comment}deps = [ + %{comment}":cuda_driver", + %{comment}":cudart_shared_library", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/builtin_types.h", + %{comment}"include/channel_descriptor.h", + %{comment}"include/common_functions.h", + %{comment}"include/cooperative_groups/**", + %{comment}"include/cooperative_groups.h", + %{comment}"include/cuComplex.h", + %{comment}"include/cuda.h", + %{comment}"include/cudaEGL.h", + %{comment}"include/cudaEGLTypedefs.h", + %{comment}"include/cudaGL.h", + %{comment}"include/cudaGLTypedefs.h", + %{comment}"include/cudaProfilerTypedefs.h", + %{comment}"include/cudaTypedefs.h", + %{comment}"include/cudaVDPAU.h", + %{comment}"include/cudaVDPAUTypedefs.h", + %{comment}"include/cuda_awbarrier.h", + %{comment}"include/cuda_awbarrier_helpers.h", + %{comment}"include/cuda_awbarrier_primitives.h", + %{comment}"include/cuda_bf16.h", + %{comment}"include/cuda_bf16.hpp", + %{comment}"include/cuda_device_runtime_api.h", + %{comment}"include/cuda_egl_interop.h", + %{comment}"include/cuda_fp16.h", + %{comment}"include/cuda_fp16.hpp", + %{comment}"include/cuda_fp8.h", + %{comment}"include/cuda_fp8.hpp", + %{comment}"include/cuda_gl_interop.h", + %{comment}"include/cuda_occupancy.h", + %{comment}"include/cuda_pipeline.h", + %{comment}"include/cuda_pipeline_helpers.h", + %{comment}"include/cuda_pipeline_primitives.h", + %{comment}"include/cuda_runtime.h", + %{comment}"include/cuda_runtime_api.h", + %{comment}"include/cuda_surface_types.h", + %{comment}"include/cuda_texture_types.h", + %{comment}"include/cuda_vdpau_interop.h", + %{comment}"include/cudart_platform.h", + %{comment}"include/device_atomic_functions.h", + %{comment}"include/device_atomic_functions.hpp", + %{comment}"include/device_double_functions.h", + %{comment}"include/device_functions.h", + %{comment}"include/device_launch_parameters.h", + %{comment}"include/device_types.h", + %{comment}"include/driver_functions.h", + %{comment}"include/driver_types.h", + %{comment}"include/host_config.h", + %{comment}"include/host_defines.h", + %{comment}"include/library_types.h", + %{comment}"include/math_constants.h", + %{comment}"include/math_functions.h", + %{comment}"include/mma.h", + %{comment}"include/nvfunctional", + %{comment}"include/sm_20_atomic_functions.h", + %{comment}"include/sm_20_atomic_functions.hpp", + %{comment}"include/sm_20_intrinsics.h", + %{comment}"include/sm_20_intrinsics.hpp", + %{comment}"include/sm_30_intrinsics.h", + %{comment}"include/sm_30_intrinsics.hpp", + %{comment}"include/sm_32_atomic_functions.h", + %{comment}"include/sm_32_atomic_functions.hpp", + %{comment}"include/sm_32_intrinsics.h", + %{comment}"include/sm_32_intrinsics.hpp", + %{comment}"include/sm_35_atomic_functions.h", + %{comment}"include/sm_35_intrinsics.h", + %{comment}"include/sm_60_atomic_functions.h", + %{comment}"include/sm_60_atomic_functions.hpp", + %{comment}"include/sm_61_intrinsics.h", + %{comment}"include/sm_61_intrinsics.hpp", + %{comment}"include/surface_functions.h", + %{comment}"include/surface_indirect_functions.h", + %{comment}"include/surface_types.h", + %{comment}"include/texture_fetch_functions.h", + %{comment}"include/texture_indirect_functions.h", + %{comment}"include/texture_types.h", + %{comment}"include/vector_functions.h", + %{comment}"include/vector_functions.hpp", + %{comment}"include/vector_types.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl new file mode 100644 index 00000000000000..165c5b1579e73f --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -0,0 +1,73 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_infer.so.%{libcudnn_ops_infer_version}", +) + +cc_import( + name = "cudnn_cnn_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_infer.so.%{libcudnn_cnn_infer_version}", +) + +cc_import( + name = "cudnn_ops_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_train.so.%{libcudnn_ops_train_version}", +) + +cc_import( + name = "cudnn_cnn_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_train.so.%{libcudnn_cnn_train_version}", +) + +cc_import( + name = "cudnn_adv_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_infer.so.%{libcudnn_adv_infer_version}", +) + +cc_import( + name = "cudnn_adv_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_train.so.%{libcudnn_adv_train_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_ops_infer", + %{comment}":cudnn_ops_train", + %{comment}":cudnn_cnn_infer", + %{comment}":cudnn_cnn_train", + %{comment}":cudnn_adv_infer", + %{comment}":cudnn_adv_train", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl new file mode 100644 index 00000000000000..7f36054a51bb5b --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -0,0 +1,80 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops.so.%{libcudnn_ops_version}", +) + +cc_import( + name = "cudnn_cnn", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn.so.%{libcudnn_cnn_version}", +) + +cc_import( + name = "cudnn_adv", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv.so.%{libcudnn_adv_version}", +) + +cc_import( + name = "cudnn_graph", + hdrs = [":headers"], + shared_library = "lib/libcudnn_graph.so.%{libcudnn_graph_version}", +) + +cc_import( + name = "cudnn_engines_precompiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_precompiled.so.%{libcudnn_engines_precompiled_version}", +) + +cc_import( + name = "cudnn_engines_runtime_compiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_runtime_compiled.so.%{libcudnn_engines_runtime_compiled_version}", +) + +cc_import( + name = "cudnn_heuristic", + hdrs = [":headers"], + shared_library = "lib/libcudnn_heuristic.so.%{libcudnn_heuristic_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_engines_precompiled", + %{comment}":cudnn_ops", + %{comment}":cudnn_graph", + %{comment}":cudnn_cnn", + %{comment}":cudnn_adv", + %{comment}":cudnn_engines_runtime_compiled", + %{comment}":cudnn_heuristic", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl new file mode 100644 index 00000000000000..48ccb0ea3cd197 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -0,0 +1,29 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cufft_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcufft.so.%{libcufft_version}", +) +%{multiline_comment} +cc_library( + name = "cufft", + %{comment}deps = [":cufft_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudalibxt.h", + %{comment}"include/cufft*.h" + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl new file mode 100644 index 00000000000000..3efe76f470953f --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -0,0 +1,59 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cupti_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcupti.so.%{libcupti_version}", +) +%{multiline_comment} +cc_library( + name = "cupti", + %{comment}deps = [":cupti_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/Openacc/**", + %{comment}"include/Openmp/**", + %{comment}"include/cuda_stdint.h", + %{comment}"include/cupti.h", + %{comment}"include/cupti_activity.h", + %{comment}"include/cupti_activity_deprecated.h", + %{comment}"include/cupti_callbacks.h", + %{comment}"include/cupti_checkpoint.h", + %{comment}"include/cupti_driver_cbid.h", + %{comment}"include/cupti_events.h", + %{comment}"include/cupti_metrics.h", + %{comment}"include/cupti_nvtx_cbid.h", + %{comment}"include/cupti_pcsampling.h", + %{comment}"include/cupti_pcsampling_util.h", + %{comment}"include/cupti_profiler_target.h", + %{comment}"include/cupti_result.h", + %{comment}"include/cupti_runtime_cbid.h", + %{comment}"include/cupti_sass_metrics.h", + %{comment}"include/cupti_target.h", + %{comment}"include/cupti_version.h", + %{comment}"include/generated_cudaGL_meta.h", + %{comment}"include/generated_cudaVDPAU_meta.h", + %{comment}"include/generated_cuda_gl_interop_meta.h", + %{comment}"include/generated_cuda_meta.h", + %{comment}"include/generated_cuda_runtime_api_meta.h", + %{comment}"include/generated_cuda_vdpau_interop_meta.h", + %{comment}"include/generated_cudart_removed_meta.h", + %{comment}"include/generated_nvtx_meta.h", + %{comment}"include/nvperf_common.h", + %{comment}"include/nvperf_cuda_host.h", + %{comment}"include/nvperf_host.h", + %{comment}"include/nvperf_target.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/extras/CUPTI/include", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl new file mode 100644 index 00000000000000..50e5a8f18a96fd --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -0,0 +1,26 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "curand_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcurand.so.%{libcurand_version}", +) +%{multiline_comment} +cc_library( + name = "curand", + %{comment}deps = [":curand_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob(["include/curand*.h"]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl new file mode 100644 index 00000000000000..943a08ebeb96e1 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -0,0 +1,34 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusolver_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusolver.so.%{libcusolver_version}", + deps = [ + "@cuda_nvjitlink//:nvjitlink", + "@cuda_cusparse//:cusparse", + "@cuda_cublas//:cublas", + "@cuda_cublas//:cublasLt", + ], +) +%{multiline_comment} +cc_library( + name = "cusolver", + %{comment}deps = [":cusolver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cusolver*.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl new file mode 100644 index 00000000000000..46b24366ce1c04 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -0,0 +1,27 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusparse_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusparse.so.%{libcusparse_version}", + deps = ["@cuda_nvjitlink//:nvjitlink"], +) +%{multiline_comment} +cc_library( + name = "cusparse", + %{comment}deps = [":cusparse_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = ["include/cusparse.h"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl b/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl new file mode 100644 index 00000000000000..fdda3aaf92cea5 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl @@ -0,0 +1,125 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistributions JSON repository initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_JSON_DICT", + "CUDNN_REDIST_JSON_DICT", +) + +def _get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_json_file_content(repository_ctx, url_to_sha256, json_file_name): + if len(url_to_sha256) > 1: + (url, sha256) = url_to_sha256 + else: + url = url_to_sha256[0] + sha256 = "" + repository_ctx.download( + url = tf_mirror_urls(url), + sha256 = sha256, + output = json_file_name, + ) + return repository_ctx.read(repository_ctx.path(json_file_name)) + +def _cuda_redist_json_impl(repository_ctx): + cuda_version = (_get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + _get_env_var(repository_ctx, "TF_CUDA_VERSION")) + local_cuda_path = _get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + cudnn_version = (_get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + _get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + local_cudnn_path = _get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + supported_cuda_versions = repository_ctx.attr.cuda_json_dict.keys() + if (cuda_version and not local_cuda_path and + (cuda_version not in supported_cuda_versions)): + fail( + ("The supported CUDA versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add JSON URL for" + + " CUDA version={version}.") + .format( + supported_versions = supported_cuda_versions, + version = cuda_version, + ), + ) + supported_cudnn_versions = repository_ctx.attr.cudnn_json_dict.keys() + if cudnn_version and not local_cudnn_path and (cudnn_version not in supported_cudnn_versions): + fail( + ("The supported CUDNN versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDNN_VERSION" + + " environment variable or add JSON URL for" + + " CUDNN version={version}.") + .format( + supported_versions = supported_cudnn_versions, + version = cudnn_version, + ), + ) + cuda_redistributions = "{}" + cudnn_redistributions = "{}" + if cuda_version and not local_cuda_path: + cuda_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cuda_json_dict[cuda_version], + "redistrib_cuda_%s.json" % cuda_version, + ) + if cudnn_version and not local_cudnn_path: + cudnn_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cudnn_json_dict[cudnn_version], + "redistrib_cudnn_%s.json" % cudnn_version, + ) + + repository_ctx.file( + "distributions.bzl", + """CUDA_REDISTRIBUTIONS = {cuda_redistributions} + +CUDNN_REDISTRIBUTIONS = {cudnn_redistributions} +""".format( + cuda_redistributions = cuda_redistributions, + cudnn_redistributions = cudnn_redistributions, + ), + ) + repository_ctx.file( + "BUILD", + "", + ) + +cuda_redist_json = repository_rule( + implementation = _cuda_redist_json_impl, + attrs = { + "cuda_json_dict": attr.string_list_dict(mandatory = True), + "cudnn_json_dict": attr.string_list_dict(mandatory = True), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "HERMETIC_CUDNN_VERSION", + "TF_CUDA_VERSION", + "TF_CUDNN_VERSION", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", + ], +) + +def cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT): + cuda_redist_json( + name = "cuda_redist_json", + cuda_json_dict = cuda_json_dict, + cudnn_json_dict = cudnn_json_dict, + ) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl new file mode 100644 index 00000000000000..7757a92a90b795 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl @@ -0,0 +1,75 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "bin/nvcc", +]) + +filegroup( + name = "nvvm", + srcs = [ + "nvvm/libdevice/libdevice.10.bc", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "nvlink", + srcs = [ + "bin/nvlink", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "fatbinary", + srcs = [ + "bin/fatbinary", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin2c", + srcs = [ + "bin/bin2c", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "ptxas", + srcs = [ + "bin/ptxas", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin", + srcs = glob([ + "bin/**", + "nvvm/bin/**", + ]), + visibility = ["//visibility:public"], +) + +filegroup( + name = "link_stub", + srcs = [ + "bin/crt/link.stub", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/crt/**", + %{comment}"include/fatbinary_section.h", + %{comment}"include/nvPTXCompiler.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl new file mode 100644 index 00000000000000..9784a84471f1a7 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -0,0 +1,17 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nvjitlink_shared_library", + shared_library = "lib/libnvJitLink.so.%{libnvjitlink_version}", +) +%{multiline_comment} +cc_library( + name = "nvjitlink", + %{comment}deps = [":nvjitlink_shared_library"], + visibility = ["//visibility:public"], +) + diff --git a/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl new file mode 100644 index 00000000000000..23ee30f09f8ff3 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl @@ -0,0 +1,10 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = ["include/nvml.h"], + include_prefix = "third_party/gpus/cuda/nvml/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl new file mode 100644 index 00000000000000..986ef0c8f76166 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl @@ -0,0 +1,9 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +filegroup( + name = "nvprune", + srcs = [ + "bin/nvprune", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl new file mode 100644 index 00000000000000..de18489b455b79 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl @@ -0,0 +1,20 @@ +licenses(["restricted"]) # NVIDIA proprietary license +%{multiline_comment} +cc_import( + name = "nvrtc_main", + shared_library = "lib/libnvrtc.so.%{libnvrtc_version}", +) + +cc_import( + name = "nvrtc_builtins", + shared_library = "lib/libnvrtc-builtins.so.%{libnvrtc-builtins_version}", +) +%{multiline_comment} +cc_library( + name = "nvrtc", + %{comment}deps = [ + %{comment}":nvrtc_main", + %{comment}":nvrtc_builtins", + %{comment}], + visibility = ["//visibility:public"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl new file mode 100644 index 00000000000000..3457f41a502dee --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl @@ -0,0 +1,13 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nvToolsExt*.h", + %{comment}"include/nvtx3/**", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl new file mode 100644 index 00000000000000..d2015e737540c3 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -0,0 +1,491 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_PATH_PREFIX", + "CUDNN_REDIST_PATH_PREFIX", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +OS_ARCH_DICT = { + "amd64": "x86_64-unknown-linux-gnu", + "aarch64": "aarch64-unknown-linux-gnu", +} +_REDIST_ARCH_DICT = { + "linux-x86_64": "x86_64-unknown-linux-gnu", + "linux-sbsa": "aarch64-unknown-linux-gnu", +} + +SUPPORTED_ARCHIVE_EXTENSIONS = [ + ".zip", + ".jar", + ".war", + ".aar", + ".tar", + ".tar.gz", + ".tgz", + ".tar.xz", + ".txz", + ".tar.zst", + ".tzst", + ".tar.bz2", + ".tbz", + ".ar", + ".deb", + ".whl", +] + +def get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def get_archive_name(url): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the archive name without extension.""" + filename = _get_file_name(url) + for extension in SUPPORTED_ARCHIVE_EXTENSIONS: + if filename.endswith(extension): + return filename[:-len(extension)] + return filename + +LIB_EXTENSION = ".so." + +def _get_lib_name_and_version(path): + extension_index = path.rfind(LIB_EXTENSION) + last_slash_index = path.rfind("/") + lib_name = path[last_slash_index + 1:extension_index] + lib_version = path[extension_index + len(LIB_EXTENSION):] + return (lib_name, lib_version) + +def _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_dir_path = repository_ctx.path("lib") + if not lib_dir_path.exists: + return [] + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]).lower() + lib_dir_content = lib_dir_path.readdir() + return [ + str(f) + for f in lib_dir_content + if (LIB_EXTENSION in str(f) and + main_lib_name in str(f).lower()) + ] + +def get_lib_name_to_version_dict(repository_ctx): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns a dict of library names and major versions.""" + lib_name_to_version_dict = {} + for path in _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_name, lib_version = _get_lib_name_and_version(path) + key = "%%{%s_version}" % lib_name.lower() + + # We need to find either major or major.minor version if there is no + # file with major version. E.g. if we have the following files: + # libcudart.so + # libcudart.so.12 + # libcudart.so.12.3.2, + # we will save save {"%{libcudart_version}": "12"}. + if len(lib_version.split(".")) == 1: + lib_name_to_version_dict[key] = lib_version + if (len(lib_version.split(".")) == 2 and + key not in lib_name_to_version_dict): + lib_name_to_version_dict[key] = lib_version + return lib_name_to_version_dict + +def create_dummy_build_file(repository_ctx, use_comment_symbols = True): + repository_ctx.template( + "BUILD", + repository_ctx.attr.build_templates[0], + { + "%{multiline_comment}": "'''" if use_comment_symbols else "", + "%{comment}": "#" if use_comment_symbols else "", + }, + ) + +def _get_build_template(repository_ctx, major_lib_version): + template = None + for i in range(0, len(repository_ctx.attr.versions)): + for dist_version in repository_ctx.attr.versions[i].split(","): + if dist_version == major_lib_version: + template = repository_ctx.attr.build_templates[i] + break + if not template: + fail("No build template found for {} version {}".format( + repository_ctx.name, + major_lib_version, + )) + return template + +def get_major_library_version(repository_ctx, lib_name_to_version_dict): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the major library version provided the versions dict.""" + major_version = "" + if len(lib_name_to_version_dict) == 0: + return major_version + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]) + key = "%%{%s_version}" % main_lib_name + major_version = lib_name_to_version_dict[key] + return major_version + +def create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_lib_version): + # buildifier: disable=function-docstring-args + """Creates a BUILD file for the repository.""" + if len(major_lib_version) == 0: + build_template_content = repository_ctx.read( + repository_ctx.attr.build_templates[0], + ) + if "_version}" not in build_template_content: + create_dummy_build_file(repository_ctx, use_comment_symbols = False) + else: + create_dummy_build_file(repository_ctx) + return + build_template = _get_build_template( + repository_ctx, + major_lib_version.split(".")[0], + ) + repository_ctx.template( + "BUILD", + build_template, + lib_name_to_version_dict | { + "%{multiline_comment}": "", + "%{comment}": "", + }, + ) + +def _create_symlinks(repository_ctx, local_path, dirs): + for dir in dirs: + repository_ctx.symlink( + "{path}/{dir}".format( + path = local_path, + dir = dir, + ), + dir, + ) + +def use_local_path(repository_ctx, local_path, dirs): + # buildifier: disable=function-docstring-args + """Creates repository using local redistribution paths.""" + _create_symlinks( + repository_ctx, + local_path, + dirs, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _use_local_cuda_path(repository_ctx, local_cuda_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDA repository.""" + use_local_path( + repository_ctx, + local_cuda_path, + ["include", "lib", "bin", "nvvm"], + ) + +def _use_local_cudnn_path(repository_ctx, local_cudnn_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDNN repository.""" + use_local_path(repository_ctx, local_cudnn_path, ["include", "lib"]) + +def _download_redistribution(repository_ctx, arch_key, path_prefix): + (url, sha256) = repository_ctx.attr.url_dict[arch_key] + + # If url is not relative, then appending prefix is not needed. + if not (url.startswith("http") or url.startswith("file:///")): + url = path_prefix + url + archive_name = get_archive_name(url) + file_name = _get_file_name(url) + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + if repository_ctx.attr.override_strip_prefix: + strip_prefix = repository_ctx.attr.override_strip_prefix + else: + strip_prefix = archive_name + repository_ctx.extract( + archive = file_name, + stripPrefix = strip_prefix, + ) + repository_ctx.delete(file_name) + +def _use_downloaded_cuda_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDA redistribution and initializes hermetic CUDA repository.""" + major_version = "" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cuda_version: + # If no CUDA version is found, comment out all cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cuda_redist_path_prefix, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version(repository_ctx, lib_name_to_version_dict) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _cuda_repo_impl(repository_ctx): + local_cuda_path = get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + if local_cuda_path: + _use_local_cuda_path(repository_ctx, local_cuda_path) + else: + _use_downloaded_cuda_redistribution(repository_ctx) + +cuda_repo = repository_rule( + implementation = _cuda_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cuda_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDA_PATH", + ], +) + +def _use_downloaded_cudnn_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDNN redistribution and initializes hermetic CUDNN repository.""" + cudnn_version = None + major_version = "" + cudnn_version = (get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cudnn_version: + # If no CUDNN version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + arch_key = "cuda{version}_{arch}".format( + version = cuda_version.split(".")[0], + arch = arch_key, + ) + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cudnn_redist_path_prefix, + ) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _cudnn_repo_impl(repository_ctx): + local_cudnn_path = get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + if local_cudnn_path: + _use_local_cudnn_path(repository_ctx, local_cudnn_path) + else: + _use_downloaded_cudnn_redistribution(repository_ctx) + +cudnn_repo = repository_rule( + implementation = _cudnn_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cudnn_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDNN_VERSION", + "TF_CUDNN_VERSION", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDNN_PATH", + ], +) + +def _get_redistribution_urls(dist_info): + url_dict = {} + for arch in _REDIST_ARCH_DICT.keys(): + if "relative_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["relative_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + if "full_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["full_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + for cuda_version, data in dist_info[arch].items(): + # CUDNN JSON might contain paths for each CUDA version. + path_key = "relative_path" + if path_key not in data.keys(): + path_key = "full_path" + url_dict["{cuda_version}_{arch}".format( + cuda_version = cuda_version, + arch = _REDIST_ARCH_DICT[arch], + )] = [data[path_key], data.get("sha256", "")] + return url_dict + +def get_version_and_template_lists(version_to_template): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns lists of versions and templates provided in the dict.""" + template_to_version_map = {} + for version, template in version_to_template.items(): + if template not in template_to_version_map.keys(): + template_to_version_map[template] = [version] + else: + template_to_version_map[template].append(version) + version_list = [] + template_list = [] + for template, versions in template_to_version_map.items(): + version_list.append(",".join(versions)) + template_list.append(Label(template)) + return (version_list, template_list) + +def cudnn_redist_init_repository( + cudnn_redistributions, + cudnn_redist_path_prefix = CUDNN_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDNN repository.""" + if "cudnn" in cudnn_redistributions.keys(): + url_dict = _get_redistribution_urls(cudnn_redistributions["cudnn"]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates["cudnn"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cudnn_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cudnn_redist_path_prefix = cudnn_redist_path_prefix, + ) + +def cuda_redist_init_repositories( + cuda_redistributions, + cuda_redist_path_prefix = CUDA_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDA repositories.""" + for redist_name, _ in redist_versions_to_build_templates.items(): + if redist_name in ["cudnn", "cuda_nccl"]: + continue + if redist_name in cuda_redistributions.keys(): + url_dict = _get_redistribution_urls(cuda_redistributions[redist_name]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates[redist_name] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cuda_redist_path_prefix = cuda_redist_path_prefix, + ) diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl new file mode 100644 index 00000000000000..d7ccff736a4801 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -0,0 +1,243 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistribution versions.""" + +CUDA_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cuda/redist/" +CUDNN_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redist/" + +CUDA_REDIST_JSON_DICT = { + "11.8": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_11.8.0.json", + "941a950a4ab3b95311c50df7b3c8bca973e0cdda76fc2f4b456d2d5e4dac0281", + ], + "12.1.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.1.1.json", + "bafea3cb83a4cf5c764eeedcaac0040d0d3c5db3f9a74550da0e7b6ac24d378c", + ], + "12.2.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.2.0.json", + "d883762c6339c8ebb3ffb072facc8f7265cd257d2db16a475fff9a9306ecea89", + ], + "12.3.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.1.json", + "b3cc4181d711cf9b6e3718f323b23813c24f9478119911d7b4bceec9b437dbc3", + ], + "12.3.2": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.2.json", + "1b6eacf335dd49803633fed53ef261d62c193e5a56eee5019e7d2f634e39e7ef", + ], + "12.4.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.4.0.json", + "a4f496b8d5299939b34c9ef88dc4274821f8c9451b2d7c9bcee53166932da067", + ], + "12.4.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.4.1.json", + "9cd815f3b71c2e3686ef2219b7794b81044f9dcefaa8e21dacfcb5bc4d931892", + ], + "12.5.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.5.0.json", + "166664b520bfe51f27abcc8c7a934f4cb6ea287f8c399b5f8255f6f4d214569a", + ], + "12.5.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.5.1.json", + "7ab9c76014ae4907fa1b51738af599607a5fd8ca3a5c4bb4c3b31338cc642a93", + ], + "12.6.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json", + "87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab", + ], +} + +CUDNN_REDIST_JSON_DICT = { + "8.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.6.0.json", + "7f6f50bed4fd8216dc10d6ef505771dc0ecc99cce813993ab405cb507a21d51d", + ], + "8.9.4.25": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.4.25.json", + "02258dba8384860c9230fe3c78522e7bd8e350e461ccd37a8d932cb64127ba57", + ], + "8.9.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.6.json", + "6069ef92a2b9bb18cebfbc944964bd2b024b76f2c2c35a43812982e0bc45cf0c", + ], + "8.9.7.29": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.7.29.json", + "a0734f26f068522464fa09b2f2c186dfbe6ad7407a88ea0c50dd331f0c3389ec", + ], + "9.1.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.1.1.json", + "d22d569405e5683ff8e563d00d6e8c27e5e6a902c564c23d752b22a8b8b3fe20", + ], + "9.2.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.2.0.json", + "6852eb279b95d2b5775f7a7737ec133bed059107f863cdd8588f3ae6f13eadd7", + ], + "9.2.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.2.1.json", + "9a4198c59b2e66b2b115a736ebe4dc8f3dc6d78161bb494702f824da8fc77b99", + ], + "9.3.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.3.0.json", + "d17d9a7878365736758550294f03e633a0b023bec879bf173349bfb34781972e", + ], +} + +# The versions are different for x86 and aarch64 architectures because only +# NCCL release versions 2.20.3 and 2.20.5 have the wheels for aarch64. +CUDA_12_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", + }, + "aarch64-unknown-linux-gnu": { + "version": "2.20.5", + "url": "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", + "sha256": "1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", + }, +} + +CUDA_11_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/ac/9a/8b6a28b3b87d5fddab0e92cd835339eb8fbddaa71ae67518c8c1b3d05bae/nvidia_nccl_cu11-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "49d8350629c7888701d1fd200934942671cb5c728f49acc5a0b3a768820bed29", + }, +} + +CUDA_NCCL_WHEELS = { + "11.8": CUDA_11_NCCL_WHEEL_DICT, + "12.1.1": CUDA_12_NCCL_WHEEL_DICT, + "12.2.0": CUDA_12_NCCL_WHEEL_DICT, + "12.3.1": CUDA_12_NCCL_WHEEL_DICT, + "12.3.2": CUDA_12_NCCL_WHEEL_DICT, + "12.4.0": CUDA_12_NCCL_WHEEL_DICT, + "12.1.0": CUDA_12_NCCL_WHEEL_DICT, + "12.5.0": CUDA_12_NCCL_WHEEL_DICT, + "12.5.1": CUDA_12_NCCL_WHEEL_DICT, + "12.6.0": CUDA_12_NCCL_WHEEL_DICT, +} + +REDIST_VERSIONS_TO_BUILD_TEMPLATES = { + "cuda_nccl": { + "repo_name": "cuda_nccl", + "version_to_template": { + "2": "//third_party/nccl/hermetic:cuda_nccl.BUILD.tpl", + }, + }, + "cudnn": { + "repo_name": "cuda_cudnn", + "version_to_template": { + "9": "//third_party/gpus/cuda/hermetic:cuda_cudnn9.BUILD.tpl", + "8": "//third_party/gpus/cuda/hermetic:cuda_cudnn.BUILD.tpl", + }, + }, + "libcublas": { + "repo_name": "cuda_cublas", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + }, + }, + "cuda_cudart": { + "repo_name": "cuda_cudart", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + }, + }, + "libcufft": { + "repo_name": "cuda_cufft", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + "10": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + }, + }, + "cuda_cupti": { + "repo_name": "cuda_cupti", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + }, + }, + "libcurand": { + "repo_name": "cuda_curand", + "version_to_template": { + "10": "//third_party/gpus/cuda/hermetic:cuda_curand.BUILD.tpl", + }, + }, + "libcusolver": { + "repo_name": "cuda_cusolver", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cusolver.BUILD.tpl", + }, + }, + "libcusparse": { + "repo_name": "cuda_cusparse", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + }, + }, + "libnvjitlink": { + "repo_name": "cuda_nvjitlink", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", + }, + }, + "cuda_nvrtc": { + "repo_name": "cuda_nvrtc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + }, + }, + "cuda_cccl": { + "repo_name": "cuda_cccl", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + }, + }, + "cuda_nvcc": { + "repo_name": "cuda_nvcc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + }, + }, + "cuda_nvml_dev": { + "repo_name": "cuda_nvml", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + }, + }, + "cuda_nvprune": { + "repo_name": "cuda_nvprune", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + }, + }, + "cuda_nvtx": { + "repo_name": "cuda_nvtx", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + }, + }, +} diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index fefbf081c87e1c..8bf1db2b0f8f9f 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for CUDA autoconfiguration. +NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. + `cuda_configure` depends on the following environment variables: * `TF_NEED_CUDA`: Whether to enable building with CUDA. @@ -53,6 +55,11 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" @@ -67,20 +74,6 @@ _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" -def to_list_of_strings(elements): - """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. - - This is to be used to put a list of strings into the bzl file templates - so it gets interpreted as list of strings in Starlark. - - Args: - elements: list of string elements - - Returns: - single string of elements wrapped in quotes separated by a comma.""" - quoted_strings = ["\"" + element + "\"" for element in elements] - return ", ".join(quoted_strings) - def verify_build_defines(params): """Verify all variables that crosstool/BUILD.tpl expects are substituted. @@ -238,156 +231,6 @@ def find_cc(repository_ctx, use_cuda_clang): " environment variable").format(target_cc_name, cc_path_envvar)) return cc -_INC_DIR_MARKER_BEGIN = "#include <...>" - -# OSX add " (framework directory)" at the end of line, strip it. -_OSX_FRAMEWORK_SUFFIX = " (framework directory)" -_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) - -def _cxx_inc_convert(path): - """Convert path returned by cc -E xc++ in a complete path.""" - path = path.strip() - if path.endswith(_OSX_FRAMEWORK_SUFFIX): - path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() - return path - -def _normalize_include_path(repository_ctx, path): - """Normalizes include paths before writing them to the crosstool. - - If path points inside the 'crosstool' folder of the repository, a relative - path is returned. - If path points outside the 'crosstool' folder, an absolute path is returned. - """ - path = str(repository_ctx.path(path)) - crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) - - if path.startswith(crosstool_folder): - # We drop the path to "$REPO/crosstool" and a trailing path separator. - return path[len(crosstool_folder) + 1:] - return path - -def _is_compiler_option_supported(repository_ctx, cc, option): - """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" - result = repository_ctx.execute([ - cc, - option, - "-o", - "/dev/null", - "-c", - str(repository_ctx.path("tools/cpp/empty.cc")), - ]) - return result.stderr.find(option) == -1 - -def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): - """Compute the list of default C or C++ include directories.""" - if lang_is_cpp: - lang = "c++" - else: - lang = "c" - sysroot = [] - if tf_sysroot: - sysroot += ["--sysroot", tf_sysroot] - result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + - sysroot) - stderr = err_out(result) - index1 = stderr.find(_INC_DIR_MARKER_BEGIN) - if index1 == -1: - return [] - index1 = stderr.find("\n", index1) - if index1 == -1: - return [] - index2 = stderr.rfind("\n ") - if index2 == -1 or index2 < index1: - return [] - index2 = stderr.find("\n", index2 + 1) - if index2 == -1: - inc_dirs = stderr[index1 + 1:] - else: - inc_dirs = stderr[index1 + 1:index2].strip() - - print_resource_dir_supported = _is_compiler_option_supported( - repository_ctx, - cc, - "-print-resource-dir", - ) - - if print_resource_dir_supported: - resource_dir = repository_ctx.execute( - [cc, "-print-resource-dir"], - ).stdout.strip() + "/share" - inc_dirs += "\n" + resource_dir - - compiler_includes = [ - _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) - for p in inc_dirs.split("\n") - ] - - # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc - # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) - # but Bazel might encounter either (usually reported by the compiler) - # especially when a compiler wrapper (e.g. ccache) is used. - # So we need to also include paths where symlinks are not resolved. - - # Try to find real path to CC installation to "see through" compiler wrappers - # GCC has the path to g++ - index1 = result.stderr.find("COLLECT_GCC=") - if index1 != -1: - index1 = result.stderr.find("=", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname - else: - # Clang has the directory - index1 = result.stderr.find("InstalledDir: ") - if index1 != -1: - index1 = result.stderr.find(" ", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname - else: - # Fallback to the CC path - cc_topdir = repository_ctx.path(cc).dirname.dirname - - # We now have the compiler installation prefix, e.g. /symlink/gcc - # And the resolved installation prefix, e.g. /opt/gcc - cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() - cc_topdir = str(cc_topdir).strip() - - # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. - # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] - # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] - if cc_topdir_resolved != cc_topdir: - unresolved_compiler_includes = [ - cc_topdir + inc[len(cc_topdir_resolved):] - for inc in compiler_includes - if inc.startswith(cc_topdir_resolved) - ] - compiler_includes = compiler_includes + unresolved_compiler_includes - return compiler_includes - -def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): - """Compute the list of default C and C++ include directories.""" - - # For some reason `clang -xc` sometimes returns include paths that are - # different from the ones from `clang -xc++`. (Symlink and a dir) - # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists - includes_cpp = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - True, - tf_sysroot, - ) - includes_c = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - False, - tf_sysroot, - ) - - return includes_cpp + [ - inc - for inc in includes_c - if inc not in includes_cpp - ] - def auto_configure_fail(msg): """Output failure message when cuda configuration fails.""" red = "\033[0;31m" @@ -1293,6 +1136,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cuda_nvcc_files}"] = "[]" if is_cuda_clang and not is_nvcc_and_clang: cuda_defines["%{host_compiler_path}"] = str(cc) cuda_defines["%{host_compiler_warnings}"] = """ diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index b88694af5c014d..68623bf671da71 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -14,6 +14,9 @@ # ============================================================================== """Prints CUDA library and header directories and versions found on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + The script searches for CUDA library and header files on the system, inspects them to determine their version and prints the configuration to stdout. The paths to inspect and the required versions are specified through environment diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index ff9b53b407be44..fb63d4db886c1c 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -22,12 +22,15 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "enable_cuda", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) load( ":sycl_configure.bzl", diff --git a/third_party/gpus/sycl_configure.bzl b/third_party/gpus/sycl_configure.bzl index 05330b2fe53195..dd80694e7274f5 100644 --- a/third_party/gpus/sycl_configure.bzl +++ b/third_party/gpus/sycl_configure.bzl @@ -16,11 +16,14 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..821b0b238614b3 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,30 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +@@ -342,6 +342,7 @@ + "include/mlir/IR/PDLPatternMatch.h.inc", + "include/mlir/Interfaces/CallInterfaces.h", + "include/mlir/Interfaces/DataLayoutInterfaces.h", ++ "include/mlir/Interfaces/InferIntRangeInterface.h", + "include/mlir/Interfaces/SideEffectInterfaces.h", + ], + hdrs = glob([ +@@ -362,6 +363,7 @@ + ":BytecodeOpInterfaceIncGen", + ":CallOpInterfacesIncGen", + ":DataLayoutInterfacesIncGen", ++ ":InferIntRangeInterfaceIncGen", + ":OpAsmInterfaceIncGen", + ":RegionKindInterfaceIncGen", + ":SideEffectInterfacesIncGen", +@@ -5422,7 +5424,9 @@ + hdrs = glob(["include/mlir/Dialect/LLVMIR/Transforms/*.h"]), + includes = ["include"], + deps = [ ++ ":DataLayoutInterfaces", + ":FuncDialect", ++ ":InliningUtils", + ":IR", + ":LLVMDialect", + ":LLVMPassIncGen", diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 88869a4d59aeed..5f8535bcee878a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "0c25f85e5b88102363c0cd55e1946053d5827e99" - LLVM_SHA256 = "851d958e60193edfb54d6eb8644785179eeb604edae8c026ac1819e82c059f6c" + LLVM_COMMIT = "1115dee248e68a155001ac3712a189299d104863" + LLVM_SHA256 = "cbfe9694c137ed4489b1667dd01429b7595b40aa47b8d3ae4cafa8a6cff2ef8f" tf_http_archive( name = name, diff --git a/third_party/nanobind/nanobind.BUILD b/third_party/nanobind/nanobind.BUILD index c9f307b75ef0ca..72b47585b5e5d0 100644 --- a/third_party/nanobind/nanobind.BUILD +++ b/third_party/nanobind/nanobind.BUILD @@ -4,9 +4,12 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "nanobind", - srcs = glob([ - "src/*.cpp", - ]), + srcs = glob( + [ + "src/*.cpp", + ], + exclude = ["src/nb_combined.cpp"], + ), copts = ["-fexceptions"], defines = [ "NB_BUILD=1", diff --git a/third_party/nanobind/pr438.patch b/third_party/nanobind/pr438.patch deleted file mode 100644 index edb7d61700e03b..00000000000000 --- a/third_party/nanobind/pr438.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/src/nb_enum.cpp b/src/nb_enum.cpp -index 86f64d1..91f3932 100644 ---- a/src/nb_enum.cpp -+++ b/src/nb_enum.cpp -@@ -73,6 +73,13 @@ static PyObject *nb_enum_get_doc(PyObject *self, void *) { - return result; - } - -+static PyObject *nb_enum_get_value(PyObject *self, void *) { -+ enum_supplement &supp = nb_enum_supplement(Py_TYPE(self)); -+ return supp.is_signed ? nb_enum_int_signed(self) -+ : nb_enum_int_unsigned(self); -+} -+ -+ - NB_NOINLINE static PyObject *nb_enum_int_signed(PyObject *o) { - type_data *t = nb_type_data(Py_TYPE(o)); - const void *p = inst_ptr((nb_inst *) o); -@@ -141,6 +148,8 @@ error: - static PyGetSetDef nb_enum_getset[] = { - { "__doc__", nb_enum_get_doc, nullptr, nullptr, nullptr }, - { "__name__", nb_enum_get_name, nullptr, nullptr, nullptr }, -+ { "name", nb_enum_get_name, nullptr, nullptr, nullptr }, -+ { "value", nb_enum_get_value, nullptr, nullptr, nullptr }, - { nullptr, nullptr, nullptr, nullptr, nullptr } - }; - -diff --git a/tests/test_enum.py b/tests/test_enum.py -index 2a6e9ff..1063eef 100644 ---- a/tests/test_enum.py -+++ b/tests/test_enum.py -@@ -14,6 +14,9 @@ def test01_unsigned_enum(): - assert int(t.Enum.A) == 0 - assert int(t.Enum.B) == 1 - assert int(t.Enum.C) == 0xffffffff -+ assert t.Enum.A.value == 0 -+ assert t.Enum.B.value == 1 -+ assert t.Enum.C.value == 0xffffffff - assert t.Enum(0) is t.Enum.A - assert t.Enum(1) is t.Enum.B - assert t.Enum(0xffffffff) is t.Enum.C -@@ -48,6 +51,9 @@ def test02_signed_enum(): - assert int(t.SEnum.A) == 0 - assert int(t.SEnum.B) == 1 - assert int(t.SEnum.C) == -1 -+ assert t.SEnum.A.value == 0 -+ assert t.SEnum.B.value == 1 -+ assert t.SEnum.C.value == -1 - assert t.SEnum(0) is t.SEnum.A - assert t.SEnum(1) is t.SEnum.B - assert t.SEnum(-1) is t.SEnum.C \ No newline at end of file diff --git a/third_party/nanobind/pr461.patch b/third_party/nanobind/pr461.patch deleted file mode 100644 index aa0a51b68175a3..00000000000000 --- a/third_party/nanobind/pr461.patch +++ /dev/null @@ -1,39 +0,0 @@ -diff --git a/src/nb_type.cpp b/src/nb_type.cpp ---- a/src/nb_type.cpp -+++ b/src/nb_type.cpp -@@ -36,6 +36,11 @@ static PyObject **nb_weaklist_ptr(PyObje - return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr; - } - -+static PyGetSetDef inst_getset[] = { -+ { "__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr }, -+ { nullptr, nullptr, nullptr, nullptr, nullptr } -+}; -+ - static int inst_clear(PyObject *self) { - PyObject **dict = nb_dict_ptr(self); - if (dict) -@@ -923,8 +928,11 @@ PyObject *nb_type_new(const type_init_da - } - - bool has_traverse = false; -- for (PyType_Slot *ts = slots; ts != s; ++ts) -+ bool has_getset = false; -+ for (PyType_Slot *ts = slots; ts != s; ++ts) { - has_traverse |= ts->slot == Py_tp_traverse; -+ has_getset |= ts->slot == Py_tp_getset; -+ } - - Py_ssize_t dictoffset = 0, weaklistoffset = 0; - int num_members = 0; -@@ -948,6 +956,10 @@ PyObject *nb_type_new(const type_init_da - has_traverse = true; - } - spec.basicsize = (int) basicsize; -+ -+ if (!has_getset) { -+ *s++ = { Py_tp_getset, (void *) inst_getset }; -+ } - } - - if (is_weak_referenceable) { diff --git a/third_party/nanobind/workspace.bzl b/third_party/nanobind/workspace.bzl index 9f9022dbaa8d12..1c692d396e9b98 100644 --- a/third_party/nanobind/workspace.bzl +++ b/third_party/nanobind/workspace.bzl @@ -5,12 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "nanobind", - strip_prefix = "nanobind-1.9.2", - sha256 = "149a3da40b0a988513d8cf5e71db3037373823505a3c92f87b988c92d7e0ab34", - urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v1.9.2.tar.gz"), + strip_prefix = "nanobind-2.1.0", + sha256 = "c37c53c60ada5fe1c956e24bd4b83af669a2309bf952bd251f36a7d2fa3bacf0", + urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v2.1.0.tar.gz"), build_file = "//third_party/nanobind:nanobind.BUILD", - patch_file = [ - "//third_party/nanobind:pr438.patch", # Remove when updating to nanobind 2.0.0. - "//third_party/nanobind:pr461.patch", # Remove when updating to nanobind 2.0.0. - ], ) diff --git a/third_party/nccl/build_defs.bzl.tpl b/third_party/nccl/build_defs.bzl.tpl index 53a6d4e1e41890..a0930df34ecec8 100644 --- a/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/nccl/build_defs.bzl.tpl @@ -5,7 +5,6 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") # CUDA toolkit version as tuple (e.g. '(11, 1)'). _cuda_version = %{cuda_version} -_cuda_clang = %{cuda_clang} def _rdc_copts(): """Returns copts for compiling relocatable device code.""" @@ -121,25 +120,25 @@ _device_link = rule( "gpu_archs": attr.string_list(), "nvlink_args": attr.string_list(), "_nvlink": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvlink"), + default = Label("%{nvlink_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_fatbinary": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/fatbinary"), + default = Label("%{fatbinary_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_bin2c": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/bin2c"), + default = Label("%{bin2c_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_link_stub": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/crt/link.stub"), + default = Label("%{link_stub_label}"), allow_single_file = True, ), }, @@ -189,7 +188,7 @@ _prune_relocatable_code = rule( "input": attr.label(mandatory = True, allow_files = True), "gpu_archs": attr.string_list(), "_nvprune": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvprune"), + default = Label("%{nvprune_label}"), allow_single_file = True, executable = True, cfg = "host", diff --git a/third_party/nccl/hermetic/BUILD b/third_party/nccl/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl b/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl new file mode 100644 index 00000000000000..61d7809bcdaad1 --- /dev/null +++ b/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl @@ -0,0 +1,30 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nccl_shared_library", + shared_library = "lib/libnccl.so.%{libnccl_version}", + hdrs = [":headers"], + deps = ["@local_config_cuda//cuda:cuda_headers", ":headers"], +) +%{multiline_comment} +cc_library( + name = "nccl", + %{comment}deps = [":nccl_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nccl*.h", + %{comment}]), + include_prefix = "third_party/nccl", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["//visibility:public"], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) diff --git a/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/nccl/hermetic/nccl_configure.bzl new file mode 100644 index 00000000000000..75f5a10b6fe24e --- /dev/null +++ b/third_party/nccl/hermetic/nccl_configure.bzl @@ -0,0 +1,183 @@ +"""Repository rule for hermetic NCCL configuration. + +`nccl_configure` depends on the following environment variables: + + * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should + be used, "0" if NCCL should be linked in statically. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + +""" + +load( + "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "TF_NEED_CUDA", + "enable_cuda", + "get_cuda_version", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", +) + +_TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" + +_NCCL_DUMMY_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_STUB_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl_via_stub", + }), + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +def _create_local_nccl_repository(repository_ctx): + cuda_version = get_cuda_version(repository_ctx).split(".")[:2] + nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + + if get_host_environ(repository_ctx, _TF_NCCL_USE_STUB, "0") == "0": + repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) + else: + repository_ctx.file("BUILD", _NCCL_ARCHIVE_STUB_BUILD_CONTENT) + + repository_ctx.template("generated_names.bzl", repository_ctx.attr.generated_names_tpl, {}) + repository_ctx.template( + "build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), + "%{nvlink_label}": "@cuda_nvcc//:nvlink", + "%{fatbinary_label}": "@cuda_nvcc//:fatbinary", + "%{bin2c_label}": "@cuda_nvcc//:bin2c", + "%{link_stub_label}": "@cuda_nvcc//:link_stub", + "%{nvprune_label}": "@cuda_nvprune//:nvprune", + }, + ) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"%s\"" % nccl_version) + +def _nccl_autoconf_impl(repository_ctx): + if (not enable_cuda(repository_ctx) or + get_cpu_value(repository_ctx) != "Linux"): + # Add a dummy build file to make bazel query happy. + repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") + else: + _create_local_nccl_repository(repository_ctx) + +_ENVIRONS = [ + TF_NEED_CUDA, + TF_CUDA_VERSION, + _TF_NCCL_USE_STUB, + HERMETIC_CUDA_VERSION, + "LOCAL_NCCL_PATH", +] + +nccl_configure = repository_rule( + environ = _ENVIRONS, + implementation = _nccl_autoconf_impl, + attrs = { + "environ": attr.string_dict(), + "nccl_version": attr.label(default = Label("@cuda_nccl//:version.txt")), + "generated_names_tpl": attr.label(default = Label("//third_party/nccl:generated_names.bzl.tpl")), + "build_defs_tpl": attr.label(default = Label("//third_party/nccl:build_defs.bzl.tpl")), + }, +) +"""Downloads and configures the hermetic NCCL configuration. + +Add the following to your WORKSPACE file: + +```python +nccl_configure(name = "local_config_nccl") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/nccl/hermetic/nccl_redist_init_repository.bzl b/third_party/nccl/hermetic/nccl_redist_init_repository.bzl new file mode 100644 index 00000000000000..244cb851ddf591 --- /dev/null +++ b/third_party/nccl/hermetic/nccl_redist_init_repository.bzl @@ -0,0 +1,145 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic NCCL repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "OS_ARCH_DICT", + "create_build_file", + "create_dummy_build_file", + "get_archive_name", + "get_env_var", + "get_lib_name_to_version_dict", + "get_major_library_version", + "get_version_and_template_lists", + "use_local_path", +) +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_NCCL_WHEELS", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +def _use_downloaded_nccl_wheel(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads NCCL wheel and inits hermetic NCCL repository.""" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + major_version = "" + if not cuda_version: + # If no CUDA version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch = OS_ARCH_DICT[repository_ctx.os.arch] + dict_key = "{cuda_version}-{arch}".format( + cuda_version = cuda_version, + arch = arch, + ) + supported_versions = repository_ctx.attr.url_dict.keys() + if dict_key not in supported_versions: + fail( + ("The supported NCCL versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add NCCL distribution for" + + " CUDA version={version}, OS={arch}.") + .format( + supported_versions = supported_versions, + version = cuda_version, + arch = arch, + ), + ) + sha256 = repository_ctx.attr.sha256_dict[dict_key] + url = repository_ctx.attr.url_dict[dict_key] + + archive_name = get_archive_name(url) + file_name = archive_name + ".zip" + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + repository_ctx.extract( + archive = file_name, + stripPrefix = repository_ctx.attr.strip_prefix, + ) + repository_ctx.delete(file_name) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _use_local_nccl_path(repository_ctx, local_nccl_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic NCCL repository.""" + use_local_path(repository_ctx, local_nccl_path, ["include", "lib"]) + +def _cuda_nccl_repo_impl(repository_ctx): + local_nccl_path = get_env_var(repository_ctx, "LOCAL_NCCL_PATH") + if local_nccl_path: + _use_local_nccl_path(repository_ctx, local_nccl_path) + else: + _use_downloaded_nccl_wheel(repository_ctx) + +cuda_nccl_repo = repository_rule( + implementation = _cuda_nccl_repo_impl, + attrs = { + "sha256_dict": attr.string_dict(mandatory = True), + "url_dict": attr.string_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "strip_prefix": attr.string(), + }, + environ = ["HERMETIC_CUDA_VERSION", "TF_CUDA_VERSION", "LOCAL_NCCL_PATH"], +) + +def nccl_redist_init_repository( + cuda_nccl_wheels = CUDA_NCCL_WHEELS, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes NCCL repository.""" + nccl_artifacts_dict = {"sha256_dict": {}, "url_dict": {}} + for cuda_version, nccl_wheel_info in cuda_nccl_wheels.items(): + for arch in OS_ARCH_DICT.values(): + if arch in nccl_wheel_info.keys(): + cuda_version_to_arch_key = "%s-%s" % (cuda_version, arch) + nccl_artifacts_dict["sha256_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch].get("sha256", "") + nccl_artifacts_dict["url_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch]["url"] + repo_data = redist_versions_to_build_templates["cuda_nccl"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_nccl_repo( + name = repo_data["repo_name"], + sha256_dict = nccl_artifacts_dict["sha256_dict"], + url_dict = nccl_artifacts_dict["url_dict"], + versions = versions, + build_templates = templates, + strip_prefix = "nvidia/nccl", + ) diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl index 22cf64d4771062..59f8b5c08ef0ee 100644 --- a/third_party/nccl/nccl_configure.bzl +++ b/third_party/nccl/nccl_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for NCCL configuration. +NB: DEPRECATED! Use `hermetic/nccl_configure` rule instead. + `nccl_configure` depends on the following environment variables: * `TF_NCCL_VERSION`: Installed NCCL version or empty to build from source. @@ -8,7 +10,6 @@ files. * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is `/usr/local/cuda,usr/`. - * `TF_CUDA_CLANG`: "1" if using Clang, "0" if using NVCC. * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should be used, "0" if NCCL should be linked in statically. @@ -33,7 +34,6 @@ _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" _TF_NCCL_VERSION = "TF_NCCL_VERSION" _TF_NEED_CUDA = "TF_NEED_CUDA" _TF_CUDA_PATHS = "TF_CUDA_PATHS" -_TF_CUDA_CLANG = "TF_CUDA_CLANG" _TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" _DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR" @@ -129,7 +129,11 @@ def _create_local_nccl_repository(repository_ctx): _label("build_defs.bzl.tpl"), { "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), - "%{cuda_clang}": repr(get_host_environ(repository_ctx, _TF_CUDA_CLANG)), + "%{nvlink_label}": "@local_config_cuda//cuda:cuda/bin/nvlink", + "%{fatbinary_label}": "@local_config_cuda//cuda:cuda/bin/fatbinary", + "%{bin2c_label}": "@local_config_cuda//cuda:cuda/bin/bin2c", + "%{link_stub_label}": "@local_config_cuda//cuda:cuda/bin/crt/link.stub", + "%{nvprune_label}": "@local_config_cuda//cuda:cuda/bin/nvprune", }, ) else: @@ -181,7 +185,6 @@ _ENVIRONS = [ _TF_CUDA_COMPUTE_CAPABILITIES, _TF_NEED_CUDA, _TF_CUDA_PATHS, - _TF_CUDA_CLANG, ] remote_nccl_configure = repository_rule( diff --git a/third_party/py/python_repo.bzl b/third_party/py/python_repo.bzl index f8fdd1033b5e2f..13aed2b687129f 100644 --- a/third_party/py/python_repo.bzl +++ b/third_party/py/python_repo.bzl @@ -255,8 +255,12 @@ def _basic_wildcard_match(name, patterns, expected_match_result, match_all): def _custom_python_interpreter_impl(ctx): version = ctx.attr.version - strip_prefix = ctx.attr.strip_prefix.format(version = version) - urls = [url.format(version = version) for url in ctx.attr.urls] + version_variant = ctx.attr.version_variant + strip_prefix = ctx.attr.strip_prefix.format( + version = version, + version_variant = version_variant, + ) + urls = [url.format(version = version, version_variant = version_variant) for url in ctx.attr.urls] binary_name = ctx.attr.binary_name if not binary_name: ver_chunks = version.split(".") @@ -272,13 +276,12 @@ def _custom_python_interpreter_impl(ctx): output = srcs_dir, ) - configure_params = [] + configure_params = list(ctx.attr.configure_params) if "CC" in ctx.os.environ: configure_params.append("CC={}".format(ctx.os.environ["CC"])) if "CXX" in ctx.os.environ: configure_params.append("CXX={}".format(ctx.os.environ["CXX"])) - configure_params.append("--enable-optimizations") configure_params.append("--prefix=%s" % install_path.realpath) _exec_and_check( ctx, @@ -361,6 +364,11 @@ custom_python_interpreter = repository_rule( "strip_prefix": attr.string(), "binary_name": attr.string(mandatory = False), "version": attr.string(), + "version_variant": attr.string(), + "configure_params": attr.string_list( + mandatory = False, + default = ["--enable-optimizations"], + ), }, ) diff --git a/third_party/shardy/BUILD b/third_party/shardy/BUILD index ea1ecdb548c1f4..bf3ae84c142f65 100644 --- a/third_party/shardy/BUILD +++ b/third_party/shardy/BUILD @@ -2,4 +2,7 @@ # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) -exports_files(srcs = ["workspace.bzl"]) +exports_files(srcs = [ + "temporary.patch", + "workspace.bzl", +]) diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index f8d02d38377ed3..e69de29bb2d1d6 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,655 +0,0 @@ -diff --git a/docs/sdy_export_passes.md b/docs/sdy_export_passes.md -index 7a7e3ef..add024c 100755 ---- a/docs/sdy_export_passes.md -+++ b/docs/sdy_export_passes.md -@@ -12,12 +12,3 @@ the edge), and replaces the op with its input. - - TODO(tomnatan): consider moving the sharding to all targets that can have a - sharding attached. --### `-sdy-update-non-divisible-input-output-shardings` -- --_Makes FuncOp inputs/outputs evenly sharded, removing any need for padding due to non-divisible shardings._ -- --Users of Shardy expect the function inputs/outputs to be evenly --divisible/shardable to avoid requiring padding their tensors. Propagation --may make inputs/outputs have non-divisible shardings, so this pass updates --them to the largest dimension sharding prefix of the original sharding that --is evenly sharded. -diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc -index aaa33c5..f6f88bc 100644 ---- a/shardy/dialect/sdy/ir/dialect.cc -+++ b/shardy/dialect/sdy/ir/dialect.cc -@@ -28,7 +28,6 @@ limitations under the License. - #include "llvm/ADT/SmallVector.h" - #include "llvm/Support/ErrorHandling.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/OperationSupport.h" -@@ -431,8 +430,7 @@ TensorShardingPerValueAttr TensorShardingPerValueAttr::getFullyOpen( - for (Type type : types) { - int64_t rank = 0; - // TODO(tomnatan): remove mlir:: once Attribute::dyn_cast is removed. -- if (auto tensorType = mlir::dyn_cast(type)) { -- assert(tensorType.hasStaticShape()); -+ if (auto tensorType = mlir::dyn_cast(type)) { - rank = tensorType.getRank(); - } - shardingPerResult.push_back( -diff --git a/shardy/dialect/sdy/ir/ops.td b/shardy/dialect/sdy/ir/ops.td -index 9478d7b..ca67f51 100644 ---- a/shardy/dialect/sdy/ir/ops.td -+++ b/shardy/dialect/sdy/ir/ops.td -@@ -135,12 +135,12 @@ def Sdy_ManualComputationOp : Sdy_Op<"manual_computation", - }]; - - let arguments = (ins -- Variadic:$tensors, -+ Variadic:$tensors, - Sdy_TensorShardingPerValue:$in_shardings, - Sdy_TensorShardingPerValue:$out_shardings, - Sdy_ManualAxes:$manual_axes - ); -- let results = (outs Variadic:$results); -+ let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$body); - - let assemblyFormat = [{ -@@ -249,6 +249,27 @@ def Sdy_ConstantOp : Sdy_Op<"constant", - }]; - } - -+//===----------------------------------------------------------------------===// -+// IdentityOp -+//===----------------------------------------------------------------------===// -+ -+def IdentityOp : Sdy_Op<"identity", -+ [Pure, Elementwise, SameOperandsAndResultType]> { -+ let summary = "Identity operation"; -+ -+ let description = [{ -+ An identity op that outputs the same value that it takes as input. This is -+ useful, to break a pattern where a block argument is directly used in the -+ block's terminator, which could result in canonicalization removing that -+ block argument, e.g., a block argument of a while op that could be replaced -+ with the corresponding operand as a free variable. -+ }]; -+ -+ let arguments = (ins AnyTensor:$input); -+ let results = (outs AnyTensor:$result); -+ let assemblyFormat = "attr-dict $input `:` type($input)"; -+} -+ - //===----------------------------------------------------------------------===// - // DataFlowEdgeOp - //===----------------------------------------------------------------------===// -@@ -316,10 +337,10 @@ def DataFlowEdgeOp : Sdy_Op<"data_flow_edge", - }]; - - let arguments = (ins -- AnyShaped:$input, -+ AnyRankedTensor:$input, - OptionalAttr:$sharding); - -- let results = (outs AnyShaped:$result); -+ let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = "$input (`sharding````=``` $sharding^)? attr-dict `:` type($result)"; - -@@ -360,10 +381,10 @@ def PropagationBarrierOp : Sdy_Op<"propagation_barrier", - }]; - - let arguments = (ins -- AnyRankedTensor:$input, -+ AnyTensor:$input, - Sdy_PropagationDirection:$allowed_direction - ); -- let results = (outs AnyRankedTensor:$result); -+ let results = (outs AnyTensor:$result); - let assemblyFormat = "$input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)"; - let hasVerifier = 1; - } -diff --git a/shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir b/shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir -index b247d79..c2a355d 100644 ---- a/shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir -+++ b/shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir -@@ -12,15 +12,6 @@ func.func @invalid_sharding(%arg0 : tensor<8xf32>) -> tensor<8xf32> { - - // ----- - --func.func @dynamic_shaped_type(%arg0: tensor) -- -> (tensor, tensor) { -- // expected-error @+1 {{expected sdy.data_flow_edge to have a static-shaped result}} -- %0 = sdy.data_flow_edge %arg0 : tensor -- return %arg0, %0 : tensor, tensor --} -- --// ----- -- - func.func @input_has_multiple_users(%arg0: tensor<32x96xf32>) - -> (tensor<32x96xf32>, tensor<32x96xf32>) { - // expected-error @+1 {{expected input of sdy.data_flow_edge to have a single user}} -diff --git a/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir b/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir -index e64c43c..9fc6e87 100644 ---- a/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir -+++ b/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir -@@ -16,22 +16,6 @@ func.func @sharding_rule_wrong_attr_type(%arg0: tensor<8xf32>) -> tensor<8xf32> - - // ----- - --func.func @unranked_tensor_type(%arg0: tensor<*xf32>) -> tensor<*xf32> { -- // expected-error@+1 {{operand 0 - expected a ranked tensor with a static shape}} -- %0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=2, j=4}>} : tensor<*xf32> -- return %0 : tensor<*xf32> --} -- --// ----- -- --func.func @dynamic_shaped_tensor_type(%arg0: tensor) -> tensor { -- // expected-error@+1 {{operand 0 - expected a ranked tensor with a static shape}} -- %0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=2, j=4}>} : tensor -- return %0 : tensor --} -- --// ----- -- - func.func @operand_mappings_wrong_rank(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> { - // expected-error@+1 {{operand 1 - mapping rank must match: 1 != 2}} - %0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i])->([i, j]) {i=2, j=4}>} : tensor<2x4xf32> -diff --git a/shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir b/shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir -index 540ce8d..50394d1 100644 ---- a/shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir -+++ b/shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir -@@ -2,7 +2,7 @@ - - sdy.mesh @mesh = <"a"=2> - --// expected-error @+1 {{'func.func' op arg 0 - non-shaped tensors can only have a sharding with rank 0 and no replicated axes}} -+// expected-error @+1 {{'func.func' op arg 0 - non-ranked tensors can only have a sharding with rank 0 and no replicated axes}} - func.func @token_sharding_rank_non_zero(%arg0: !stablehlo.token {sdy.sharding=#sdy.sharding<@mesh, [{}]>}) -> !stablehlo.token { - return %arg0 : !stablehlo.token - } -@@ -11,31 +11,13 @@ func.func @token_sharding_rank_non_zero(%arg0: !stablehlo.token {sdy.sharding=#s - - sdy.mesh @mesh = <"a"=2> - --// expected-error @+1 {{'func.func' op arg 0 - non-shaped tensors can only have a sharding with rank 0 and no replicated axes}} -+// expected-error @+1 {{'func.func' op arg 0 - non-ranked tensors can only have a sharding with rank 0 and no replicated axes}} - func.func @token_sharding_with_replicated_axes(%arg0: !stablehlo.token {sdy.sharding=#sdy.sharding<@mesh, [], replicated={"a"}>}) -> !stablehlo.token { - return %arg0 : !stablehlo.token - } - - // ----- - --sdy.mesh @mesh = <"a"=2> -- --// expected-error @+1 {{'func.func' op arg 0 - only ranked tensors with a static shape can have a sharding}} --func.func @unranked_tensor_with_sharding(%arg0: tensor<*xf32> {sdy.sharding=#sdy.sharding<@mesh, []>}) -> tensor<*xf32> { -- return %arg0 : tensor<*xf32> --} -- --// ----- -- --sdy.mesh @mesh = <"a"=2> -- --// expected-error @+1 {{'func.func' op arg 0 - only ranked tensors with a static shape can have a sharding}} --func.func @dynamic_shaped_tensor_with_sharding(%arg0: tensor<*xf32> {sdy.sharding=#sdy.sharding<@mesh, [{}, {}]>}) -> tensor { -- return %arg0 : tensor<*xf32> --} -- --// ----- -- - sdy.mesh @mesh = <"a"=2, "b"=2> - - func.func @dim_shardings_rank_mismatch(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { -diff --git a/shardy/dialect/sdy/ir/utils.cc b/shardy/dialect/sdy/ir/utils.cc -index b184794..8831d58 100644 ---- a/shardy/dialect/sdy/ir/utils.cc -+++ b/shardy/dialect/sdy/ir/utils.cc -@@ -28,7 +28,6 @@ limitations under the License. - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/Diagnostics.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/Operation.h" -@@ -92,37 +91,26 @@ std::string operationToString(Operation* op) { - return mlirToString(op); - } - --std::string valueToString(Value value) { return mlirToString(&value); } -- --ShapedType dynCastStaticShapedType(Type type) { -- if (auto shapedType = dyn_cast(type); -- shapedType && shapedType.hasStaticShape()) { -- return shapedType; -- } -- return nullptr; --} -- --bool isStaticShapedType(Type type) { -- return dynCastStaticShapedType(type) != nullptr; -+std::string valueToString(Value value) { -+ return mlirToString(&value); - } - - ArrayRef getTensorShape(Value value) { -- if (auto tensorType = dyn_cast(value.getType())) { -+ if (auto tensorType = dyn_cast(value.getType())) { - return tensorType.getShape(); - } - return {}; - } - - int64_t getTensorRank(Value value) { -- if (auto tensorType = dyn_cast(value.getType())) { -+ if (auto tensorType = dyn_cast(value.getType())) { - return tensorType.getRank(); - } - return 0; - } - - int64_t isScalar(Value value) { -- if (auto tensorType = dyn_cast(value.getType()); -- tensorType && tensorType.hasRank()) { -+ if (auto tensorType = dyn_cast(value.getType())) { - return tensorType.getRank() == 0; - } - return false; -diff --git a/shardy/dialect/sdy/ir/utils.h b/shardy/dialect/sdy/ir/utils.h -index c151955..d0868a7 100644 ---- a/shardy/dialect/sdy/ir/utils.h -+++ b/shardy/dialect/sdy/ir/utils.h -@@ -26,7 +26,6 @@ limitations under the License. - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/Operation.h" - #include "mlir/IR/PatternMatch.h" -@@ -66,23 +65,12 @@ std::string operationToString(Operation* op); - // Converts `value` to string with location information. - std::string valueToString(Value value); - --// If the given `type` is a `ShapedType` with a static shape, returns it, --// otherwise returns nullptr. --ShapedType dynCastStaticShapedType(Type type); -- --// Returns true if the given `type` is a `ShapedType` with a static shape. --bool isStaticShapedType(Type type); -- --// Returns the shape of the given `value` if its type is a `ShapeTensor`, -+// Returns the shape of the given `value` if its type is a `RankedTensorType`, - // otherwise returns an empty array. --// --// Assumes the `ShapeTensor` has a rank. - ArrayRef getTensorShape(Value value); - --// Returns the rank of the given `value` if its type is a `ShapeTensor`, -+// Returns the rank of the given `value` if its type is a `RankedTensorType`, - // otherwise returns 0. --// --// Assumes the `ShapeTensor` has a rank. - int64_t getTensorRank(Value value); - - // Returns true if the value is a tensor with rank 0. -diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc -index 61fd0e0..015e10f 100644 ---- a/shardy/dialect/sdy/ir/verifiers.cc -+++ b/shardy/dialect/sdy/ir/verifiers.cc -@@ -30,7 +30,6 @@ limitations under the License. - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/Diagnostics.h" - #include "mlir/IR/SymbolTable.h" -@@ -200,11 +199,11 @@ LogicalResult emitBoundAxisInManualComputationError(EmitErrorFn emitError, - - // Verifies the following for `shardingAttr`: - // --// If `type` isn't a `ShapedType`, the sharding must have rank 0 and no -+// If `type` isn't a `RankedTensorType`, the sharding must have rank 0 and no - // replicated axes. - // --// - The tensor should have a rank and static shape. --// - The number of dimension shardings is equal to the rank of the tensor. -+// - The number of dimension shardings is equal to the rank of the tensor -+// (specified by `type`, which should be a `RankedTensorType`). - // - Dimensions of size 0 aren't sharded. - // - Replicated axes are ordered w.r.t. `mesh` (see - // AxisRefAttr::getMeshComparator). -@@ -221,22 +220,17 @@ LogicalResult verifyTensorShardingAttr( - TensorShardingAttr shardingAttr, Type type, MeshAttr mesh, - EmitErrorFn emitError, - ManualAxisToOwner alreadyManualAxes = ManualAxisToOwner()) { -- auto tensorType = dyn_cast(type); -+ auto tensorType = dyn_cast(type); - if (!tensorType) { - if (shardingAttr.getRank() != 0 || - !shardingAttr.getReplicatedAxes().empty()) { - return emitError( -- "non-shaped tensors can only have a sharding with rank 0 ") -+ "non-ranked tensors can only have a sharding with rank 0 ") - << "and no replicated axes. type: " << type - << ", sharding: " << shardingAttr; - } - return success(); - } -- if (!tensorType.hasStaticShape()) { -- return emitError( -- "only ranked tensors with a static shape can have a sharding. ") -- << "type: " << type; -- } - int64_t rank = tensorType.getRank(); - if (shardingAttr.getRank() != rank) { - return emitError("sharding doesn't match tensor rank: ") -@@ -432,6 +426,7 @@ LogicalResult verifyShardingRuleMapping(Operation* op, TypeRange types, - // doesn't reuse the same factor. - BitVector valueSeenFactorIndices(factorSizes.size()); - auto [type, mapping] = typeAndMapping; -+ auto tensorType = cast(type); - - EmitErrorFn valueEmitError = getEmitValueInRangeErrorFn( - [op, valueKindStr](StringRef msg) { -@@ -439,13 +434,6 @@ LogicalResult verifyShardingRuleMapping(Operation* op, TypeRange types, - }, - types.size(), index); - -- auto tensorType = dynCastStaticShapedType(type); -- if (!tensorType) { -- return valueEmitError( -- "expected a ranked tensor with a static shape. type: ") -- << type; -- } -- - if (mapping.getRank() != tensorType.getRank()) { - return valueEmitError("mapping rank must match: ") - << mapping.getRank() << " != " << tensorType.getRank(); -@@ -571,11 +559,6 @@ LogicalResult ReshardOp::verify() { - } - - LogicalResult DataFlowEdgeOp::verify() { -- if (!getType().hasStaticShape()) { -- return emitOpError( -- "expected sdy.data_flow_edge to have a static-shaped result. ") -- << "type: " << getType(); -- } - if (!getInput().hasOneUse()) { - return emitOpError( - "expected input of sdy.data_flow_edge to have a single user"); -@@ -682,8 +665,8 @@ LogicalResult verifyManualComputationValue( - for (auto [valueIndex, valueEntry] : llvm::enumerate(llvm::zip_equal( - globalTypes, localTypes, shardingPerValueAttr.getShardings()))) { - auto [globalType, localType, sharding] = valueEntry; -- auto globalRankedType = cast(globalType); -- auto localRankedType = cast(localType); -+ auto globalRankedType = globalType.template cast(); -+ auto localRankedType = localType.template cast(); - - // 5. Verify the manual axes come before any free axes in each dim sharding. - for (auto [dim, dimSharding] : -@@ -710,7 +693,7 @@ LogicalResult verifyManualComputationValue( - accumulatedManualAxesSize(op, dimSharding.getAxes(), - manualAxes, mesh)); - } -- auto expectedLocalRankedType = -+ RankedTensorType expectedLocalRankedType = - RankedTensorType::get(newDimSizes, globalRankedType.getElementType()); - if (expectedLocalRankedType != localRankedType) { - return op->emitOpError(valueKindStr) -diff --git a/shardy/dialect/sdy/transforms/export/update_non_divisible_input_output_shardings.cc b/shardy/dialect/sdy/transforms/export/update_non_divisible_input_output_shardings.cc -index 22c4269..6a4d05c 100644 ---- a/shardy/dialect/sdy/transforms/export/update_non_divisible_input_output_shardings.cc -+++ b/shardy/dialect/sdy/transforms/export/update_non_divisible_input_output_shardings.cc -@@ -23,7 +23,6 @@ limitations under the License. - #include "llvm/Support/ErrorHandling.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/TypeRange.h" - #include "mlir/IR/Value.h" -@@ -59,7 +58,8 @@ namespace { - // - [{"y","x"}] : tensor<4xf32> -> [{"y","x":(1)2}] : tensor<4xf32> - // See update_non_divisible_input_output_shardings.mlir for more examples. - TensorShardingAttr getEvenlySharded(TensorShardingAttr sharding, -- ShapedType type, func::FuncOp funcOp) { -+ RankedTensorType type, -+ func::FuncOp funcOp) { - StringRef meshName = sharding.getMeshName(); - MeshAttr mesh = getMeshAttr(funcOp, meshName); - assert(mesh && "unknown mesh"); -@@ -130,7 +130,7 @@ void updateValueShardings( - func::FuncOp funcOp) { - for (auto [index, type] : llvm::enumerate(types)) { - TensorShardingAttr sharding = getSharding(index); -- if (auto tensorType = dynCastStaticShapedType(type); -+ if (auto tensorType = dyn_cast(type); - sharding && tensorType) { - setSharding(index, getEvenlySharded(sharding, tensorType, funcOp)); - } -diff --git a/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc b/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc -index 91b5acb..b67c18c 100644 ---- a/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc -+++ b/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc -@@ -47,8 +47,8 @@ struct AddDataFlowEdgesPass - ValueRange edgeRoots = getDataFlowEdgeRoots(op); - rewriter.setInsertionPointAfter(op); - for (Value edgeRoot : edgeRoots) { -- if (!isStaticShapedType(edgeRoot.getType())) { -- // Skip non-static-shaped tensors, e.g., tokens. -+ if (!isa(edgeRoot.getType())) { -+ // Skip non-tensor values, e.g., tokens. - continue; - } - TensorShardingAttr sharding = nullptr; -diff --git a/shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir b/shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir -index 67cede6..f31387d 100644 ---- a/shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir -+++ b/shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir -@@ -66,16 +66,6 @@ func.func @optimization_barrier(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf3 - return %0#0, %0#1 : tensor<32x96xf32>, tensor<32x96xf32> - } - --// CHECK-LABEL: func @optimization_barrier --func.func @optimization_barrier_dynamic_shaped_tensor_skipped(%arg0: tensor<32x96xf32>, %arg1: tensor) -- -> (tensor<32x96xf32>, tensor) { -- // CHECK-NEXT: %[[OPT_BARRIER:.*]]:2 = stablehlo.optimization_barrier %arg0, %arg1 -- // CHECK: %[[EDGE_1:.*]] = sdy.data_flow_edge %[[OPT_BARRIER]]#0 -- // CHECK-NEXT: return %[[EDGE_1]], %[[OPT_BARRIER]]#1 -- %0:2 = stablehlo.optimization_barrier %arg0, %arg1 : tensor<32x96xf32>, tensor -- return %0#0, %0#1 : tensor<32x96xf32>, tensor --} -- - // CHECK-LABEL: func @while_unused_result - func.func @while_unused_result(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { - // CHECK: %[[C0:.*]] = stablehlo.constant dense<0> -diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc -index 8117426..eff74a3 100644 ---- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc -+++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc -@@ -28,7 +28,6 @@ limitations under the License. - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinOps.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/Diagnostics.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/OpDefinition.h" -@@ -46,6 +45,7 @@ limitations under the License. - #include "shardy/dialect/sdy/ir/data_flow_utils.h" - #include "shardy/dialect/sdy/ir/dialect.h" - #include "shardy/dialect/sdy/ir/utils.h" -+#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h" - #include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h" - #include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h" - #include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h" -@@ -328,9 +328,9 @@ LogicalResult propagateFuncResults(FuncOp funcOp, - const FactorPropagation& factorPropagation) { - for (OpOperand& returnOperand : getBodyTerminatorOpOperands(funcOp)) { - Value returnValue = returnOperand.get(); -- auto tensorType = dynCastStaticShapedType(returnValue.getType()); -+ auto tensorType = dyn_cast(returnValue.getType()); - if (!tensorType) { -- // Skip non-static-shaped tensors, e.g., tokens. -+ // Skip non-tensor values, e.g., tokens. - continue; - } - int64_t resNum = returnOperand.getOperandNumber(); -@@ -436,7 +436,7 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern { - return propagateTensorShardings( - sources, dataFlowEdgeOp.getResult(), - createIdentityShardingRule( -- cast(dataFlowEdgeOp.getType()), sources.size()), -+ cast(dataFlowEdgeOp.getType()), sources.size()), - dataFlowEdgeOp, rewriter, factorPropagation); - } - -diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc -index 3763581..2b8ff59 100644 ---- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc -+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc -@@ -23,7 +23,6 @@ limitations under the License. - #include - - #include "llvm/ADT/STLExtras.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/Operation.h" -@@ -88,12 +87,12 @@ OpShardingRuleBuilder::OpShardingRuleBuilder( - resultMappings.reserve(resultTypes.size()); - int64_t maxRank = 0; - for (Type operandType : operandTypes) { -- int64_t rank = cast(operandType).getRank(); -+ int64_t rank = cast(operandType).getRank(); - maxRank = std::max(maxRank, rank); - operandMappings.push_back(TensorMapping(rank)); - } - for (Type resultType : resultTypes) { -- int64_t rank = cast(resultType).getRank(); -+ int64_t rank = cast(resultType).getRank(); - maxRank = std::max(maxRank, rank); - resultMappings.push_back(TensorMapping(rank)); - } -@@ -126,7 +125,7 @@ OpShardingRuleAttr OpShardingRuleBuilder::build() { - OpShardingRuleAttr OpShardingRuleBuilder::buildPointwise(Operation* op) { - // All results should have the same shape, so we look at the first. - ArrayRef shape = -- cast(op->getResultTypes().front()).getShape(); -+ cast(op->getResultTypes().front()).getShape(); - - OpShardingRuleBuilder builder(op); - -@@ -201,7 +200,7 @@ OpShardingRuleBuilder& OpShardingRuleBuilder::addPointwiseIfDimSizesMatch( - return *this; - } - --OpShardingRuleAttr createIdentityShardingRule(ShapedType type, -+OpShardingRuleAttr createIdentityShardingRule(RankedTensorType type, - size_t numOperands, - size_t numResults) { - return OpShardingRuleBuilder(SmallVector(numOperands, type), -diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h -index 5130827..5d0b5a8 100644 ---- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h -+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h -@@ -22,7 +22,6 @@ limitations under the License. - #include - #include - --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/Operation.h" -@@ -119,7 +118,7 @@ class OpShardingRuleBuilder { - // i.e., all operands/results have the same mapping. - // - // NOTE: an empty rule {([])->([])} will be created for scalar ops. --OpShardingRuleAttr createIdentityShardingRule(ShapedType type, -+OpShardingRuleAttr createIdentityShardingRule(RankedTensorType type, - size_t numOperands = 1, - size_t numResults = 1); - -diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -index 98fa7a1..80e4933 100644 ---- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -@@ -144,8 +144,8 @@ OpShardingRuleAttr getOrCreateShardingRule(Operation* op, - OpShardingRuleAttr createOpShardingRule(Operation* op, - const bool conservativePropagation) { - return TypeSwitch(op) -- .Case, %arg1: tensor<8x16xf32>) - return %0 : tensor<8x16xf32> - } - --// CHECK-LABEL: func @token_func_output_skipped( -+// CHECK-LABEL: func @token_func_output_token_skipped( - // CHECK-SAME: %arg0: !stablehlo.token, - // CHECK-SAME: %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) - // CHECK-SAME: -> (!stablehlo.token, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) { --func.func @token_func_output_skipped(%arg0: !stablehlo.token, %arg1: tensor<8x16xf32>) -+func.func @token_func_output_token_skipped(%arg0: !stablehlo.token, %arg1: tensor<8x16xf32>) - -> (!stablehlo.token, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) { - // CHECK-NEXT: stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>]>} - %0 = stablehlo.add %arg1, %arg1 : tensor<8x16xf32> - return %arg0, %0 : !stablehlo.token, tensor<8x16xf32> - } - --// CHECK-LABEL: func @dynamic_shaped_func_output_skipped( --// CHECK-SAME: %arg0: tensor, --// CHECK-SAME: %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) --// CHECK-SAME: -> (tensor, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) { --func.func @dynamic_shaped_func_output_skipped(%arg0: tensor, %arg1: tensor<8x16xf32>) -- -> (tensor, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) { -- // CHECK-NEXT: stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>]>} -- %0 = stablehlo.add %arg1, %arg1 : tensor<8x16xf32> -- return %arg0, %0 : tensor, tensor<8x16xf32> --} -- - // CHECK-LABEL: func @func_result_intermediate_op_both_updated( - // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) - // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) { -diff --git a/shardy/integrations/python/ir/__init__.py b/shardy/integrations/python/ir/__init__.py -index 97e8a3b..89a06ba 100644 ---- a/shardy/integrations/python/ir/__init__.py -+++ b/shardy/integrations/python/ir/__init__.py -@@ -17,6 +17,7 @@ - # pylint: disable=g-multiple-import,g-importing-member,unused-import,useless-import-alias - from ._sdy_ops_gen import ( - ConstantOp as ConstantOp, -+ IdentityOp as IdentityOp, - ManualComputationOp as ManualComputationOp, - MeshOp as MeshOp, - ReshardOp as ReshardOp, -diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 0d420ba..88869a4 100644 ---- a/third_party/llvm/workspace.bzl -+++ b/third_party/llvm/workspace.bzl -@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") - - def repo(name): - """Imports LLVM.""" -- LLVM_COMMIT = "41491c77231e9d389ef18593be1fab4f4e810e88" -- LLVM_SHA256 = "10b17d9f8304eb7c9fb91f7b13f73e9e5ca81984aa692eac91b82d19db311547" -+ LLVM_COMMIT = "0c25f85e5b88102363c0cd55e1946053d5827e99" -+ LLVM_SHA256 = "851d958e60193edfb54d6eb8644785179eeb604edae8c026ac1819e82c059f6c" - - tf_http_archive( - name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index c50cb5177e2d70..6d91def025b34a 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "ef636ca340e01a2ef3f910bb0dffc4539019f793" - SHARDY_SHA256 = "ad87f171909ba0e7c9879e7f3e57c31e25f0fbe935e14ebe2dbd45ed4c64f632" + SHARDY_COMMIT = "7e3ddfb532b3b53cb0b108014c24a86ac147e9f6" + SHARDY_SHA256 = "1d304e1e6f1132fe3ccb969d28798bc6ee90db84d10c85113ef8573eae350325" tf_http_archive( name = "shardy", diff --git a/third_party/spirv_llvm_translator/BUILD b/third_party/spirv_llvm_translator/BUILD new file mode 100644 index 00000000000000..8d626dc7635d1a --- /dev/null +++ b/third_party/spirv_llvm_translator/BUILD @@ -0,0 +1,7 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +# spirv_llvm_translator license placeholder diff --git a/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD new file mode 100644 index 00000000000000..557e2e8f50edd2 --- /dev/null +++ b/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD @@ -0,0 +1,34 @@ +cc_library( + name = "spirv_llvm_translator", + srcs = glob([ + "lib/SPIRV/libSPIRV/*.cpp", + "lib/SPIRV/libSPIRV/*.hpp", + "lib/SPIRV/libSPIRV/*.h", + "lib/SPIRV/Mangler/*.cpp", + "lib/SPIRV/Mangler/*.h", + "lib/SPIRV/*.cpp", + "lib/SPIRV/*.hpp", + "lib/SPIRV/*.h", + ]), + hdrs = glob(["include/*"]), + includes = [ + "include/", + "lib/SPIRV/", + "lib/SPIRV/Mangler/", + "lib/SPIRV/libSPIRV/", + ], + visibility = ["//visibility:public"], + deps = [ + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@spirv_headers//:spirv_cpp_headers", + ], +) diff --git a/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/spirv_llvm_translator/spirv_llvm_translator.patch new file mode 100644 index 00000000000000..fc843b1b039b09 --- /dev/null +++ b/third_party/spirv_llvm_translator/spirv_llvm_translator.patch @@ -0,0 +1,25 @@ +diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h +index a828add8..924e13b4 100644 + +Spir backend uses different addrspace representations link with nvptx backend link. +We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding +changing addrspace based on device backend everywhere) + +--- a/lib/SPIRV/SPIRVInternal.h ++++ b/lib/SPIRV/SPIRVInternal.h +@@ -179,11 +179,12 @@ typedef SPIRVMap IntBoolOpMap; + "-v512:512:512-v1024:1024:1024" + + enum SPIRAddressSpace { +- SPIRAS_Private, ++ SPIRAS_Generic, + SPIRAS_Global, +- SPIRAS_Constant, ++ SPIRAS_Internal, + SPIRAS_Local, +- SPIRAS_Generic, ++ SPIRAS_Constant, ++ SPIRAS_Private, + SPIRAS_GlobalDevice, + SPIRAS_GlobalHost, + SPIRAS_Input, \ No newline at end of file diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe9..77fefee2b13b6d 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,28 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -1283,6 +1283,7 @@ + "@llvm-project//mlir:AllExtensions", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", ++ "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:TosaDialect", + ], + ) +diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py b/stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py +--- stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py ++++ stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py +@@ -32,9 +32,9 @@ + + # Make LLVM and StableHLO tools available in RUN directives + tools = [ +- 'stablehlo-opt', +- 'FileCheck', +- 'stablehlo-translate', ++ 'stablehlo-opt', ++ 'FileCheck', ++ 'stablehlo-translate', + ] + tool_dirs = [ + config.llvm_tools_dir, diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index b46c39e85fc240..6c0cea3e8f16f5 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "2fcbae0c933b5cc2735523bab2de880a3a9c5e46" - STABLEHLO_SHA256 = "14f879b246266dc7c5cb49cdbf88c87ebac0444e3ebae04b57448d4bbc2fe180" + STABLEHLO_COMMIT = "23d3e1414b0be1c1b5256f0949520dc4f0a0705c" + STABLEHLO_SHA256 = "ad694a3da43a2a432c8c5f1c60be39fc211e28834cca07cf663ce8dc85d920fe" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 4de9536cbddbab..3466def95fd60d 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "60277ba976739502e45ad26585e071568fa44af1" - TFRT_SHA256 = "7634f696ad57f0ec914c4092cd8a2d19371f024abeb23d06c8eb5c18be660405" + TFRT_COMMIT = "07992d7c1ead60f610c17b7c1f9e50b6898adc87" + TFRT_SHA256 = "e1de8d371248d3dfc6e9ebd0e4094b57ce04d9545ae3756b5a84c33482614d5f" tf_http_archive( name = "tf_runtime", diff --git a/third_party/triton/llvm_integration/cl657620552.patch b/third_party/triton/llvm_integration/cl657620552.patch deleted file mode 100644 index 4a1f47d79e6c92..00000000000000 --- a/third_party/triton/llvm_integration/cl657620552.patch +++ /dev/null @@ -1,18 +0,0 @@ -# Do not upstream this patch. This has been already upstreamed in -# https://github.com/triton-lang/triton/commit/de46a0ede6efe7e93c2a9ebef639e36c6177c511 -# Next integration will include it and this patch should be removed then. - -diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc ---- a/third_party/amd/python/triton_amd.cc -+++ b/third_party/amd/python/triton_amd.cc -@@ -193,9 +193,7 @@ void init_triton_amd(py::module &&m) { - target->createMCAsmBackend(*sti, *mri, mcOptions)); - mcStreamer.reset(target->createMCObjectStreamer( - triple, ctx, std::move(mab), mab->createObjectWriter(svos), -- std::move(ce), *sti, mcOptions.MCRelaxAll, -- mcOptions.MCIncrementalLinkerCompatible, -- /*DWARFMustBeAtTheEnd=*/false)); -+ std::move(ce), *sti)); - - std::unique_ptr parser( - createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 8162fb5fad6342..656b9c894904d8 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,6 +8,5 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton:llvm_integration/cl657620552.patch", # Add new patches just above this line ] diff --git a/third_party/triton/temporary/cuda11-temporary.patch b/third_party/triton/temporary/cuda11-temporary.patch deleted file mode 100644 index a92166eef6df71..00000000000000 --- a/third_party/triton/temporary/cuda11-temporary.patch +++ /dev/null @@ -1,35 +0,0 @@ -# This temporary patch has already been included to the public list of Triton -# patches. It is only here temporarily to be included in the openxla version, -# but it will be removed during the next triton integration. - ---- a/third_party/nvidia/backend/driver.c -+++ b/third_party/nvidia/backend/driver.c -@@ -154,6 +154,8 @@ static PyObject *loadBinary(PyObject *se - typedef CUresult (*cuOccupancyMaxActiveClusters_t)( - int *numClusters, CUfunction func, const CUlaunchConfig *config); - -+#if CUDA_VERSION < 12000 -+#else - typedef CUresult (*cuTensorMapEncodeTiled_t)( - CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, -@@ -161,6 +161,7 @@ typedef CUresult (*cuTensorMapEncodeTile - const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill); -+#endif - - #define defineGetFunctionHandle(name, symbolName) \ - static symbolName##_t name() { \ -@@ -187,8 +187,11 @@ typedef CUresult (*cuTensorMapEncodeTile - defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, - cuOccupancyMaxActiveClusters); - -+#if CUDA_VERSION < 12000 -+#else - defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, - cuTensorMapEncodeTiled); -+#endif - - static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { - int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 388e57f849f14e..4fa55269e3323c 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -14,7 +14,5 @@ those to this list. """ temporary_patch_list = [ - "//third_party/triton:temporary/cuda11-temporary.patch", - "//third_party/triton:temporary/undo_tesla_gpu.patch", # Add new patches just above this line ] diff --git a/third_party/triton/temporary/undo_tesla_gpu.patch b/third_party/triton/temporary/undo_tesla_gpu.patch deleted file mode 100644 index 6c2d1d1d734fbc..00000000000000 --- a/third_party/triton/temporary/undo_tesla_gpu.patch +++ /dev/null @@ -1,13 +0,0 @@ -This can be removed on the next integrate as it already exists in upstream. -diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -@@ -21,7 +21,7 @@ namespace { - static int getMMAVersionSafe(int computeCapability, DotOp op) { - // List supported mma version in order of preference. - SmallVector versionsSupported; -- if (computeCapability < 80) { -+ if (computeCapability < 75) { - versionsSupported = {1}; - } else if (computeCapability < 90) { - versionsSupported = {2}; diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index fc6c45f7bc1e5e..e74434221f6c98 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl657175856" - TRITON_SHA256 = "316f421a7d7ead2b7e5adc2e8bb68ce1a8f7809db73dbed8abd54c35bd0c1576" + TRITON_COMMIT = "cl664783844" + TRITON_SHA256 = "d5779d331008dd3a4941dd59e61385ec964987da74454248446ac3e36b874007" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch index a1c011dbb8beb5..dadc7732a4f280 100644 --- a/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/triton/xla_extensions/sparse_dot.patch @@ -57,7 +57,7 @@ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dia index 012786dae..6043b764a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp -@@ -498,6 +498,119 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, +@@ -498,6 +498,123 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } @@ -173,6 +173,10 @@ index 012786dae..6043b764a 100644 + ArrayRef tensorShape) const { + return ::getShapePerCTATile(getParent(), tensorShape); +} ++std::optional SparseDotMetaEncodingAttr::toLinearLayout( ++ ArrayRef shape) const { ++ return ::toLinearLayout(shape, getParent()); ++} + } // namespace gpu } // namespace triton @@ -273,9 +277,9 @@ index d74e0a224..4e45f7c4c 100644 + return op->hasTrait() || isa(op); +} + - // Replace the ForOp's yield with a new one with the given operands appended. - static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { - // Fix up the yield op. + static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + tt::CoarseSchedule &schedule, @@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) diff --git a/third_party/uv/uv.BUILD b/third_party/uv/uv.BUILD index b04383ad3487e7..43c194a53ea516 100644 --- a/third_party/uv/uv.BUILD +++ b/third_party/uv/uv.BUILD @@ -55,7 +55,19 @@ cc_library( # TODO: Add Linux, etc. as in https://github.com/libuv/libuv/blob/v1.38.0/CMakeLists.txt. hdrs = [ "include/uv.h", - ], + "src/heap-inl.h", + "src/idna.h", + "src/queue.h", + "src/strscpy.h", + "src/unix/atomic-ops.h", + "src/unix/internal.h", + "src/unix/spinlock.h", + "src/uv-common.h", + ] + select({ + "@platforms//os:osx": [ + "src/unix/darwin-stub.h", + ], + }) + glob(["include/uv/*.h"]), copts = [ "-fexceptions", "-Wno-unused-variable", diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 76f824f372e0d3..9e565e91a1b903 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -219,13 +219,16 @@ build:mkl_aarch64_threadpool -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda +# Default CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # CUDA: This config refers to building CUDA op kernels with clang. build:cuda_clang --config=cuda -# Enable TensorRT optimizations https://developer.nvidia.com/tensorrt -build:cuda_clang --config=tensorrt -build:cuda_clang --action_env=TF_CUDA_CLANG="1" build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --copt=-Qunused-arguments # Select supported compute capabilities (supported graphics cards). # This is the same as the official TensorFlow builds. # See https://developer.nvidia.com/cuda-gpus#compute @@ -234,22 +237,22 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Set lld as the linker. +build:cuda_clang --host_linkopt="-fuse-ld=lld" +build:cuda_clang --host_linkopt="-lm" +build:cuda_clang --linkopt="-fuse-ld=lld" +build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" -build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" -build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc @@ -545,10 +548,6 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.17-clang_config_tensorrt" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" -test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --config=nvcc_clang @@ -633,7 +632,6 @@ build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/cla # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS -test:release_linux_base --test_env=LD_LIBRARY_PATH # Give only the list of failed tests at the end of the log test:release_linux_base --test_summary=short @@ -647,7 +645,6 @@ build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. # Note that linux cpu and cuda builds share the same toolchain now. build:release_gpu_linux --config=cuda_clang_official -test:release_gpu_linux --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" # Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think test:release_gpu_linux --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute @@ -656,6 +653,7 @@ build:release_arm64_linux --config=linux_arm64 build:release_arm64_linux --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" build:release_arm64_linux --config=mkl_aarch64_threadpool build:release_arm64_linux --copt=-flax-vector-conversions +test:release_arm64_linux --flaky_test_attempts=3 # The old gcc linux build options are preserved in the unsupported_*_linux # configs. If your project fails to build with Clang, you can use these @@ -677,9 +675,8 @@ build:unsupported_gpu_linux --config=unsupported_cpu_linux build:unsupported_gpu_linux --action_env=TF_CUDA_VERSION="11" build:unsupported_gpu_linux --action_env=TF_CUDNN_VERSION="8" build:unsupported_gpu_linux --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" -build:unsupported_gpu_linux --config=tensorrt build:unsupported_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" -build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64:/usr/local/tensorrt/lib" +build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain @@ -774,7 +771,7 @@ test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflo # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium --flaky_test_attempts=3 +test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 @@ -812,7 +809,7 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # inherit from build. build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP diff --git a/third_party/xla/.github/workflows/bazel_query.yml b/third_party/xla/.github/workflows/bazel_query.yml index 253218acba1149..969383fb09062f 100644 --- a/third_party/xla/.github/workflows/bazel_query.yml +++ b/third_party/xla/.github/workflows/bazel_query.yml @@ -35,4 +35,6 @@ jobs: - name: "Install bazelisk" run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 - name: "Run bazel query //xla/..." - run: bazelisk query //xla/... + run: bazelisk query //xla/... > /dev/null + - name: "Run bazel query deps(//xla/...)" + run: bazelisk query "deps(//xla/...)" > /dev/null diff --git a/third_party/xla/.kokoro/macos/build.sh b/third_party/xla/.kokoro/macos/build.sh index c3e0c126560afb..1aedf1badf55d2 100644 --- a/third_party/xla/.kokoro/macos/build.sh +++ b/third_party/xla/.kokoro/macos/build.sh @@ -37,32 +37,6 @@ function install_build_env_tools(){ sudo wget --no-verbose -O "/usr/local/bin/bazel" \ "https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" \ && chmod +x "/usr/local/bin/bazel" - - echo "===== Installing Pyenv =====" - # Install pyenv; Set up a virtual environment to control dependencies and their - # versions - git clone --branch v2.3.17 https://github.com/pyenv/pyenv.git /Users/kbuilder/.tf_pyenv - export PYENV_ROOT=/Users/kbuilder/.tf_pyenv - export PATH="$PYENV_ROOT/bin:$PATH" # if `pyenv` is not already on PATH - eval "$(pyenv init --path)" - eval "$(pyenv init -)" - - echo "===== Installing Python =====" - # Install Python and set the local python version - pyenv install -s "${TF_PYENV_VERSION}" - pyenv rehash - pyenv local "${TF_PYENV_VERSION}" - # Do a sanity check to make sure that we using the correct Python version - echo "===== Python version =====" - python --version - # Set up virtual environment and activate it - python -m venv /Users/kbuilder/.tf-venv && source /Users/kbuilder/.tf-venv/bin/activate - - # Setup links to Python. Referenced in ./macos.bazelrc - ln -s /Users/kbuilder/.tf-venv/lib/python* /Users/kbuilder/.tf-venv/lib/python - - echo "===== Upgrading to latest pip =====" - python -m pip install --upgrade pip } # Run the tests under /Volumes/BuildData/ so that we don't run into VM @@ -72,8 +46,6 @@ export TEST_TMPDIR=/Volumes/BuildData/bazel_output install_build_env_tools -python -m pip install numpy==1.21.4 - TARGET_FILTER="-//xla/hlo/experimental/... -//xla/python_api/... -//xla/python/... -//xla/service/gpu/..." TAGS_FILTER="-no_oss,-oss_excluded,-gpu,-no_mac,-nomac,-mac_excluded,-requires-gpu-nvidia,-requires-gpu-amd" diff --git a/third_party/xla/README.md b/third_party/xla/README.md index be0325eefc03ba..1a6d70a29cde25 100644 --- a/third_party/xla/README.md +++ b/third_party/xla/README.md @@ -7,6 +7,11 @@ The XLA compiler takes models from popular ML frameworks such as PyTorch, TensorFlow, and JAX, and optimizes them for high-performance execution across different hardware platforms including GPUs, CPUs, and ML accelerators. + + + OpenXLA Ecosystem + + ## Get started If you want to use XLA to compile your ML project, refer to the corresponding diff --git a/third_party/xla/WORKSPACE b/third_party/xla/WORKSPACE index 9d046e22949091..a18ebde79da786 100644 --- a/third_party/xla/WORKSPACE +++ b/third_party/xla/WORKSPACE @@ -52,3 +52,50 @@ xla_workspace1() load(":workspace0.bzl", "xla_workspace0") xla_workspace0() + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "@local_tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "@local_tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/third_party/xla/build_tools/build.py b/third_party/xla/build_tools/build.py index 14ad4fa189f666..ec989e21737e0e 100755 --- a/third_party/xla/build_tools/build.py +++ b/third_party/xla/build_tools/build.py @@ -23,7 +23,6 @@ The script also assumes that the working directory never changes modulo `cd`ing into the repo that should be built (mostly `github/xla`, but also JAX and TF). """ -import contextlib import dataclasses import enum import logging @@ -33,8 +32,8 @@ import time from typing import Any, Dict, List, Tuple -_KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} +_CONTAINER_NAME = "xla_ci" # TODO(ddunleavy): move this to the bazelrc _DEFAULT_BAZEL_OPTIONS = dict( test_output="errors", @@ -54,7 +53,7 @@ tty=True, volume="./github:/github", ) - +_KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} _XLA_DEFAULT_TARGET_PATTERNS = ( "//xla/...", "//build_tools/...", @@ -91,10 +90,37 @@ class BuildType(enum.Enum): @dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) -class DockerImage: - """Class representing a docker image.""" +class Build: + """Class representing a build of XLA.""" + type_: BuildType + repo: str image_url: str + target_patterns: Tuple[str, ...] + configs: Tuple[str, ...] = () + build_tag_filters: Tuple[str, ...] = () + test_tag_filters: Tuple[str, ...] = () + action_env: Dict[str, Any] = dataclasses.field(default_factory=dict) + test_env: Dict[str, Any] = dataclasses.field(default_factory=dict) + options: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def bazel_test_command(self) -> List[str]: + """Returns a bazel test command for this build. + + Returns: List of command line arguments + """ + options = _dict_to_cli_options(self.options) + configs = [f"--config={config}" for config in self.configs] + build_tag_filters = ( + f"--build_tag_filters={','.join(self.build_tag_filters)}" + ) + test_tag_filters = f"--test_tag_filters={','.join(self.test_tag_filters)}" + action_env = [f"--action_env={k}={v}" for k, v in self.action_env.items()] + test_env = [f"--test_env={k}={v}" for k, v in self.test_env.items()] + + tag_filters = [build_tag_filters, test_tag_filters] + all_options = tag_filters + configs + action_env + test_env + options + return ["bazel", "test", *all_options, "--", *self.target_patterns] def _pull_docker_image_with_retries(self, retries=3) -> None: """Pulls docker image with retries to avoid transient rate limit errors.""" @@ -112,10 +138,9 @@ def _pull_docker_image_with_retries(self, retries=3) -> None: # TODO(ddunleavy): get sha # _write_to_sponge_config("TF_INFO_DOCKER_SHA", sha) - @contextlib.contextmanager - def pull_and_run( + def pull_and_run_docker_image( self, - name: str = "xla_ci", + name: str, command: Tuple[str, ...] = ("bash",), **kwargs: Any, ): @@ -126,67 +151,19 @@ def pull_and_run( command: Command given to `docker run`, e.g. `bash` **kwargs: Extra options passed to `docker run`. - Yields: - A function that accepts a command as a list of args, and runs those on the - corresponding docker container. It shouldn't be used outside the `with` - block, as the container will be stopped after the end of the block. - - This manages pulling, starting, and stopping the container. Example usage: - ``` - with image.pull_and_run() as docker_exec: - docker_exec(["command", "--with", "--flags"]) - ``` + Returns: + None. """ - try: - self._pull_docker_image_with_retries() - options = _dict_to_cli_options(kwargs) - sh([ - "docker", - "run", - "--name", - name, - *options, - self.image_url, - *command, - ]) - docker_exec = lambda args: sh(["docker", "exec", name, *args]) - yield docker_exec - finally: - sh(["docker", "stop", name]) - - -@dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) -class Build: - """Class representing a build of XLA.""" + self._pull_docker_image_with_retries() - type_: BuildType - repo: str - docker_image: DockerImage - target_patterns: Tuple[str, ...] - configs: Tuple[str, ...] = () - build_tag_filters: Tuple[str, ...] = () - test_tag_filters: Tuple[str, ...] = () - action_env: Dict[str, Any] = dataclasses.field(default_factory=dict) - test_env: Dict[str, Any] = dataclasses.field(default_factory=dict) - options: Dict[str, Any] = dataclasses.field(default_factory=dict) + assert "workdir" not in kwargs + _, repo_name = self.repo.split("/") + workdir = f"/github/{repo_name}" - def bazel_test_command(self) -> List[str]: - """Returns a bazel test command for this build. - - Returns: List of command line arguments - """ - options = _dict_to_cli_options(self.options) - configs = [f"--config={config}" for config in self.configs] - build_tag_filters = ( - f"--build_tag_filters={','.join(self.build_tag_filters)}" - ) - test_tag_filters = f"--test_tag_filters={','.join(self.test_tag_filters)}" - action_env = [f"--action_env={k}={v}" for k, v in self.action_env.items()] - test_env = [f"--test_env={k}={v}" for k, v in self.test_env.items()] + options = ["--name", name, "--workdir", workdir] + options += _dict_to_cli_options(kwargs) - tag_filters = [build_tag_filters, test_tag_filters] - all_options = tag_filters + configs + action_env + test_env + options - return ["bazel", "test", *all_options, "--", *self.target_patterns] + sh(["docker", "run", *options, self.image_url, *command]) def _tag_filters_for_compute_capability( @@ -202,18 +179,12 @@ def _tag_filters_for_compute_capability( return tag_filters -_DEFAULT_IMAGE = DockerImage( - image_url="gcr.io/tensorflow-sigs/build:latest-python3.11", -) +_DEFAULT_IMAGE = "gcr.io/tensorflow-sigs/build:latest-python3.11" # TODO(b/338885148): Remove this once the TF containers have cuDNN 9 -_CUDNN_9_IMAGE = DockerImage( - image_url="gcr.io/tensorflow-sigs/build@sha256:0a9728e258d7e0e5830d1960a65968ffdc1d138af5441e30948918e0d50ab2c7", -) +_CUDNN_9_IMAGE = "gcr.io/tensorflow-sigs/build@sha256:0a9728e258d7e0e5830d1960a65968ffdc1d138af5441e30948918e0d50ab2c7" -_ARM64_JAX_MULTI_PYTHON_IMAGE = DockerImage( - image_url="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-latest-multi-python", -) +_ARM64_JAX_MULTI_PYTHON_IMAGE = "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-latest-multi-python" def nvidia_gpu_build_with_compute_capability( @@ -223,7 +194,7 @@ def nvidia_gpu_build_with_compute_capability( return Build( type_=type_, repo="openxla/xla", - docker_image=_CUDNN_9_IMAGE, + image_url=_CUDNN_9_IMAGE, target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, configs=configs, test_tag_filters=("-no_oss", "requires-gpu-nvidia") + extra_gpu_tags, @@ -245,7 +216,7 @@ def nvidia_gpu_build_with_compute_capability( _CPU_X86_BUILD = Build( type_=BuildType.CPU_X86, repo="openxla/xla", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=("warnings", "nonccl", "rbe_linux_cpu"), target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, build_tag_filters=cpu_x86_tag_filter, @@ -263,7 +234,7 @@ def nvidia_gpu_build_with_compute_capability( _CPU_ARM64_BUILD = Build( type_=BuildType.CPU_ARM64, repo="openxla/xla", - docker_image=_ARM64_JAX_MULTI_PYTHON_IMAGE, + image_url=_ARM64_JAX_MULTI_PYTHON_IMAGE, configs=("warnings", "rbe_cross_compile_linux_arm64_xla", "nonccl"), target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, options={**_DEFAULT_BAZEL_OPTIONS, "build_tests_only": True}, @@ -280,7 +251,7 @@ def nvidia_gpu_build_with_compute_capability( _JAX_CPU_BUILD = Build( type_=BuildType.JAX_CPU, repo="google/jax", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=( "avx_posix", "mkl_open_source_only", @@ -300,7 +271,7 @@ def nvidia_gpu_build_with_compute_capability( _JAX_GPU_BUILD = Build( type_=BuildType.JAX_GPU, repo="google/jax", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=( "avx_posix", "mkl_open_source_only", @@ -316,16 +287,14 @@ def nvidia_gpu_build_with_compute_capability( JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow", ), options=dict( - **_DEFAULT_BAZEL_OPTIONS, - override_repository="xla=/github/xla", - **{"//jax:build_cuda_plugin_from_source": True}, + **_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla" ), ) _TENSORFLOW_CPU_BUILD = Build( type_=BuildType.TENSORFLOW_CPU, repo="tensorflow/tensorflow", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=( "release_cpu_linux", "rbe_linux_cpu", @@ -349,7 +318,7 @@ def nvidia_gpu_build_with_compute_capability( _TENSORFLOW_GPU_BUILD = Build( type_=BuildType.TENSORFLOW_GPU, repo="tensorflow/tensorflow", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=( "release_gpu_linux", "rbe_linux_cuda", @@ -414,13 +383,24 @@ def main(): "github/xla/.bazelrc", ], ) + sh( + [ + "sed", + "-i", + r"s/8\.9\.7\.29/9.1.1/g", + "github/xla/.bazelrc", + ], + ) sh(["nvidia-smi"]) - with build.docker_image.pull_and_run( - workdir=f"/github/{repo_name}", **_DEFAULT_DOCKER_OPTIONS - ) as docker_exec: - docker_exec(build.bazel_test_command()) - docker_exec(["bazel", "analyze-profile", "profile.json.gz"]) + build.pull_and_run_docker_image( + _CONTAINER_NAME, + **_DEFAULT_DOCKER_OPTIONS, + ) + docker_exec = lambda cmd: sh(["docker", "exec", _CONTAINER_NAME, *cmd]) + docker_exec(build.bazel_test_command()) + docker_exec(["bazel", "analyze-profile", "profile.json.gz"]) + sh(["docker", "stop", _CONTAINER_NAME]) if __name__ == "__main__": diff --git a/third_party/xla/build_tools/configure/BUILD b/third_party/xla/build_tools/configure/BUILD index 6b84ba404c9043..ed518510f5eae3 100644 --- a/third_party/xla/build_tools/configure/BUILD +++ b/third_party/xla/build_tools/configure/BUILD @@ -33,6 +33,7 @@ py_test( data = [ "testdata/clang.bazelrc", "testdata/cuda_clang.bazelrc", + "testdata/default_cuda_clang.bazelrc", "testdata/gcc.bazelrc", "testdata/nvcc_clang.bazelrc", "testdata/nvcc_gcc.bazelrc", diff --git a/third_party/xla/build_tools/configure/configure.py b/third_party/xla/build_tools/configure/configure.py index 39cfd7a01ecbf0..43e0f234d49cfd 100755 --- a/third_party/xla/build_tools/configure/configure.py +++ b/third_party/xla/build_tools/configure/configure.py @@ -27,11 +27,6 @@ the clang in your path. If that isn't the correct clang, you can override like `./configure.py --backend=cpu --clang_path=`. -NOTE(ddunleavy): Lots of these things should probably be outside of configure.py -but are here because of complexity in `cuda_configure.bzl` and the TF bazelrc. -Once XLA has it's own bazelrc, and cuda_configure.bzl is replaced or refactored, -we can probably make this file smaller. - TODO(ddunleavy): add more thorough validation. """ import argparse @@ -45,18 +40,9 @@ import sys from typing import Optional -_REQUIRED_CUDA_LIBRARIES = ["cublas", "cuda", "cudnn"] _DEFAULT_BUILD_AND_TEST_TAG_FILTERS = ("-no_oss",) # Assume we are being invoked from the symlink at the root of the repo _XLA_SRC_ROOT = pathlib.Path(__file__).absolute().parent -_FIND_CUDA_CONFIG = str( - _XLA_SRC_ROOT - / "third_party" - / "tsl" - / "third_party" - / "gpus" - / "find_cuda_config.py" -) _XLA_BAZELRC_NAME = "xla_configure.bazelrc" _KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} @@ -224,11 +210,12 @@ class DiscoverablePathsAndVersions: ld_library_path: Optional[str] = None # CUDA specific - cublas_version: Optional[str] = None - cuda_toolkit_path: Optional[str] = None + cuda_version: Optional[str] = None cuda_compute_capabilities: Optional[list[str]] = None cudnn_version: Optional[str] = None - nccl_version: Optional[str] = None + local_cuda_path: Optional[str] = None + local_cudnn_path: Optional[str] = None + local_nccl_path: Optional[str] = None def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): """Gets paths and versions as needed by the config. @@ -247,7 +234,7 @@ def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): ) # Notably, we don't use `_find_executable_or_die` for lld, as it changes - # which commands it accepts based on it's name! ld.lld is symlinked to a + # which commands it accepts based on its name! ld.lld is symlinked to a # different executable just called lld, which should not be invoked # directly. self.lld_path = self.lld_path or shutil.which("ld.lld") @@ -261,64 +248,6 @@ def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): if not self.cuda_compute_capabilities: self.cuda_compute_capabilities = _get_cuda_compute_capabilities_or_die() - self._get_cuda_libraries_paths_and_versions_if_needed(config) - - def _get_cuda_libraries_paths_and_versions_if_needed( - self, config: "XLAConfigOptions" - ): - """Gets cuda paths and versions if user left any unspecified. - - This uses `find_cuda_config.py` to find versions for all libraries in - `_REQUIRED_CUDA_LIBRARIES`. - - Args: - config: config that determines which libraries should be found. - """ - should_find_nccl = config.using_nccl and self.nccl_version is None - any_cuda_config_unset = any([ - self.cublas_version is None, - self.cuda_toolkit_path is None, - self.cudnn_version is None, - should_find_nccl, - ]) - - maybe_nccl = ["nccl"] if should_find_nccl else [] - - if any_cuda_config_unset: - logging.info( - "Some CUDA config versions and paths were not provided, " - "so trying to find them using find_cuda_config.py" - ) - try: - find_cuda_config_proc = subprocess.run( - [ - sys.executable, - _FIND_CUDA_CONFIG, - *_REQUIRED_CUDA_LIBRARIES, - *maybe_nccl, - ], - capture_output=True, - check=True, - text=True, - ) - except subprocess.CalledProcessError as e: - logging.info("Command %s failed. Is CUDA installed?", e.cmd) - logging.info("Dumping %s ouptut:\n %s", e.cmd, e.output) - raise e - - cuda_config = dict( - tuple(line.split(": ")) - for line in find_cuda_config_proc.stdout.strip().split("\n") - ) - - self.cublas_version = self.cublas_version or cuda_config["cublas_version"] - self.cuda_toolkit_path = ( - self.cuda_toolkit_path or cuda_config["cuda_toolkit_path"] - ) - self.cudnn_version = self.cudnn_version or cuda_config["cudnn_version"] - if should_find_nccl: - self.nccl_version = self.nccl_version or cuda_config["nccl_version"] - @dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) class XLAConfigOptions: @@ -333,7 +262,6 @@ class XLAConfigOptions: # CUDA specific cuda_compiler: CudaCompiler using_nccl: bool - using_tensorrt: bool def to_bazelrc_lines( self, @@ -392,19 +320,31 @@ def to_bazelrc_lines( ) # Lines needed for CUDA backend regardless of CUDA/host compiler + if dpav.cuda_version: + rc.append( + f"build:cuda --repo_env HERMETIC_CUDA_VERSION={dpav.cuda_version}" + ) rc.append( - f"build --action_env CUDA_TOOLKIT_PATH={dpav.cuda_toolkit_path}" - ) - rc.append(f"build --action_env TF_CUBLAS_VERSION={dpav.cublas_version}") - rc.append( - "build --action_env" - f" TF_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" + "build:cuda --repo_env" + f" HERMETIC_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" ) - rc.append(f"build --action_env TF_CUDNN_VERSION={dpav.cudnn_version}") - rc.append(f"build --repo_env TF_NEED_TENSORRT={int(self.using_tensorrt)}") - if self.using_nccl: - rc.append(f"build --action_env TF_NCCL_VERSION={dpav.nccl_version}") - else: + if dpav.cudnn_version: + rc.append( + f"build:cuda --repo_env HERMETIC_CUDNN_VERSION={dpav.cudnn_version}" + ) + if dpav.local_cuda_path: + rc.append( + f"build:cuda --repo_env LOCAL_CUDA_PATH={dpav.local_cuda_path}" + ) + if dpav.local_cudnn_path: + rc.append( + f"build:cuda --repo_env LOCAL_CUDNN_PATH={dpav.local_cudnn_path}" + ) + if dpav.local_nccl_path: + rc.append( + f"build:cuda --repo_env LOCAL_NCCL_PATH={dpav.local_nccl_path}" + ) + if not self.using_nccl: rc.append("build --config nonccl") elif self.backend == Backend.ROCM: pass @@ -476,7 +416,6 @@ def _parse_args(): default="-Wno-sign-compare", ) parser.add_argument("--nccl", action="store_true") - parser.add_argument("--tensorrt", action="store_true") # Path and version overrides path_help = "Optional: will be found on PATH if possible." @@ -492,13 +431,35 @@ def _parse_args(): parser.add_argument("--lld_path", help=path_help) # CUDA specific - find_cuda_config_help = ( - "Optional: will be found using `find_cuda_config.py` if flag is not set." + parser.add_argument( + "--cuda_version", + help="Optional: CUDA will be downloaded by Bazel if the flag is set", + ) + parser.add_argument( + "--cudnn_version", + help="Optional: CUDNN will be downloaded by Bazel if the flag is set", + ) + parser.add_argument( + "--local_cuda_path", + help=( + "Optional: Local CUDA dir will be used in dependencies if the flag" + " is set" + ), + ) + parser.add_argument( + "--local_cudnn_path", + help=( + "Optional: Local CUDNN dir will be used in dependencies if the flag" + " is set" + ), + ) + parser.add_argument( + "--local_nccl_path", + help=( + "Optional: Local NCCL dir will be used in dependencies if the flag" + " is set" + ), ) - parser.add_argument("--cublas_version", help=find_cuda_config_help) - parser.add_argument("--cuda_toolkit_path", help=find_cuda_config_help) - parser.add_argument("--cudnn_version", help=find_cuda_config_help) - parser.add_argument("--nccl_version", help=find_cuda_config_help) return parser.parse_args() @@ -518,7 +479,6 @@ def main(): python_bin_path=args.python_bin_path, compiler_options=args.compiler_options, using_nccl=args.nccl, - using_tensorrt=args.tensorrt, ) bazelrc_lines = config.to_bazelrc_lines( @@ -527,11 +487,12 @@ def main(): gcc_path=args.gcc_path, lld_path=args.lld_path, ld_library_path=args.ld_library_path, - cublas_version=args.cublas_version, - cuda_compute_capabilities=args.cuda_compute_capabilities, - cuda_toolkit_path=args.cuda_toolkit_path, + cuda_version=args.cuda_version, cudnn_version=args.cudnn_version, - nccl_version=args.nccl_version, + cuda_compute_capabilities=args.cuda_compute_capabilities, + local_cuda_path=args.local_cuda_path, + local_cudnn_path=args.local_cudnn_path, + local_nccl_path=args.local_nccl_path, ) ) diff --git a/third_party/xla/build_tools/configure/configure_test.py b/third_party/xla/build_tools/configure/configure_test.py index e29e718b78547d..8457ff40aea3ee 100644 --- a/third_party/xla/build_tools/configure/configure_test.py +++ b/third_party/xla/build_tools/configure/configure_test.py @@ -34,12 +34,20 @@ # CUDA specific paths and versions _CUDA_SPECIFIC_PATHS_AND_VERSIONS = { - "cublas_version": "12.3", - "cuda_toolkit_path": "/usr/local/cuda-12.2", + "cuda_version": '"12.1.1"', "cuda_compute_capabilities": ["7.5"], - "cudnn_version": "8", + "cudnn_version": '"8.6"', + "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", +} +_CUDA_COMPUTE_CAPABILITIES_AND_LD_LIBRARY_PATH = { + "cuda_compute_capabilities": [ + "sm_50", + "sm_60", + "sm_70", + "sm_80", + "compute_90", + ], "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", - "nccl_version": "2", } @@ -66,6 +74,11 @@ def setUpClass(cls): with (testdata / "cuda_clang.bazelrc").open() as f: cls.cuda_clang_bazelrc_lines = [line.strip() for line in f.readlines()] + with (testdata / "default_cuda_clang.bazelrc").open() as f: + cls.default_cuda_clang_bazelrc_lines = [ + line.strip() for line in f.readlines() + ] + with (testdata / "nvcc_clang.bazelrc").open() as f: cls.nvcc_clang_bazelrc_lines = [line.strip() for line in f.readlines()] @@ -85,7 +98,6 @@ def test_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( @@ -107,7 +119,6 @@ def test_gcc_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( @@ -128,7 +139,6 @@ def test_cuda_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.CLANG, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( @@ -141,6 +151,27 @@ def test_cuda_clang_bazelrc(self): self.assertEqual(bazelrc_lines, self.cuda_clang_bazelrc_lines) + def test_default_cuda_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.CLANG, + using_nccl=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + clang_major_version=17, + **_CUDA_COMPUTE_CAPABILITIES_AND_LD_LIBRARY_PATH, + ) + ) + + self.assertEqual(bazelrc_lines, self.default_cuda_clang_bazelrc_lines) + def test_nvcc_clang_bazelrc(self): config = XLAConfigOptions( backend=Backend.CUDA, @@ -150,7 +181,6 @@ def test_nvcc_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( @@ -172,7 +202,6 @@ def test_nvcc_gcc_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( diff --git a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc index a6e7a423bfc490..502bc8541c1285 100644 --- a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc @@ -3,11 +3,9 @@ build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang build --config cuda_clang build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 -build --repo_env TF_NEED_TENSORRT=0 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc new file mode 100644 index 00000000000000..4623f6f52073fa --- /dev/null +++ b/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc @@ -0,0 +1,19 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-18/bin/clang +build --repo_env CC=/usr/lib/llvm-18/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang +build --config cuda_clang +build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc index e147dbd687b118..8cd19224698311 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc @@ -3,11 +3,9 @@ build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang build --config nvcc_clang build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 -build --repo_env TF_NEED_TENSORRT=0 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc index 863209697362de..be90a87545368b 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc @@ -1,10 +1,8 @@ build --action_env GCC_HOST_COMPILER_PATH=/usr/bin/gcc build --config cuda -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 -build --repo_env TF_NEED_TENSORRT=0 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/third_party/xla/docs/build_from_source.md b/third_party/xla/docs/build_from_source.md index c273f7f3cdf8c0..8b30f9995d08e3 100644 --- a/third_party/xla/docs/build_from_source.md +++ b/third_party/xla/docs/build_from_source.md @@ -65,12 +65,11 @@ docker exec xla_gpu ./configure.py --backend=CUDA docker exec xla_gpu bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` -If you want to build XLA targets with GPU support without Docker you need to -install the following additional dependencies: -[`cuda-12.3`](https://developer.nvidia.com/cuda-downloads), -[`cuDNN-8.9`](https://developer.nvidia.com/cudnn). +For more details regarding +[TensorFlow's GPU docker images you can check out this document.](https://www.tensorflow.org/install/source#gpu_support_3) -Then configure and build targets using the following commands: +You can build XLA targets with GPU support without Docker as well. Configure and +build targets using the following commands: ``` ./configure.py --backend=CUDA @@ -79,4 +78,4 @@ bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` For more details regarding -[TensorFlow's GPU docker images you can check out this document.](https://www.tensorflow.org/install/source#gpu_support_3) +[hermetic CUDA you can check out this document.](docs/hermetic_cuda.md) diff --git a/third_party/xla/docs/developer_guide.md b/third_party/xla/docs/developer_guide.md index 53b3efcd8cab5c..b736309b7fbc59 100644 --- a/third_party/xla/docs/developer_guide.md +++ b/third_party/xla/docs/developer_guide.md @@ -64,6 +64,16 @@ docker exec xla ./configure.py --backend=CUDA docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` +**NB:** please note that with hermetic CUDA rules, you don't have to build XLA +in Docker. You can build XLA for GPU on your machine without GPUs and without +NVIDIA driver installed: + +```sh +./configure.py --backend=CUDA + +bazel build --test_output=all --spawn_strategy=sandboxed //xla/... +``` + Your first build will take quite a while because it has to build the entire stack, including XLA, MLIR, and StableHLO. diff --git a/third_party/xla/docs/hermetic_cuda.md b/third_party/xla/docs/hermetic_cuda.md new file mode 100644 index 00000000000000..18cc228d743461 --- /dev/null +++ b/third_party/xla/docs/hermetic_cuda.md @@ -0,0 +1,544 @@ +# Hermetic CUDA overview + +Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s +locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL distributions, +and then use CUDA libraries and tools as dependencies in various Bazel targets. +This enables more reproducible builds for Google ML projects and supported CUDA +versions. + +## Supported hermetic CUDA, CUDNN versions + +The supported CUDA versions are specified in `CUDA_REDIST_JSON_DICT` +dictionary, +[third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + +The supported CUDNN versions are specified in `CUDNN_REDIST_JSON_DICT` +dictionary, +[third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + +The `.bazelrc` files of individual projects have `HERMETIC_CUDA_VERSION`, +`HERMETIC_CUDNN_VERSION` environment variables set to the versions used by +default when `--config=cuda` is specified in Bazel command options. + +## Environment variables controlling the hermetic CUDA/CUDNN versions + +`HERMETIC_CUDA_VERSION` environment variable should consist of major, minor and +patch CUDA version, e.g. `12.3.2`. +`HERMETIC_CUDNN_VERSION` environment variable should consist of major, minor and +patch CUDNN version, e.g. `9.1.1`. + +Three ways to set the environment variables for Bazel commands: + +``` +# Add an entry to your `.bazelrc` file +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# OR pass it directly to your specific build command +bazel build --config=cuda \ +--repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ +--repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# OR set the environment variable globally in your shell: +export HERMETIC_CUDA_VERSION="12.3.2" +export LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +export HERMETIC_CUDNN_VERSION="9.1.1" +``` + +If `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` are not present, the +hermetic CUDA/CUDNN repository rules will look up `TF_CUDA_VERSION` and +`TF_CUDNN_VERSION` environment variables values. This is made for the backward +compatibility with non-hermetic CUDA/CUDNN repository rules. + +The mapping between CUDA version and NCCL distribution version to be downloaded +is specified in [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + +## Upgrade hermetic CUDA/CUDNN version +1. Create and submit a pull request with updated `CUDA_REDIST_JSON_DICT`, + `CUDA_REDIST_JSON_DICT` dictionaries in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + + Update `CUDA_NCCL_WHEELS` in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + if needed. + + Update `REDIST_VERSIONS_TO_BUILD_TEMPLATES` in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + if needed. + +2. For RBE executions: update `TF_CUDA_VERSION` and/or `TF_CUDNN_VERSION` in + [toolchains/remote_config/rbe_config.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl). + +3. For RBE executions: update `cuda_version`, `cudnn_version`, `TF_CUDA_VERSION` + and `TF_CUDNN_VERSION` in + [toolchains/remote_config/configs.bzl](https://github.com/openxla/xla/blob/main/tools/toolchains/remote_config/configs.bzl). + +4. For each Google ML project create a separate pull request with updated + `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` in `.bazelrc` file. + + The PR presubmit job executions will launch bazel tests and download hermetic + CUDA/CUDNN distributions. Verify that the presubmit jobs passed before + submitting the PR. + +## Pointing to CUDA/CUDNN/NCCL redistributions on local file system + +You can use the local CUDA/CUDNN/NCCL dirs as a source of redistributions. The following additional environment variables are required: + +``` +LOCAL_CUDA_PATH +LOCAL_CUDNN_PATH +LOCAL_NCCL_PATH +``` + +Example: + +``` +# Add an entry to your `.bazelrc` file +build:cuda --repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" +build:cuda --repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +build:cuda --repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" + +# OR pass it directly to your specific build command +bazel build --config=cuda \ +--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ +--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ +--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" + +# OR set the environment variable globally in your shell: +export LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" +export LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +export LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" +``` + +The structure of the folders inside CUDA dir should be the following (as if the archived redistributions were unpacked into one place): + +``` +/ + include/ + bin/ + lib/ + nvvm/ +``` + +The structure of the folders inside CUDNN dir should be the following: + +``` + + include/ + lib/ +``` + +The structure of the folders inside NCCL dir should be the following: + +``` + + include/ + lib/ +``` + +## Custom CUDA/CUDNN archives and NCCL wheels + +There are three options that allow usage of custom CUDA/CUDNN distributions. + +### Custom CUDA/CUDNN redistribution JSON files + +This option allows to use custom distributions for all CUDA/CUDNN dependencies +in Google ML projects. + +1. Create `cuda_redist.json` and/or `cudnn_redist.json` files. + + `cuda_redist.json` show follow the format below: + + ``` + { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + } + }, + } + ``` + + `cudnn_redist.json` show follow the format below: + + ``` + { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + } + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + } + } + } + } + ``` + + The `relative_path` field can be replaced with `full_path` for the full URLs + and absolute local paths starting with `file:///`. + +2. In the downstream project dependent on XLA, update the hermetic cuda JSON + repository call in `WORKSPACE` file. Both web links and local file paths are + allowed. Example: + + ``` + _CUDA_JSON_DICT = { + "12.4.0": [ + "file:///home/user/Downloads/redistrib_12.4.0_updated.json", + ], + } + + _CUDNN_JSON_DICT = { + "9.0.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.0.0.json", + ], + } + + cuda_json_init_repository( + cuda_json_dict = _CUDA_JSON_DICT, + cudnn_json_dict = _CUDNN_JSON_DICT, + ) + ``` + + If JSON files contain relative paths to distributions, the path prefix should + be updated in `cuda_redist_init_repositories()` and + `cudnn_redist_init_repository()` calls. Example + + ``` + cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, + cuda_redist_path_prefix = "file:///usr/Downloads/dists/", + ) + ``` + +### Custom CUDA/CUDNN distributions + +This option allows to use custom distributions for some CUDA/CUDNN dependencies +in Google ML projects. + +1. In the downstream project dependent on XLA, remove the lines below: + + ``` + <...> + "CUDA_REDIST_JSON_DICT", + <...> + "CUDNN_REDIST_JSON_DICT", + <...> + + cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT, + ) + + load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", + ) + ``` + +2. In the same `WORKSPACE` file, create dictionaries with distribution paths. + + The dictionary with CUDA distributions show follow the format below: + + ``` + _CUSTOM_CUDA_REDISTRIBUTIONS = { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + } + }, + } + ``` + + The dictionary with CUDNN distributions show follow the format below: + + ``` + _CUSTOM_CUDNN_REDISTRIBUTIONS = { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + } + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + } + } + } + } + ``` + + The `relative_path` field can be replaced with `full_path` for the full URLs + and absolute local paths starting with `file:///`. + +2. In the same `WORKSPACE` file, pass the created dictionaries to the repository + rule. If the dictionaries contain relative paths to distributions, the path + prefix should be updated in `cuda_redist_init_repositories()` and + `cudnn_redist_init_repository()` calls. + + ``` + cuda_redist_init_repositories( + cuda_redistributions = _CUSTOM_CUDA_REDISTRIBUTIONS, + cuda_redist_path_prefix = "file:///home/usr/Downloads/dists/", + ) + + cudnn_redist_init_repository( + cudnn_redistributions = _CUSTOM_CUDNN_REDISTRIBUTIONS, + cudnn_redist_path_prefix = "file:///home/usr/Downloads/dists/cudnn/" + ) + ``` +### Combination of the options above + +In the example below, `CUDA_REDIST_JSON_DICT` is merged with custom JSON data in +`_CUDA_JSON_DICT`, and `CUDNN_REDIST_JSON_DICT` is merged with +`_CUDNN_JSON_DICT`. + +The distributions data in `_CUDA_DIST_DICT` overrides the content of resulting +CUDA JSON file, and the distributions data in `_CUDNN_DIST_DICT` overrides the +content of resulting CUDNN JSON file. The NCCL wheels data is merged from +`CUDA_NCCL_WHEELS` and `_NCCL_WHEEL_DICT`. + +``` +load( + //third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_PATH_PREFIX", + "CUDA_NCCL_WHEELS", + "CUDA_REDIST_JSON_DICT", + "CUDNN_REDIST_PATH_PREFIX", + "CUDNN_REDIST_JSON_DICT", +) + +_CUDA_JSON_DICT = { + "12.4.0": [ + "file:///usr/Downloads/redistrib_12.4.0_updated.json", + ], +} + +_CUDNN_JSON_DICT = { + "9.0.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.0.0.json", + ], +} + +cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT | _CUDA_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT | _CUDNN_JSON_DICT, +) + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) + +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +_CUDA_DIST_DICT = { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + }, + }, + "libcusolver": { + "linux-x86_64": { + "full_path": "file:///usr/Downloads/dists/libcusolver-linux-x86_64-11.6.0.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "libcusolver-linux-sbsa-11.6.0.99-archive.tar.xz", + }, + }, +} + +_CUDNN_DIST_DICT = { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + }, + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + }, + }, + }, +} + +cudnn_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS | _CUDA_DIST_DICT, + cuda_redist_path_prefix = "file:///usr/Downloads/dists/", +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS | _CUDNN_DIST_DICT, + cudnn_redist_path_prefix = "file:///usr/Downloads/dists/cudnn/" +) + +load( + "//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +_NCCL_WHEEL_DICT = { + "12.4.0": { + "x86_64-unknown-linux-gnu": { + "url": "https://files.pythonhosted.org/packages/38/00/d0d4e48aef772ad5aebcf70b73028f88db6e5640b36c38e90445b7a57c45/nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", + }, + }, +} + +nccl_redist_init_repository( + cuda_nccl_wheels = CUDA_NCCL_WHEELS | _NCCL_WHEEL_DICT, +) +``` + +## DEPRECATED: Non-hermetic CUDA/CUDNN usage +Though non-hermetic CUDA/CUDNN usage is deprecated, it might be used for +some experiments currently unsupported officially (for example, building wheels +on Windows with CUDA). + +Here are the steps to use non-hermetic CUDA installed locally in Google ML +projects: + +1. Delete calls to hermetic CUDA repository rules from the `WORKSPACE` + file of the project dependent on XLA. + +2. Add the calls to non-hermetic CUDA repository rules to the bottom of the + `WORKSPACE` file. + + For XLA and JAX: + ``` + load("@local_tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") + cuda_configure(name = "local_config_cuda") + load("@local_tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure") + nccl_configure(name = "local_config_nccl") + ``` + + For Tensorflow: + ``` + load("@local_tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") + cuda_configure(name = "local_config_cuda") + load("@local_tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure") + nccl_configure(name = "local_config_nccl") + ``` + +3. Set the following environment variables directly in your shell or in + `.bazelrc` file as shown below: + ``` + build:cuda --action_env=TF_CUDA_VERSION= + build:cuda --action_env=TF_CUDNN_VERSION= + build:cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES= + build:cuda --action_env=LD_LIBRARY_PATH= + build:cuda --action_env=CUDA_TOOLKIT_PATH= + build:cuda --action_env=TF_CUDA_PATHS= + build:cuda --action_env=NCCL_INSTALL_PATH= + ``` + + Note that `TF_CUDA_VERSION` and `TF_CUDNN_VERSION` should consist of major and + minor versions only (e.g. `12.3` for CUDA and `9.1` for CUDNN). + +4. Now you can run `bazel` command to use locally installed CUDA and CUDNN. + + For XLA, no changes in the command options are needed. + + For JAX, use `--override_repository=tsl=` flag in the Bazel command + options. + + For Tensorflow, use `--override_repository=local_tsl=` flag in the + Bazel command options. + +## Configure hermetic CUDA + +1. In the downstream project dependent on XLA, add the following lines to the + bottom of the `WORKSPACE` file: + + Note: use @local_tsl instead of @tsl in Tensorflow project. + + ``` + load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", + ) + + cuda_json_init_repository() + + load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", + ) + load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", + ) + + cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, + ) + + cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, + ) + + load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", + ) + + cuda_configure(name = "local_config_cuda") + + load( + "@local_tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", + ) + + nccl_redist_init_repository() + + load( + "@local_tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", + ) + + nccl_configure(name = "local_config_nccl") + ``` + +2. To select specific versions of hermetic CUDA and CUDNN, set the + `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` environment variables + respectively. Use only supported versions. You may set the environment + variables directly in your shell or in `.bazelrc` file as shown below: + ``` + build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" + build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + build:cuda --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" + ``` + +3. To enable Hermetic CUDA during test execution, or when running a binary via + bazel, make sure to add `--@local_config_cuda//cuda:include_hermetic_cuda_libs=true` + flag to your bazel command. You can provide it either directly in a shell or + in `.bazelrc`: + ``` + test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true + ``` + The flag is needed to make sure that CUDA dependencies are properly provided + to test executables. The flag is false by default to avoid unwanted coupling + of Google-released Python wheels to CUDA binaries. diff --git a/third_party/xla/docs/images/openxla.svg b/third_party/xla/docs/images/openxla.svg new file mode 100644 index 00000000000000..bb97db4af1c268 --- /dev/null +++ b/third_party/xla/docs/images/openxla.svg @@ -0,0 +1,266 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/third_party/xla/docs/images/openxla_dark.svg b/third_party/xla/docs/images/openxla_dark.svg new file mode 100644 index 00000000000000..ae2dc4c874c13f --- /dev/null +++ b/third_party/xla/docs/images/openxla_dark.svg @@ -0,0 +1,255 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index baafd35265caaf..5759a24c5d6d54 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -34,6 +34,7 @@ third_party/py/python_init_toolchains.bzl: third_party/py/python_repo.bzl: third_party/python_runtime/BUILD: third_party/repo.bzl: +third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD: third_party/stablehlo/BUILD: tools/toolchains/BUILD: tools/toolchains/clang6/BUILD: diff --git a/third_party/xla/third_party/nanobind/nanobind.BUILD b/third_party/xla/third_party/nanobind/nanobind.BUILD index c9f307b75ef0ca..72b47585b5e5d0 100644 --- a/third_party/xla/third_party/nanobind/nanobind.BUILD +++ b/third_party/xla/third_party/nanobind/nanobind.BUILD @@ -4,9 +4,12 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "nanobind", - srcs = glob([ - "src/*.cpp", - ]), + srcs = glob( + [ + "src/*.cpp", + ], + exclude = ["src/nb_combined.cpp"], + ), copts = ["-fexceptions"], defines = [ "NB_BUILD=1", diff --git a/third_party/xla/third_party/nanobind/pr438.patch b/third_party/xla/third_party/nanobind/pr438.patch deleted file mode 100644 index edb7d61700e03b..00000000000000 --- a/third_party/xla/third_party/nanobind/pr438.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/src/nb_enum.cpp b/src/nb_enum.cpp -index 86f64d1..91f3932 100644 ---- a/src/nb_enum.cpp -+++ b/src/nb_enum.cpp -@@ -73,6 +73,13 @@ static PyObject *nb_enum_get_doc(PyObject *self, void *) { - return result; - } - -+static PyObject *nb_enum_get_value(PyObject *self, void *) { -+ enum_supplement &supp = nb_enum_supplement(Py_TYPE(self)); -+ return supp.is_signed ? nb_enum_int_signed(self) -+ : nb_enum_int_unsigned(self); -+} -+ -+ - NB_NOINLINE static PyObject *nb_enum_int_signed(PyObject *o) { - type_data *t = nb_type_data(Py_TYPE(o)); - const void *p = inst_ptr((nb_inst *) o); -@@ -141,6 +148,8 @@ error: - static PyGetSetDef nb_enum_getset[] = { - { "__doc__", nb_enum_get_doc, nullptr, nullptr, nullptr }, - { "__name__", nb_enum_get_name, nullptr, nullptr, nullptr }, -+ { "name", nb_enum_get_name, nullptr, nullptr, nullptr }, -+ { "value", nb_enum_get_value, nullptr, nullptr, nullptr }, - { nullptr, nullptr, nullptr, nullptr, nullptr } - }; - -diff --git a/tests/test_enum.py b/tests/test_enum.py -index 2a6e9ff..1063eef 100644 ---- a/tests/test_enum.py -+++ b/tests/test_enum.py -@@ -14,6 +14,9 @@ def test01_unsigned_enum(): - assert int(t.Enum.A) == 0 - assert int(t.Enum.B) == 1 - assert int(t.Enum.C) == 0xffffffff -+ assert t.Enum.A.value == 0 -+ assert t.Enum.B.value == 1 -+ assert t.Enum.C.value == 0xffffffff - assert t.Enum(0) is t.Enum.A - assert t.Enum(1) is t.Enum.B - assert t.Enum(0xffffffff) is t.Enum.C -@@ -48,6 +51,9 @@ def test02_signed_enum(): - assert int(t.SEnum.A) == 0 - assert int(t.SEnum.B) == 1 - assert int(t.SEnum.C) == -1 -+ assert t.SEnum.A.value == 0 -+ assert t.SEnum.B.value == 1 -+ assert t.SEnum.C.value == -1 - assert t.SEnum(0) is t.SEnum.A - assert t.SEnum(1) is t.SEnum.B - assert t.SEnum(-1) is t.SEnum.C \ No newline at end of file diff --git a/third_party/xla/third_party/nanobind/pr461.patch b/third_party/xla/third_party/nanobind/pr461.patch deleted file mode 100644 index aa0a51b68175a3..00000000000000 --- a/third_party/xla/third_party/nanobind/pr461.patch +++ /dev/null @@ -1,39 +0,0 @@ -diff --git a/src/nb_type.cpp b/src/nb_type.cpp ---- a/src/nb_type.cpp -+++ b/src/nb_type.cpp -@@ -36,6 +36,11 @@ static PyObject **nb_weaklist_ptr(PyObje - return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr; - } - -+static PyGetSetDef inst_getset[] = { -+ { "__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr }, -+ { nullptr, nullptr, nullptr, nullptr, nullptr } -+}; -+ - static int inst_clear(PyObject *self) { - PyObject **dict = nb_dict_ptr(self); - if (dict) -@@ -923,8 +928,11 @@ PyObject *nb_type_new(const type_init_da - } - - bool has_traverse = false; -- for (PyType_Slot *ts = slots; ts != s; ++ts) -+ bool has_getset = false; -+ for (PyType_Slot *ts = slots; ts != s; ++ts) { - has_traverse |= ts->slot == Py_tp_traverse; -+ has_getset |= ts->slot == Py_tp_getset; -+ } - - Py_ssize_t dictoffset = 0, weaklistoffset = 0; - int num_members = 0; -@@ -948,6 +956,10 @@ PyObject *nb_type_new(const type_init_da - has_traverse = true; - } - spec.basicsize = (int) basicsize; -+ -+ if (!has_getset) { -+ *s++ = { Py_tp_getset, (void *) inst_getset }; -+ } - } - - if (is_weak_referenceable) { diff --git a/third_party/xla/third_party/nanobind/workspace.bzl b/third_party/xla/third_party/nanobind/workspace.bzl index 9f9022dbaa8d12..1c692d396e9b98 100644 --- a/third_party/xla/third_party/nanobind/workspace.bzl +++ b/third_party/xla/third_party/nanobind/workspace.bzl @@ -5,12 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "nanobind", - strip_prefix = "nanobind-1.9.2", - sha256 = "149a3da40b0a988513d8cf5e71db3037373823505a3c92f87b988c92d7e0ab34", - urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v1.9.2.tar.gz"), + strip_prefix = "nanobind-2.1.0", + sha256 = "c37c53c60ada5fe1c956e24bd4b83af669a2309bf952bd251f36a7d2fa3bacf0", + urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v2.1.0.tar.gz"), build_file = "//third_party/nanobind:nanobind.BUILD", - patch_file = [ - "//third_party/nanobind:pr438.patch", # Remove when updating to nanobind 2.0.0. - "//third_party/nanobind:pr461.patch", # Remove when updating to nanobind 2.0.0. - ], ) diff --git a/third_party/xla/third_party/py/python_repo.bzl b/third_party/xla/third_party/py/python_repo.bzl index f8fdd1033b5e2f..13aed2b687129f 100644 --- a/third_party/xla/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/py/python_repo.bzl @@ -255,8 +255,12 @@ def _basic_wildcard_match(name, patterns, expected_match_result, match_all): def _custom_python_interpreter_impl(ctx): version = ctx.attr.version - strip_prefix = ctx.attr.strip_prefix.format(version = version) - urls = [url.format(version = version) for url in ctx.attr.urls] + version_variant = ctx.attr.version_variant + strip_prefix = ctx.attr.strip_prefix.format( + version = version, + version_variant = version_variant, + ) + urls = [url.format(version = version, version_variant = version_variant) for url in ctx.attr.urls] binary_name = ctx.attr.binary_name if not binary_name: ver_chunks = version.split(".") @@ -272,13 +276,12 @@ def _custom_python_interpreter_impl(ctx): output = srcs_dir, ) - configure_params = [] + configure_params = list(ctx.attr.configure_params) if "CC" in ctx.os.environ: configure_params.append("CC={}".format(ctx.os.environ["CC"])) if "CXX" in ctx.os.environ: configure_params.append("CXX={}".format(ctx.os.environ["CXX"])) - configure_params.append("--enable-optimizations") configure_params.append("--prefix=%s" % install_path.realpath) _exec_and_check( ctx, @@ -361,6 +364,11 @@ custom_python_interpreter = repository_rule( "strip_prefix": attr.string(), "binary_name": attr.string(mandatory = False), "version": attr.string(), + "version_variant": attr.string(), + "configure_params": attr.string_list( + mandatory = False, + default = ["--enable-optimizations"], + ), }, ) diff --git a/third_party/xla/third_party/shardy/BUILD b/third_party/xla/third_party/shardy/BUILD index ea1ecdb548c1f4..bf3ae84c142f65 100644 --- a/third_party/xla/third_party/shardy/BUILD +++ b/third_party/xla/third_party/shardy/BUILD @@ -2,4 +2,7 @@ # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) -exports_files(srcs = ["workspace.bzl"]) +exports_files(srcs = [ + "temporary.patch", + "workspace.bzl", +]) diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index f8d02d38377ed3..e69de29bb2d1d6 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,655 +0,0 @@ -diff --git a/docs/sdy_export_passes.md b/docs/sdy_export_passes.md -index 7a7e3ef..add024c 100755 ---- a/docs/sdy_export_passes.md -+++ b/docs/sdy_export_passes.md -@@ -12,12 +12,3 @@ the edge), and replaces the op with its input. - - TODO(tomnatan): consider moving the sharding to all targets that can have a - sharding attached. --### `-sdy-update-non-divisible-input-output-shardings` -- --_Makes FuncOp inputs/outputs evenly sharded, removing any need for padding due to non-divisible shardings._ -- --Users of Shardy expect the function inputs/outputs to be evenly --divisible/shardable to avoid requiring padding their tensors. Propagation --may make inputs/outputs have non-divisible shardings, so this pass updates --them to the largest dimension sharding prefix of the original sharding that --is evenly sharded. -diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc -index aaa33c5..f6f88bc 100644 ---- a/shardy/dialect/sdy/ir/dialect.cc -+++ b/shardy/dialect/sdy/ir/dialect.cc -@@ -28,7 +28,6 @@ limitations under the License. - #include "llvm/ADT/SmallVector.h" - #include "llvm/Support/ErrorHandling.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/OperationSupport.h" -@@ -431,8 +430,7 @@ TensorShardingPerValueAttr TensorShardingPerValueAttr::getFullyOpen( - for (Type type : types) { - int64_t rank = 0; - // TODO(tomnatan): remove mlir:: once Attribute::dyn_cast is removed. -- if (auto tensorType = mlir::dyn_cast(type)) { -- assert(tensorType.hasStaticShape()); -+ if (auto tensorType = mlir::dyn_cast(type)) { - rank = tensorType.getRank(); - } - shardingPerResult.push_back( -diff --git a/shardy/dialect/sdy/ir/ops.td b/shardy/dialect/sdy/ir/ops.td -index 9478d7b..ca67f51 100644 ---- a/shardy/dialect/sdy/ir/ops.td -+++ b/shardy/dialect/sdy/ir/ops.td -@@ -135,12 +135,12 @@ def Sdy_ManualComputationOp : Sdy_Op<"manual_computation", - }]; - - let arguments = (ins -- Variadic:$tensors, -+ Variadic:$tensors, - Sdy_TensorShardingPerValue:$in_shardings, - Sdy_TensorShardingPerValue:$out_shardings, - Sdy_ManualAxes:$manual_axes - ); -- let results = (outs Variadic:$results); -+ let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$body); - - let assemblyFormat = [{ -@@ -249,6 +249,27 @@ def Sdy_ConstantOp : Sdy_Op<"constant", - }]; - } - -+//===----------------------------------------------------------------------===// -+// IdentityOp -+//===----------------------------------------------------------------------===// -+ -+def IdentityOp : Sdy_Op<"identity", -+ [Pure, Elementwise, SameOperandsAndResultType]> { -+ let summary = "Identity operation"; -+ -+ let description = [{ -+ An identity op that outputs the same value that it takes as input. This is -+ useful, to break a pattern where a block argument is directly used in the -+ block's terminator, which could result in canonicalization removing that -+ block argument, e.g., a block argument of a while op that could be replaced -+ with the corresponding operand as a free variable. -+ }]; -+ -+ let arguments = (ins AnyTensor:$input); -+ let results = (outs AnyTensor:$result); -+ let assemblyFormat = "attr-dict $input `:` type($input)"; -+} -+ - //===----------------------------------------------------------------------===// - // DataFlowEdgeOp - //===----------------------------------------------------------------------===// -@@ -316,10 +337,10 @@ def DataFlowEdgeOp : Sdy_Op<"data_flow_edge", - }]; - - let arguments = (ins -- AnyShaped:$input, -+ AnyRankedTensor:$input, - OptionalAttr:$sharding); - -- let results = (outs AnyShaped:$result); -+ let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = "$input (`sharding````=``` $sharding^)? attr-dict `:` type($result)"; - -@@ -360,10 +381,10 @@ def PropagationBarrierOp : Sdy_Op<"propagation_barrier", - }]; - - let arguments = (ins -- AnyRankedTensor:$input, -+ AnyTensor:$input, - Sdy_PropagationDirection:$allowed_direction - ); -- let results = (outs AnyRankedTensor:$result); -+ let results = (outs AnyTensor:$result); - let assemblyFormat = "$input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)"; - let hasVerifier = 1; - } -diff --git a/shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir b/shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir -index b247d79..c2a355d 100644 ---- a/shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir -+++ b/shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir -@@ -12,15 +12,6 @@ func.func @invalid_sharding(%arg0 : tensor<8xf32>) -> tensor<8xf32> { - - // ----- - --func.func @dynamic_shaped_type(%arg0: tensor) -- -> (tensor, tensor) { -- // expected-error @+1 {{expected sdy.data_flow_edge to have a static-shaped result}} -- %0 = sdy.data_flow_edge %arg0 : tensor -- return %arg0, %0 : tensor, tensor --} -- --// ----- -- - func.func @input_has_multiple_users(%arg0: tensor<32x96xf32>) - -> (tensor<32x96xf32>, tensor<32x96xf32>) { - // expected-error @+1 {{expected input of sdy.data_flow_edge to have a single user}} -diff --git a/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir b/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir -index e64c43c..9fc6e87 100644 ---- a/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir -+++ b/shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir -@@ -16,22 +16,6 @@ func.func @sharding_rule_wrong_attr_type(%arg0: tensor<8xf32>) -> tensor<8xf32> - - // ----- - --func.func @unranked_tensor_type(%arg0: tensor<*xf32>) -> tensor<*xf32> { -- // expected-error@+1 {{operand 0 - expected a ranked tensor with a static shape}} -- %0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=2, j=4}>} : tensor<*xf32> -- return %0 : tensor<*xf32> --} -- --// ----- -- --func.func @dynamic_shaped_tensor_type(%arg0: tensor) -> tensor { -- // expected-error@+1 {{operand 0 - expected a ranked tensor with a static shape}} -- %0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=2, j=4}>} : tensor -- return %0 : tensor --} -- --// ----- -- - func.func @operand_mappings_wrong_rank(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> { - // expected-error@+1 {{operand 1 - mapping rank must match: 1 != 2}} - %0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i])->([i, j]) {i=2, j=4}>} : tensor<2x4xf32> -diff --git a/shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir b/shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir -index 540ce8d..50394d1 100644 ---- a/shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir -+++ b/shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir -@@ -2,7 +2,7 @@ - - sdy.mesh @mesh = <"a"=2> - --// expected-error @+1 {{'func.func' op arg 0 - non-shaped tensors can only have a sharding with rank 0 and no replicated axes}} -+// expected-error @+1 {{'func.func' op arg 0 - non-ranked tensors can only have a sharding with rank 0 and no replicated axes}} - func.func @token_sharding_rank_non_zero(%arg0: !stablehlo.token {sdy.sharding=#sdy.sharding<@mesh, [{}]>}) -> !stablehlo.token { - return %arg0 : !stablehlo.token - } -@@ -11,31 +11,13 @@ func.func @token_sharding_rank_non_zero(%arg0: !stablehlo.token {sdy.sharding=#s - - sdy.mesh @mesh = <"a"=2> - --// expected-error @+1 {{'func.func' op arg 0 - non-shaped tensors can only have a sharding with rank 0 and no replicated axes}} -+// expected-error @+1 {{'func.func' op arg 0 - non-ranked tensors can only have a sharding with rank 0 and no replicated axes}} - func.func @token_sharding_with_replicated_axes(%arg0: !stablehlo.token {sdy.sharding=#sdy.sharding<@mesh, [], replicated={"a"}>}) -> !stablehlo.token { - return %arg0 : !stablehlo.token - } - - // ----- - --sdy.mesh @mesh = <"a"=2> -- --// expected-error @+1 {{'func.func' op arg 0 - only ranked tensors with a static shape can have a sharding}} --func.func @unranked_tensor_with_sharding(%arg0: tensor<*xf32> {sdy.sharding=#sdy.sharding<@mesh, []>}) -> tensor<*xf32> { -- return %arg0 : tensor<*xf32> --} -- --// ----- -- --sdy.mesh @mesh = <"a"=2> -- --// expected-error @+1 {{'func.func' op arg 0 - only ranked tensors with a static shape can have a sharding}} --func.func @dynamic_shaped_tensor_with_sharding(%arg0: tensor<*xf32> {sdy.sharding=#sdy.sharding<@mesh, [{}, {}]>}) -> tensor { -- return %arg0 : tensor<*xf32> --} -- --// ----- -- - sdy.mesh @mesh = <"a"=2, "b"=2> - - func.func @dim_shardings_rank_mismatch(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { -diff --git a/shardy/dialect/sdy/ir/utils.cc b/shardy/dialect/sdy/ir/utils.cc -index b184794..8831d58 100644 ---- a/shardy/dialect/sdy/ir/utils.cc -+++ b/shardy/dialect/sdy/ir/utils.cc -@@ -28,7 +28,6 @@ limitations under the License. - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/Diagnostics.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/Operation.h" -@@ -92,37 +91,26 @@ std::string operationToString(Operation* op) { - return mlirToString(op); - } - --std::string valueToString(Value value) { return mlirToString(&value); } -- --ShapedType dynCastStaticShapedType(Type type) { -- if (auto shapedType = dyn_cast(type); -- shapedType && shapedType.hasStaticShape()) { -- return shapedType; -- } -- return nullptr; --} -- --bool isStaticShapedType(Type type) { -- return dynCastStaticShapedType(type) != nullptr; -+std::string valueToString(Value value) { -+ return mlirToString(&value); - } - - ArrayRef getTensorShape(Value value) { -- if (auto tensorType = dyn_cast(value.getType())) { -+ if (auto tensorType = dyn_cast(value.getType())) { - return tensorType.getShape(); - } - return {}; - } - - int64_t getTensorRank(Value value) { -- if (auto tensorType = dyn_cast(value.getType())) { -+ if (auto tensorType = dyn_cast(value.getType())) { - return tensorType.getRank(); - } - return 0; - } - - int64_t isScalar(Value value) { -- if (auto tensorType = dyn_cast(value.getType()); -- tensorType && tensorType.hasRank()) { -+ if (auto tensorType = dyn_cast(value.getType())) { - return tensorType.getRank() == 0; - } - return false; -diff --git a/shardy/dialect/sdy/ir/utils.h b/shardy/dialect/sdy/ir/utils.h -index c151955..d0868a7 100644 ---- a/shardy/dialect/sdy/ir/utils.h -+++ b/shardy/dialect/sdy/ir/utils.h -@@ -26,7 +26,6 @@ limitations under the License. - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/Operation.h" - #include "mlir/IR/PatternMatch.h" -@@ -66,23 +65,12 @@ std::string operationToString(Operation* op); - // Converts `value` to string with location information. - std::string valueToString(Value value); - --// If the given `type` is a `ShapedType` with a static shape, returns it, --// otherwise returns nullptr. --ShapedType dynCastStaticShapedType(Type type); -- --// Returns true if the given `type` is a `ShapedType` with a static shape. --bool isStaticShapedType(Type type); -- --// Returns the shape of the given `value` if its type is a `ShapeTensor`, -+// Returns the shape of the given `value` if its type is a `RankedTensorType`, - // otherwise returns an empty array. --// --// Assumes the `ShapeTensor` has a rank. - ArrayRef getTensorShape(Value value); - --// Returns the rank of the given `value` if its type is a `ShapeTensor`, -+// Returns the rank of the given `value` if its type is a `RankedTensorType`, - // otherwise returns 0. --// --// Assumes the `ShapeTensor` has a rank. - int64_t getTensorRank(Value value); - - // Returns true if the value is a tensor with rank 0. -diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc -index 61fd0e0..015e10f 100644 ---- a/shardy/dialect/sdy/ir/verifiers.cc -+++ b/shardy/dialect/sdy/ir/verifiers.cc -@@ -30,7 +30,6 @@ limitations under the License. - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/Diagnostics.h" - #include "mlir/IR/SymbolTable.h" -@@ -200,11 +199,11 @@ LogicalResult emitBoundAxisInManualComputationError(EmitErrorFn emitError, - - // Verifies the following for `shardingAttr`: - // --// If `type` isn't a `ShapedType`, the sharding must have rank 0 and no -+// If `type` isn't a `RankedTensorType`, the sharding must have rank 0 and no - // replicated axes. - // --// - The tensor should have a rank and static shape. --// - The number of dimension shardings is equal to the rank of the tensor. -+// - The number of dimension shardings is equal to the rank of the tensor -+// (specified by `type`, which should be a `RankedTensorType`). - // - Dimensions of size 0 aren't sharded. - // - Replicated axes are ordered w.r.t. `mesh` (see - // AxisRefAttr::getMeshComparator). -@@ -221,22 +220,17 @@ LogicalResult verifyTensorShardingAttr( - TensorShardingAttr shardingAttr, Type type, MeshAttr mesh, - EmitErrorFn emitError, - ManualAxisToOwner alreadyManualAxes = ManualAxisToOwner()) { -- auto tensorType = dyn_cast(type); -+ auto tensorType = dyn_cast(type); - if (!tensorType) { - if (shardingAttr.getRank() != 0 || - !shardingAttr.getReplicatedAxes().empty()) { - return emitError( -- "non-shaped tensors can only have a sharding with rank 0 ") -+ "non-ranked tensors can only have a sharding with rank 0 ") - << "and no replicated axes. type: " << type - << ", sharding: " << shardingAttr; - } - return success(); - } -- if (!tensorType.hasStaticShape()) { -- return emitError( -- "only ranked tensors with a static shape can have a sharding. ") -- << "type: " << type; -- } - int64_t rank = tensorType.getRank(); - if (shardingAttr.getRank() != rank) { - return emitError("sharding doesn't match tensor rank: ") -@@ -432,6 +426,7 @@ LogicalResult verifyShardingRuleMapping(Operation* op, TypeRange types, - // doesn't reuse the same factor. - BitVector valueSeenFactorIndices(factorSizes.size()); - auto [type, mapping] = typeAndMapping; -+ auto tensorType = cast(type); - - EmitErrorFn valueEmitError = getEmitValueInRangeErrorFn( - [op, valueKindStr](StringRef msg) { -@@ -439,13 +434,6 @@ LogicalResult verifyShardingRuleMapping(Operation* op, TypeRange types, - }, - types.size(), index); - -- auto tensorType = dynCastStaticShapedType(type); -- if (!tensorType) { -- return valueEmitError( -- "expected a ranked tensor with a static shape. type: ") -- << type; -- } -- - if (mapping.getRank() != tensorType.getRank()) { - return valueEmitError("mapping rank must match: ") - << mapping.getRank() << " != " << tensorType.getRank(); -@@ -571,11 +559,6 @@ LogicalResult ReshardOp::verify() { - } - - LogicalResult DataFlowEdgeOp::verify() { -- if (!getType().hasStaticShape()) { -- return emitOpError( -- "expected sdy.data_flow_edge to have a static-shaped result. ") -- << "type: " << getType(); -- } - if (!getInput().hasOneUse()) { - return emitOpError( - "expected input of sdy.data_flow_edge to have a single user"); -@@ -682,8 +665,8 @@ LogicalResult verifyManualComputationValue( - for (auto [valueIndex, valueEntry] : llvm::enumerate(llvm::zip_equal( - globalTypes, localTypes, shardingPerValueAttr.getShardings()))) { - auto [globalType, localType, sharding] = valueEntry; -- auto globalRankedType = cast(globalType); -- auto localRankedType = cast(localType); -+ auto globalRankedType = globalType.template cast(); -+ auto localRankedType = localType.template cast(); - - // 5. Verify the manual axes come before any free axes in each dim sharding. - for (auto [dim, dimSharding] : -@@ -710,7 +693,7 @@ LogicalResult verifyManualComputationValue( - accumulatedManualAxesSize(op, dimSharding.getAxes(), - manualAxes, mesh)); - } -- auto expectedLocalRankedType = -+ RankedTensorType expectedLocalRankedType = - RankedTensorType::get(newDimSizes, globalRankedType.getElementType()); - if (expectedLocalRankedType != localRankedType) { - return op->emitOpError(valueKindStr) -diff --git a/shardy/dialect/sdy/transforms/export/update_non_divisible_input_output_shardings.cc b/shardy/dialect/sdy/transforms/export/update_non_divisible_input_output_shardings.cc -index 22c4269..6a4d05c 100644 ---- a/shardy/dialect/sdy/transforms/export/update_non_divisible_input_output_shardings.cc -+++ b/shardy/dialect/sdy/transforms/export/update_non_divisible_input_output_shardings.cc -@@ -23,7 +23,6 @@ limitations under the License. - #include "llvm/Support/ErrorHandling.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/BuiltinAttributes.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/TypeRange.h" - #include "mlir/IR/Value.h" -@@ -59,7 +58,8 @@ namespace { - // - [{"y","x"}] : tensor<4xf32> -> [{"y","x":(1)2}] : tensor<4xf32> - // See update_non_divisible_input_output_shardings.mlir for more examples. - TensorShardingAttr getEvenlySharded(TensorShardingAttr sharding, -- ShapedType type, func::FuncOp funcOp) { -+ RankedTensorType type, -+ func::FuncOp funcOp) { - StringRef meshName = sharding.getMeshName(); - MeshAttr mesh = getMeshAttr(funcOp, meshName); - assert(mesh && "unknown mesh"); -@@ -130,7 +130,7 @@ void updateValueShardings( - func::FuncOp funcOp) { - for (auto [index, type] : llvm::enumerate(types)) { - TensorShardingAttr sharding = getSharding(index); -- if (auto tensorType = dynCastStaticShapedType(type); -+ if (auto tensorType = dyn_cast(type); - sharding && tensorType) { - setSharding(index, getEvenlySharded(sharding, tensorType, funcOp)); - } -diff --git a/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc b/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc -index 91b5acb..b67c18c 100644 ---- a/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc -+++ b/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc -@@ -47,8 +47,8 @@ struct AddDataFlowEdgesPass - ValueRange edgeRoots = getDataFlowEdgeRoots(op); - rewriter.setInsertionPointAfter(op); - for (Value edgeRoot : edgeRoots) { -- if (!isStaticShapedType(edgeRoot.getType())) { -- // Skip non-static-shaped tensors, e.g., tokens. -+ if (!isa(edgeRoot.getType())) { -+ // Skip non-tensor values, e.g., tokens. - continue; - } - TensorShardingAttr sharding = nullptr; -diff --git a/shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir b/shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir -index 67cede6..f31387d 100644 ---- a/shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir -+++ b/shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir -@@ -66,16 +66,6 @@ func.func @optimization_barrier(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf3 - return %0#0, %0#1 : tensor<32x96xf32>, tensor<32x96xf32> - } - --// CHECK-LABEL: func @optimization_barrier --func.func @optimization_barrier_dynamic_shaped_tensor_skipped(%arg0: tensor<32x96xf32>, %arg1: tensor) -- -> (tensor<32x96xf32>, tensor) { -- // CHECK-NEXT: %[[OPT_BARRIER:.*]]:2 = stablehlo.optimization_barrier %arg0, %arg1 -- // CHECK: %[[EDGE_1:.*]] = sdy.data_flow_edge %[[OPT_BARRIER]]#0 -- // CHECK-NEXT: return %[[EDGE_1]], %[[OPT_BARRIER]]#1 -- %0:2 = stablehlo.optimization_barrier %arg0, %arg1 : tensor<32x96xf32>, tensor -- return %0#0, %0#1 : tensor<32x96xf32>, tensor --} -- - // CHECK-LABEL: func @while_unused_result - func.func @while_unused_result(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { - // CHECK: %[[C0:.*]] = stablehlo.constant dense<0> -diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc -index 8117426..eff74a3 100644 ---- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc -+++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc -@@ -28,7 +28,6 @@ limitations under the License. - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinOps.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/Diagnostics.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/OpDefinition.h" -@@ -46,6 +45,7 @@ limitations under the License. - #include "shardy/dialect/sdy/ir/data_flow_utils.h" - #include "shardy/dialect/sdy/ir/dialect.h" - #include "shardy/dialect/sdy/ir/utils.h" -+#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h" - #include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h" - #include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h" - #include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h" -@@ -328,9 +328,9 @@ LogicalResult propagateFuncResults(FuncOp funcOp, - const FactorPropagation& factorPropagation) { - for (OpOperand& returnOperand : getBodyTerminatorOpOperands(funcOp)) { - Value returnValue = returnOperand.get(); -- auto tensorType = dynCastStaticShapedType(returnValue.getType()); -+ auto tensorType = dyn_cast(returnValue.getType()); - if (!tensorType) { -- // Skip non-static-shaped tensors, e.g., tokens. -+ // Skip non-tensor values, e.g., tokens. - continue; - } - int64_t resNum = returnOperand.getOperandNumber(); -@@ -436,7 +436,7 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern { - return propagateTensorShardings( - sources, dataFlowEdgeOp.getResult(), - createIdentityShardingRule( -- cast(dataFlowEdgeOp.getType()), sources.size()), -+ cast(dataFlowEdgeOp.getType()), sources.size()), - dataFlowEdgeOp, rewriter, factorPropagation); - } - -diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc -index 3763581..2b8ff59 100644 ---- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc -+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc -@@ -23,7 +23,6 @@ limitations under the License. - #include - - #include "llvm/ADT/STLExtras.h" --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/Operation.h" -@@ -88,12 +87,12 @@ OpShardingRuleBuilder::OpShardingRuleBuilder( - resultMappings.reserve(resultTypes.size()); - int64_t maxRank = 0; - for (Type operandType : operandTypes) { -- int64_t rank = cast(operandType).getRank(); -+ int64_t rank = cast(operandType).getRank(); - maxRank = std::max(maxRank, rank); - operandMappings.push_back(TensorMapping(rank)); - } - for (Type resultType : resultTypes) { -- int64_t rank = cast(resultType).getRank(); -+ int64_t rank = cast(resultType).getRank(); - maxRank = std::max(maxRank, rank); - resultMappings.push_back(TensorMapping(rank)); - } -@@ -126,7 +125,7 @@ OpShardingRuleAttr OpShardingRuleBuilder::build() { - OpShardingRuleAttr OpShardingRuleBuilder::buildPointwise(Operation* op) { - // All results should have the same shape, so we look at the first. - ArrayRef shape = -- cast(op->getResultTypes().front()).getShape(); -+ cast(op->getResultTypes().front()).getShape(); - - OpShardingRuleBuilder builder(op); - -@@ -201,7 +200,7 @@ OpShardingRuleBuilder& OpShardingRuleBuilder::addPointwiseIfDimSizesMatch( - return *this; - } - --OpShardingRuleAttr createIdentityShardingRule(ShapedType type, -+OpShardingRuleAttr createIdentityShardingRule(RankedTensorType type, - size_t numOperands, - size_t numResults) { - return OpShardingRuleBuilder(SmallVector(numOperands, type), -diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h -index 5130827..5d0b5a8 100644 ---- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h -+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h -@@ -22,7 +22,6 @@ limitations under the License. - #include - #include - --#include "mlir/IR/BuiltinTypeInterfaces.h" - #include "mlir/IR/BuiltinTypes.h" - #include "mlir/IR/MLIRContext.h" - #include "mlir/IR/Operation.h" -@@ -119,7 +118,7 @@ class OpShardingRuleBuilder { - // i.e., all operands/results have the same mapping. - // - // NOTE: an empty rule {([])->([])} will be created for scalar ops. --OpShardingRuleAttr createIdentityShardingRule(ShapedType type, -+OpShardingRuleAttr createIdentityShardingRule(RankedTensorType type, - size_t numOperands = 1, - size_t numResults = 1); - -diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -index 98fa7a1..80e4933 100644 ---- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -@@ -144,8 +144,8 @@ OpShardingRuleAttr getOrCreateShardingRule(Operation* op, - OpShardingRuleAttr createOpShardingRule(Operation* op, - const bool conservativePropagation) { - return TypeSwitch(op) -- .Case, %arg1: tensor<8x16xf32>) - return %0 : tensor<8x16xf32> - } - --// CHECK-LABEL: func @token_func_output_skipped( -+// CHECK-LABEL: func @token_func_output_token_skipped( - // CHECK-SAME: %arg0: !stablehlo.token, - // CHECK-SAME: %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) - // CHECK-SAME: -> (!stablehlo.token, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) { --func.func @token_func_output_skipped(%arg0: !stablehlo.token, %arg1: tensor<8x16xf32>) -+func.func @token_func_output_token_skipped(%arg0: !stablehlo.token, %arg1: tensor<8x16xf32>) - -> (!stablehlo.token, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) { - // CHECK-NEXT: stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>]>} - %0 = stablehlo.add %arg1, %arg1 : tensor<8x16xf32> - return %arg0, %0 : !stablehlo.token, tensor<8x16xf32> - } - --// CHECK-LABEL: func @dynamic_shaped_func_output_skipped( --// CHECK-SAME: %arg0: tensor, --// CHECK-SAME: %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) --// CHECK-SAME: -> (tensor, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) { --func.func @dynamic_shaped_func_output_skipped(%arg0: tensor, %arg1: tensor<8x16xf32>) -- -> (tensor, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) { -- // CHECK-NEXT: stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>]>} -- %0 = stablehlo.add %arg1, %arg1 : tensor<8x16xf32> -- return %arg0, %0 : tensor, tensor<8x16xf32> --} -- - // CHECK-LABEL: func @func_result_intermediate_op_both_updated( - // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) - // CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) { -diff --git a/shardy/integrations/python/ir/__init__.py b/shardy/integrations/python/ir/__init__.py -index 97e8a3b..89a06ba 100644 ---- a/shardy/integrations/python/ir/__init__.py -+++ b/shardy/integrations/python/ir/__init__.py -@@ -17,6 +17,7 @@ - # pylint: disable=g-multiple-import,g-importing-member,unused-import,useless-import-alias - from ._sdy_ops_gen import ( - ConstantOp as ConstantOp, -+ IdentityOp as IdentityOp, - ManualComputationOp as ManualComputationOp, - MeshOp as MeshOp, - ReshardOp as ReshardOp, -diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 0d420ba..88869a4 100644 ---- a/third_party/llvm/workspace.bzl -+++ b/third_party/llvm/workspace.bzl -@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") - - def repo(name): - """Imports LLVM.""" -- LLVM_COMMIT = "41491c77231e9d389ef18593be1fab4f4e810e88" -- LLVM_SHA256 = "10b17d9f8304eb7c9fb91f7b13f73e9e5ca81984aa692eac91b82d19db311547" -+ LLVM_COMMIT = "0c25f85e5b88102363c0cd55e1946053d5827e99" -+ LLVM_SHA256 = "851d958e60193edfb54d6eb8644785179eeb604edae8c026ac1819e82c059f6c" - - tf_http_archive( - name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index c50cb5177e2d70..6d91def025b34a 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "ef636ca340e01a2ef3f910bb0dffc4539019f793" - SHARDY_SHA256 = "ad87f171909ba0e7c9879e7f3e57c31e25f0fbe935e14ebe2dbd45ed4c64f632" + SHARDY_COMMIT = "7e3ddfb532b3b53cb0b108014c24a86ac147e9f6" + SHARDY_SHA256 = "1d304e1e6f1132fe3ccb969d28798bc6ee90db84d10c85113ef8573eae350325" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/spirv_llvm_translator/BUILD b/third_party/xla/third_party/spirv_llvm_translator/BUILD new file mode 100644 index 00000000000000..8d626dc7635d1a --- /dev/null +++ b/third_party/xla/third_party/spirv_llvm_translator/BUILD @@ -0,0 +1,7 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +# spirv_llvm_translator license placeholder diff --git a/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD new file mode 100644 index 00000000000000..557e2e8f50edd2 --- /dev/null +++ b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD @@ -0,0 +1,34 @@ +cc_library( + name = "spirv_llvm_translator", + srcs = glob([ + "lib/SPIRV/libSPIRV/*.cpp", + "lib/SPIRV/libSPIRV/*.hpp", + "lib/SPIRV/libSPIRV/*.h", + "lib/SPIRV/Mangler/*.cpp", + "lib/SPIRV/Mangler/*.h", + "lib/SPIRV/*.cpp", + "lib/SPIRV/*.hpp", + "lib/SPIRV/*.h", + ]), + hdrs = glob(["include/*"]), + includes = [ + "include/", + "lib/SPIRV/", + "lib/SPIRV/Mangler/", + "lib/SPIRV/libSPIRV/", + ], + visibility = ["//visibility:public"], + deps = [ + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@spirv_headers//:spirv_cpp_headers", + ], +) diff --git a/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch new file mode 100644 index 00000000000000..fc843b1b039b09 --- /dev/null +++ b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch @@ -0,0 +1,25 @@ +diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h +index a828add8..924e13b4 100644 + +Spir backend uses different addrspace representations link with nvptx backend link. +We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding +changing addrspace based on device backend everywhere) + +--- a/lib/SPIRV/SPIRVInternal.h ++++ b/lib/SPIRV/SPIRVInternal.h +@@ -179,11 +179,12 @@ typedef SPIRVMap IntBoolOpMap; + "-v512:512:512-v1024:1024:1024" + + enum SPIRAddressSpace { +- SPIRAS_Private, ++ SPIRAS_Generic, + SPIRAS_Global, +- SPIRAS_Constant, ++ SPIRAS_Internal, + SPIRAS_Local, +- SPIRAS_Generic, ++ SPIRAS_Constant, ++ SPIRAS_Private, + SPIRAS_GlobalDevice, + SPIRAS_GlobalHost, + SPIRAS_Input, \ No newline at end of file diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 8b137891791fe9..77fefee2b13b6d 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1 +1,28 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -1283,6 +1283,7 @@ + "@llvm-project//mlir:AllExtensions", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", ++ "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:TosaDialect", + ], + ) +diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py b/stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py +--- stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py ++++ stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py +@@ -32,9 +32,9 @@ + + # Make LLVM and StableHLO tools available in RUN directives + tools = [ +- 'stablehlo-opt', +- 'FileCheck', +- 'stablehlo-translate', ++ 'stablehlo-opt', ++ 'FileCheck', ++ 'stablehlo-translate', + ] + tool_dirs = [ + config.llvm_tools_dir, diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index b46c39e85fc240..6c0cea3e8f16f5 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "2fcbae0c933b5cc2735523bab2de880a3a9c5e46" - STABLEHLO_SHA256 = "14f879b246266dc7c5cb49cdbf88c87ebac0444e3ebae04b57448d4bbc2fe180" + STABLEHLO_COMMIT = "23d3e1414b0be1c1b5256f0949520dc4f0a0705c" + STABLEHLO_SHA256 = "ad694a3da43a2a432c8c5f1c60be39fc211e28834cca07cf663ce8dc85d920fe" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch b/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch deleted file mode 100644 index 4a1f47d79e6c92..00000000000000 --- a/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch +++ /dev/null @@ -1,18 +0,0 @@ -# Do not upstream this patch. This has been already upstreamed in -# https://github.com/triton-lang/triton/commit/de46a0ede6efe7e93c2a9ebef639e36c6177c511 -# Next integration will include it and this patch should be removed then. - -diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc ---- a/third_party/amd/python/triton_amd.cc -+++ b/third_party/amd/python/triton_amd.cc -@@ -193,9 +193,7 @@ void init_triton_amd(py::module &&m) { - target->createMCAsmBackend(*sti, *mri, mcOptions)); - mcStreamer.reset(target->createMCObjectStreamer( - triple, ctx, std::move(mab), mab->createObjectWriter(svos), -- std::move(ce), *sti, mcOptions.MCRelaxAll, -- mcOptions.MCIncrementalLinkerCompatible, -- /*DWARFMustBeAtTheEnd=*/false)); -+ std::move(ce), *sti)); - - std::unique_ptr parser( - createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index 8162fb5fad6342..656b9c894904d8 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -8,6 +8,5 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton:llvm_integration/cl657620552.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch b/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch deleted file mode 100644 index a92166eef6df71..00000000000000 --- a/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch +++ /dev/null @@ -1,35 +0,0 @@ -# This temporary patch has already been included to the public list of Triton -# patches. It is only here temporarily to be included in the openxla version, -# but it will be removed during the next triton integration. - ---- a/third_party/nvidia/backend/driver.c -+++ b/third_party/nvidia/backend/driver.c -@@ -154,6 +154,8 @@ static PyObject *loadBinary(PyObject *se - typedef CUresult (*cuOccupancyMaxActiveClusters_t)( - int *numClusters, CUfunction func, const CUlaunchConfig *config); - -+#if CUDA_VERSION < 12000 -+#else - typedef CUresult (*cuTensorMapEncodeTiled_t)( - CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, -@@ -161,6 +161,7 @@ typedef CUresult (*cuTensorMapEncodeTile - const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill); -+#endif - - #define defineGetFunctionHandle(name, symbolName) \ - static symbolName##_t name() { \ -@@ -187,8 +187,11 @@ typedef CUresult (*cuTensorMapEncodeTile - defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, - cuOccupancyMaxActiveClusters); - -+#if CUDA_VERSION < 12000 -+#else - defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, - cuTensorMapEncodeTiled); -+#endif - - static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { - int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl index 388e57f849f14e..4fa55269e3323c 100644 --- a/third_party/xla/third_party/triton/temporary/series.bzl +++ b/third_party/xla/third_party/triton/temporary/series.bzl @@ -14,7 +14,5 @@ those to this list. """ temporary_patch_list = [ - "//third_party/triton:temporary/cuda11-temporary.patch", - "//third_party/triton:temporary/undo_tesla_gpu.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch b/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch deleted file mode 100644 index 6c2d1d1d734fbc..00000000000000 --- a/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch +++ /dev/null @@ -1,13 +0,0 @@ -This can be removed on the next integrate as it already exists in upstream. -diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -@@ -21,7 +21,7 @@ namespace { - static int getMMAVersionSafe(int computeCapability, DotOp op) { - // List supported mma version in order of preference. - SmallVector versionsSupported; -- if (computeCapability < 80) { -+ if (computeCapability < 75) { - versionsSupported = {1}; - } else if (computeCapability < 90) { - versionsSupported = {2}; diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index fc6c45f7bc1e5e..e74434221f6c98 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl657175856" - TRITON_SHA256 = "316f421a7d7ead2b7e5adc2e8bb68ce1a8f7809db73dbed8abd54c35bd0c1576" + TRITON_COMMIT = "cl664783844" + TRITON_SHA256 = "d5779d331008dd3a4941dd59e61385ec964987da74454248446ac3e36b874007" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch index a1c011dbb8beb5..dadc7732a4f280 100644 --- a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch @@ -57,7 +57,7 @@ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dia index 012786dae..6043b764a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp -@@ -498,6 +498,119 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, +@@ -498,6 +498,123 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } @@ -173,6 +173,10 @@ index 012786dae..6043b764a 100644 + ArrayRef tensorShape) const { + return ::getShapePerCTATile(getParent(), tensorShape); +} ++std::optional SparseDotMetaEncodingAttr::toLinearLayout( ++ ArrayRef shape) const { ++ return ::toLinearLayout(shape, getParent()); ++} + } // namespace gpu } // namespace triton @@ -273,9 +277,9 @@ index d74e0a224..4e45f7c4c 100644 + return op->hasTrait() || isa(op); +} + - // Replace the ForOp's yield with a new one with the given operands appended. - static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { - // Fix up the yield op. + static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + tt::CoarseSchedule &schedule, @@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index 76f824f372e0d3..9e565e91a1b903 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -219,13 +219,16 @@ build:mkl_aarch64_threadpool -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda +# Default CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # CUDA: This config refers to building CUDA op kernels with clang. build:cuda_clang --config=cuda -# Enable TensorRT optimizations https://developer.nvidia.com/tensorrt -build:cuda_clang --config=tensorrt -build:cuda_clang --action_env=TF_CUDA_CLANG="1" build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --copt=-Qunused-arguments # Select supported compute capabilities (supported graphics cards). # This is the same as the official TensorFlow builds. # See https://developer.nvidia.com/cuda-gpus#compute @@ -234,22 +237,22 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Set lld as the linker. +build:cuda_clang --host_linkopt="-fuse-ld=lld" +build:cuda_clang --host_linkopt="-lm" +build:cuda_clang --linkopt="-fuse-ld=lld" +build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" -build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" -build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc @@ -545,10 +548,6 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.17-clang_config_tensorrt" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" -test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --config=nvcc_clang @@ -633,7 +632,6 @@ build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/cla # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS -test:release_linux_base --test_env=LD_LIBRARY_PATH # Give only the list of failed tests at the end of the log test:release_linux_base --test_summary=short @@ -647,7 +645,6 @@ build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. # Note that linux cpu and cuda builds share the same toolchain now. build:release_gpu_linux --config=cuda_clang_official -test:release_gpu_linux --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" # Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think test:release_gpu_linux --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute @@ -656,6 +653,7 @@ build:release_arm64_linux --config=linux_arm64 build:release_arm64_linux --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" build:release_arm64_linux --config=mkl_aarch64_threadpool build:release_arm64_linux --copt=-flax-vector-conversions +test:release_arm64_linux --flaky_test_attempts=3 # The old gcc linux build options are preserved in the unsupported_*_linux # configs. If your project fails to build with Clang, you can use these @@ -677,9 +675,8 @@ build:unsupported_gpu_linux --config=unsupported_cpu_linux build:unsupported_gpu_linux --action_env=TF_CUDA_VERSION="11" build:unsupported_gpu_linux --action_env=TF_CUDNN_VERSION="8" build:unsupported_gpu_linux --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" -build:unsupported_gpu_linux --config=tensorrt build:unsupported_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" -build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64:/usr/local/tensorrt/lib" +build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain @@ -774,7 +771,7 @@ test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflo # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium --flaky_test_attempts=3 +test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 @@ -812,7 +809,7 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # inherit from build. build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP diff --git a/third_party/xla/third_party/tsl/WORKSPACE b/third_party/xla/third_party/tsl/WORKSPACE index 19350e3dbba762..a83a9e63f4143a 100644 --- a/third_party/xla/third_party/tsl/WORKSPACE +++ b/third_party/xla/third_party/tsl/WORKSPACE @@ -50,3 +50,50 @@ tsl_workspace1() load(":workspace0.bzl", "tsl_workspace0") tsl_workspace0() + +load( + "//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files index 300ae95c10aec2..f93d02d633d3c7 100644 --- a/third_party/xla/third_party/tsl/opensource_only.files +++ b/third_party/xla/third_party/tsl/opensource_only.files @@ -21,6 +21,7 @@ third_party/git/BUILD.tpl: third_party/git/BUILD: third_party/git/git_configure.bzl: third_party/gpus/BUILD: +third_party/gpus/compiler_common_tools.bzl: third_party/gpus/crosstool/BUILD.rocm.tpl: third_party/gpus/crosstool/BUILD.sycl.tpl: third_party/gpus/crosstool/BUILD.tpl: @@ -38,6 +39,27 @@ third_party/gpus/cuda/LICENSE: third_party/gpus/cuda/build_defs.bzl.tpl: third_party/gpus/cuda/cuda_config.h.tpl: third_party/gpus/cuda/cuda_config.py.tpl: +third_party/gpus/cuda/hermetic/BUILD.tpl: +third_party/gpus/cuda/hermetic/BUILD: +third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_configure.bzl: +third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl: +third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl: third_party/gpus/cuda_configure.bzl: third_party/gpus/find_cuda_config:.py third_party/gpus/rocm/BUILD.tpl: @@ -67,6 +89,9 @@ third_party/nccl/archive.BUILD: third_party/nccl/archive.patch: third_party/nccl/build_defs.bzl.tpl: third_party/nccl/generated_names.bzl.tpl: +third_party/nccl/hermetic/BUILD: +third_party/nccl/hermetic/cuda_nccl.BUILD.tpl: +third_party/nccl/hermetic/nccl_configure.bzl: third_party/nccl/nccl_configure.bzl: third_party/nccl/system.BUILD.tpl: third_party/nvtx/BUILD: @@ -93,6 +118,7 @@ third_party/remote_config/remote_platform_configure.bzl: third_party/repo.bzl: third_party/six.BUILD: third_party/snappy.BUILD: +third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD: third_party/systemlibs/BUILD.tpl: third_party/systemlibs/BUILD: third_party/systemlibs/absl_py.BUILD: diff --git a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py index afd6380b0ac203..b1a10a86b9aac6 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py @@ -14,6 +14,9 @@ # ============================================================================== """Verifies that a list of libraries is installed on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + Takes a list of arguments with every two subsequent arguments being a logical tuple of (path, check_soname). The path to the library and either True or False to indicate whether to check the soname field on the shared library. diff --git a/third_party/xla/third_party/tsl/third_party/gpus/compiler_common_tools.bzl b/third_party/xla/third_party/tsl/third_party/gpus/compiler_common_tools.bzl new file mode 100644 index 00000000000000..bd07f49ec457bb --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/compiler_common_tools.bzl @@ -0,0 +1,174 @@ +"""Common compiler functions. """ + +load( + "//third_party/remote_config:common.bzl", + "err_out", + "raw_exec", + "realpath", +) + +def to_list_of_strings(elements): + """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. + + This is to be used to put a list of strings into the bzl file templates + so it gets interpreted as list of strings in Starlark. + + Args: + elements: list of string elements + + Returns: + single string of elements wrapped in quotes separated by a comma.""" + quoted_strings = ["\"" + element + "\"" for element in elements] + return ", ".join(quoted_strings) + +_INC_DIR_MARKER_BEGIN = "#include <...>" + +# OSX add " (framework directory)" at the end of line, strip it. +_OSX_FRAMEWORK_SUFFIX = " (framework directory)" +_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) + +# TODO(dzc): Once these functions have been factored out of Bazel's +# cc_configure.bzl, load them from @bazel_tools instead. +def _cxx_inc_convert(path): + """Convert path returned by cc -E xc++ in a complete path.""" + path = path.strip() + if path.endswith(_OSX_FRAMEWORK_SUFFIX): + path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() + return path + +def _normalize_include_path(repository_ctx, path): + """Normalizes include paths before writing them to the crosstool. + + If path points inside the 'crosstool' folder of the repository, a relative + path is returned. + If path points outside the 'crosstool' folder, an absolute path is returned. + """ + path = str(repository_ctx.path(path)) + crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) + + if path.startswith(crosstool_folder): + # We drop the path to "$REPO/crosstool" and a trailing path separator. + return path[len(crosstool_folder) + 1:] + return path + +def _is_compiler_option_supported(repository_ctx, cc, option): + """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" + result = repository_ctx.execute([ + cc, + option, + "-o", + "/dev/null", + "-c", + str(repository_ctx.path("tools/cpp/empty.cc")), + ]) + return result.stderr.find(option) == -1 + +def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sys_root): + """Compute the list of default C or C++ include directories.""" + if lang_is_cpp: + lang = "c++" + else: + lang = "c" + sysroot = [] + if tf_sys_root: + sysroot += ["--sysroot", tf_sys_root] + result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + + sysroot) + stderr = err_out(result) + index1 = stderr.find(_INC_DIR_MARKER_BEGIN) + if index1 == -1: + return [] + index1 = stderr.find("\n", index1) + if index1 == -1: + return [] + index2 = stderr.rfind("\n ") + if index2 == -1 or index2 < index1: + return [] + index2 = stderr.find("\n", index2 + 1) + if index2 == -1: + inc_dirs = stderr[index1 + 1:] + else: + inc_dirs = stderr[index1 + 1:index2].strip() + + print_resource_dir_supported = _is_compiler_option_supported( + repository_ctx, + cc, + "-print-resource-dir", + ) + + if print_resource_dir_supported: + resource_dir = repository_ctx.execute( + [cc, "-print-resource-dir"], + ).stdout.strip() + "/share" + inc_dirs += "\n" + resource_dir + + compiler_includes = [ + _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) + for p in inc_dirs.split("\n") + ] + + # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc + # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) + # but Bazel might encounter either (usually reported by the compiler) + # especially when a compiler wrapper (e.g. ccache) is used. + # So we need to also include paths where symlinks are not resolved. + + # Try to find real path to CC installation to "see through" compiler wrappers + # GCC has the path to g++ + index1 = result.stderr.find("COLLECT_GCC=") + if index1 != -1: + index1 = result.stderr.find("=", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname + else: + # Clang has the directory + index1 = result.stderr.find("InstalledDir: ") + if index1 != -1: + index1 = result.stderr.find(" ", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname + else: + # Fallback to the CC path + cc_topdir = repository_ctx.path(cc).dirname.dirname + + # We now have the compiler installation prefix, e.g. /symlink/gcc + # And the resolved installation prefix, e.g. /opt/gcc + cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() + cc_topdir = str(cc_topdir).strip() + + # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. + # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] + # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] + if cc_topdir_resolved != cc_topdir: + unresolved_compiler_includes = [ + cc_topdir + inc[len(cc_topdir_resolved):] + for inc in compiler_includes + if inc.startswith(cc_topdir_resolved) + ] + compiler_includes = compiler_includes + unresolved_compiler_includes + return compiler_includes + +def get_cxx_inc_directories(repository_ctx, cc, tf_sys_root): + """Compute the list of default C and C++ include directories.""" + + # For some reason `clang -xc` sometimes returns include paths that are + # different from the ones from `clang -xc++`. (Symlink and a dir) + # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists + includes_cpp = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + True, + tf_sys_root, + ) + includes_c = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + False, + tf_sys_root, + ) + + return includes_cpp + [ + inc + for inc in includes_c + if inc not in includes_cpp + ] diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl index 8eda7a1cf6ac2b..b9553d9b99ecfe 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl @@ -2,6 +2,7 @@ # Update cuda_configure.bzl#verify_build_defines when adding new variables. load(":cc_toolchain_config.bzl", "cc_toolchain_config") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") licenses(["restricted"]) @@ -133,9 +134,17 @@ filegroup( srcs = [], ) +filegroup( + name = "cuda_nvcc_files", + srcs = %{cuda_nvcc_files}, +) + filegroup( name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + srcs = [ + ":cuda_nvcc_files", + ":clang/bin/crosstool_wrapper_driver_is_not_gcc" + ], ) filegroup( diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index c46e09484fdfad..eb3a1d8c8ddf02 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -181,6 +181,9 @@ def InvokeNvcc(argv, log=False): nvccopts += ['--keep', '--keep-dir', tempdir] # Force C++17 dialect (note, everything in just one string!) nvccopts += ['--std c++17'] + # This is so that nvcc does not complain about MSVC or CLANG. + nvccopts += ['-allow-unsupported-compiler'] + nvccopts += ['--expt-extended-lambda', '--expt-relaxed-constexpr'] if log: Log([NVCC_PATH] + nvccopts) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl index 44cdbe34b25f86..094431dcedfc12 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl @@ -1,6 +1,10 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Please use `hermetic/cuda_configure` instead. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like @@ -144,7 +148,6 @@ cc_library( name = "cusolver", srcs = ["cuda/lib/%{cusolver_lib}"], data = ["cuda/lib/%{cusolver_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -220,7 +223,6 @@ cc_library( name = "cusparse", srcs = ["cuda/lib/%{cusparse_lib}"], data = ["cuda/lib/%{cusparse_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -242,6 +244,41 @@ py_library( srcs = ["cuda/cuda_config.py"], ) +# Build setting that is always true (i.e. it can not be changed on the +# command line). It is used to create the config settings below that are +# always or never satisfied. +bool_setting( + name = "true_setting", + visibility = ["//visibility:private"], + build_setting_default = True, +) + +# Config settings whether TensorFlow is built with hermetic CUDA. +# These configs are never satisfied. +config_setting( + name = "hermetic_cuda_tools", + flag_values = {":true_setting": "False"}, +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":true_setting": "False"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + %{copy_rules} cc_library( diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl index dee0e898d9ae7a..6b25c8398a7144 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl @@ -1,3 +1,7 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Hermetic CUDA repository rule doesn't support Windows. +# Please use `hermetic/cuda_configure`. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl index bc865cecb3240a..d1c50ea6377b9e 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl @@ -104,9 +104,16 @@ def if_cuda_newer_than(wanted_ver, if_true, if_false = []): wanted_major = int(wanted_ver.split('_')[0]) wanted_minor = int(wanted_ver.split('_')[1]) - configured_version = "%{cuda_version}" - configured_major = int(configured_version.split('.')[0]) - configured_minor = int(configured_version.split('.')[1]) + # Strip "64_" which appears in the CUDA version on Windows. + configured_version = "%{cuda_version}".rsplit("_", 1)[-1] + configured_version_parts = configured_version.split('.') + + # On Windows, the major and minor versions are concatenated without a period and the minor only contains one digit. + if len(configured_version_parts) == 1: + configured_version_parts = [configured_version[0:-1], configured_version[-1:]] + + configured_major = int(configured_version_parts[0]) + configured_minor = int(configured_version_parts[1]) if %{cuda_is_configured} and (wanted_major, wanted_minor) <= (configured_major, configured_minor): return select({"//conditions:default": if_true}) @@ -142,9 +149,13 @@ def cuda_header_library( **kwargs ) -def cuda_library(copts = [], **kwargs): +def cuda_library(copts = [], tags = [],**kwargs): """Wrapper over cc_library which adds default CUDA options.""" - native.cc_library(copts = cuda_default_copts() + copts, **kwargs) + native.cc_library( + copts = cuda_default_copts() + copts, + tags = tags + ["gpu"], + **kwargs + ) def cuda_cc_test(copts = [], **kwargs): """Wrapper over cc_test which adds default CUDA options.""" diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl new file mode 100644 index 00000000000000..ccf1b9a030d5ad --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -0,0 +1,266 @@ +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cc_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + ], + deps = [":cudart_headers", + ":cublas_headers", + ":cccl_headers", + ":nvtx_headers", + ":nvcc_headers", + ":cusolver_headers", + ":cufft_headers", + ":cusparse_headers", + ":curand_headers", + ":cupti_headers", + ":nvml_headers"], +) + +cc_library( + name = "cudart_static", + srcs = ["@cuda_cudart//:static"], + linkopts = [ + "-ldl", + "-lpthread", + %{cudart_static_linkopt} + ], +) + +alias( + name = "cuda_driver", + actual = "@cuda_cudart//:cuda_driver", +) + +alias( + name = "cudart_headers", + actual = "@cuda_cudart//:headers", +) + +alias( + name = "cudart", + actual = "@cuda_cudart//:cudart", +) + +alias( + name = "nvtx_headers", + actual = "@cuda_nvtx//:headers", +) + +alias( + name = "nvml_headers", + actual = "@cuda_nvml//:headers", +) + +alias( + name = "nvcc_headers", + actual = "@cuda_nvcc//:headers", +) + +alias( + name = "cccl_headers", + actual = "@cuda_cccl//:headers", +) + +alias( + name = "cublas_headers", + actual = "@cuda_cublas//:headers", +) + +alias( + name = "cusolver_headers", + actual = "@cuda_cusolver//:headers", +) + +alias( + name = "cufft_headers", + actual = "@cuda_cufft//:headers", +) + +alias( + name = "cusparse_headers", + actual = "@cuda_cusparse//:headers", +) + +alias( + name = "curand_headers", + actual = "@cuda_curand//:headers", +) + +alias( + name = "cublas", + actual = "@cuda_cublas//:cublas", +) + +alias( + name = "cublasLt", + actual = "@cuda_cublas//:cublasLt", +) + +alias( + name = "cusolver", + actual = "@cuda_cusolver//:cusolver", +) + +alias( + name = "cudnn", + actual = "@cuda_cudnn//:cudnn", +) + +alias( + name = "cudnn_header", + actual = "@cuda_cudnn//:headers", +) + +alias( + name = "cufft", + actual = "@cuda_cufft//:cufft", +) + +alias( + name = "curand", + actual = "@cuda_curand//:curand", +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = ":cuda_headers", +) + +alias( + name = "cupti_headers", + actual = "@cuda_cupti//:headers", +) + +alias( + name = "cupti_dsos", + actual = "@cuda_cupti//:cupti", +) + +alias( + name = "cusparse", + actual = "@cuda_cusparse//:cusparse", +) + +alias( + name = "cuda-nvvm", + actual = "@cuda_nvcc//:nvvm", +) + +alias( + name = "nvjitlink", + actual = "@cuda_nvjitlink//:nvjitlink" +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +# Config setting whether TensorFlow is built with hermetic CUDA. +alias( + name = "hermetic_cuda_tools", + actual = "@local_config_cuda//:is_cuda_enabled", +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":include_hermetic_cuda_libs": "True"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + +cc_library( + # This is not yet fully supported, but we need the rule + # to make bazel query happy. + name = "nvptxcompiler", +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl new file mode 100644 index 00000000000000..85c0cbbb196fef --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl @@ -0,0 +1,15 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + hdrs = glob([ + %{comment}"include/cub/**", + %{comment}"include/cuda/**", + %{comment}"include/nv/**", + %{comment}"include/thrust/**", + ]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl new file mode 100644 index 00000000000000..270b73c3884855 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -0,0 +1,521 @@ +"""Repository rule for hermetic CUDA autoconfiguration. + +`cuda_configure` depends on the following environment variables: + + * `TF_NEED_CUDA`: Whether to enable building with CUDA. + * `TF_NVCC_CLANG`: Whether to use clang for C++ and NVCC for Cuda compilation. + * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for + both host and device code compilation. + * `TF_SYSROOT`: The sysroot to use when compiling. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + * `HERMETIC_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default + is `3.5,5.2`. If not specified, the value will be determined by the + `TF_CUDA_COMPUTE_CAPABILITIES`. + * `PYTHON_BIN_PATH`: The python binary path +""" + +load( + "//third_party/gpus:compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", + "which", +) + +def _find_cc(repository_ctx): + """Find the C++ compiler.""" + cc_path_envvar = _CLANG_CUDA_COMPILER_PATH + cc_name = "clang" + + cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) + if cc_name_from_env: + cc_name = cc_name_from_env + if cc_name.startswith("/"): + # Return the absolute path. + return cc_name + cc = which(repository_ctx, cc_name) + if cc == None: + fail(("Cannot find {}, either correct your path or set the {}" + + " environment variable").format(cc_name, cc_path_envvar)) + return cc + +def _auto_configure_fail(msg): + """Output failure message when cuda configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _verify_build_defines(params): + """Verify all variables that crosstool/BUILD.tpl expects are substituted. + + Args: + params: dict of variables that will be passed to the BUILD.tpl template. + """ + missing = [] + for param in [ + "cxx_builtin_include_directories", + "extra_no_canonical_prefixes_flags", + "host_compiler_path", + "host_compiler_prefix", + "host_compiler_warnings", + "linker_bin_path", + "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", + "unfiltered_compile_flags", + "win_compiler_deps", + ]: + if ("%{" + param + "}") not in params: + missing.append(param) + + if missing: + _auto_configure_fail( + "BUILD.tpl template is missing these variables: " + + str(missing) + + ".\nWe only got: " + + str(params) + + ".", + ) + +def get_cuda_version(repository_ctx): + return (get_host_environ(repository_ctx, HERMETIC_CUDA_VERSION) or + get_host_environ(repository_ctx, TF_CUDA_VERSION)) + +def enable_cuda(repository_ctx): + """Returns whether to build with CUDA support.""" + return int(get_host_environ(repository_ctx, TF_NEED_CUDA, False)) + +def _flag_enabled(repository_ctx, flag_name): + return get_host_environ(repository_ctx, flag_name) == "1" + +def _use_nvcc_and_clang(repository_ctx): + # Returns the flag if we need to use clang for C++ and NVCC for Cuda. + return _flag_enabled(repository_ctx, _TF_NVCC_CLANG) + +def _tf_sysroot(repository_ctx): + return get_host_environ(repository_ctx, _TF_SYSROOT, "") + +def _py_tmpl_dict(d): + return {"%{cuda_config}": str(d)} + +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "\"\"," if cpu_value == "Darwin" else "\"-lrt\"," + +def _compute_capabilities(repository_ctx): + """Returns a list of strings representing cuda compute capabilities. + + Args: + repository_ctx: the repo rule's context. + + Returns: + list of cuda architectures to compile for. 'compute_xy' refers to + both PTX and SASS, 'sm_xy' refers to SASS only. + """ + capabilities = (get_host_environ( + repository_ctx, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + ) or + get_host_environ( + repository_ctx, + _TF_CUDA_COMPUTE_CAPABILITIES, + )) + capabilities = (capabilities or "compute_35,compute_52").split(",") + + # Map old 'x.y' capabilities to 'compute_xy'. + if len(capabilities) > 0 and all([len(x.split(".")) == 2 for x in capabilities]): + # If all capabilities are in 'x.y' format, only include PTX for the + # highest capability. + cc_list = sorted([x.replace(".", "") for x in capabilities]) + capabilities = [ + "sm_%s" % x + for x in cc_list[:-1] + ] + ["compute_%s" % cc_list[-1]] + for i, capability in enumerate(capabilities): + parts = capability.split(".") + if len(parts) != 2: + continue + capabilities[i] = "compute_%s%s" % (parts[0], parts[1]) + + # Make list unique + capabilities = dict(zip(capabilities, capabilities)).keys() + + # Validate capabilities. + for capability in capabilities: + if not capability.startswith(("compute_", "sm_")): + _auto_configure_fail("Invalid compute capability: %s" % capability) + for prefix in ["compute_", "sm_"]: + if not capability.startswith(prefix): + continue + if len(capability) == len(prefix) + 2 and capability[-2:].isdigit(): + continue + if len(capability) == len(prefix) + 3 and capability.endswith("90a"): + continue + _auto_configure_fail("Invalid compute capability: %s" % capability) + + return capabilities + +def _compute_cuda_extra_copts(compute_capabilities): + copts = ["--no-cuda-include-ptx=all"] + for capability in compute_capabilities: + if capability.startswith("compute_"): + capability = capability.replace("compute_", "sm_") + copts.append("--cuda-include-ptx=%s" % capability) + copts.append("--cuda-gpu-arch=%s" % capability) + + return str(copts) + +def _get_cuda_config(repository_ctx): + """Detects and returns information about the CUDA installation on the system. + + Args: + repository_ctx: The repository context. + + Returns: + A struct containing the following fields: + cuda_version: The version of CUDA on the system. + cudart_version: The CUDA runtime version on the system. + cudnn_version: The version of cuDNN on the system. + compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. + """ + + return struct( + cuda_version = get_cuda_version(repository_ctx), + cupti_version = repository_ctx.read(repository_ctx.attr.cupti_version), + cudart_version = repository_ctx.read(repository_ctx.attr.cudart_version), + cublas_version = repository_ctx.read(repository_ctx.attr.cublas_version), + cusolver_version = repository_ctx.read(repository_ctx.attr.cusolver_version), + curand_version = repository_ctx.read(repository_ctx.attr.curand_version), + cufft_version = repository_ctx.read(repository_ctx.attr.cufft_version), + cusparse_version = repository_ctx.read(repository_ctx.attr.cusparse_version), + cudnn_version = repository_ctx.read(repository_ctx.attr.cudnn_version), + compute_capabilities = _compute_capabilities(repository_ctx), + cpu_value = get_cpu_value(repository_ctx), + ) + +_DUMMY_CROSSTOOL_BZL_FILE = """ +def error_gpu_disabled(): + fail("ERROR: Building with --config=cuda but TensorFlow is not configured " + + "to build with GPU support. Please re-run ./configure and enter 'Y' " + + "at the prompt to build with GPU support.") + + native.genrule( + name = "error_gen_crosstool", + outs = ["CROSSTOOL"], + cmd = "echo 'Should not be run.' && exit 1", + ) + + native.filegroup( + name = "crosstool", + srcs = [":CROSSTOOL"], + output_licenses = ["unencumbered"], + ) +""" + +_DUMMY_CROSSTOOL_BUILD_FILE = """ +load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") + +error_gpu_disabled() +""" + +def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + + # Set up BUILD file for cuda/. + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "False", + "%{cuda_extra_copts}": "[]", + "%{cuda_gpu_architectures}": "[]", + "%{cuda_version}": "0.0", + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + }, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": "", + "%{cudart_version}": "", + "%{cupti_version}": "", + "%{cublas_version}": "", + "%{cusolver_version}": "", + "%{curand_version}": "", + "%{cufft_version}": "", + "%{cusparse_version}": "", + "%{cudnn_version}": "", + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({}), + ) + + # If cuda_configure is not configured to build with GPU support, and the user + # attempts to build with --config=cuda, add a dummy build rule to intercept + # this and fail with an actionable error message. + repository_ctx.file( + "crosstool/error_gpu_disabled.bzl", + _DUMMY_CROSSTOOL_BZL_FILE, + ) + repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) + +def _create_local_cuda_repository(repository_ctx): + """Creates the repository containing files set up to build with CUDA.""" + cuda_config = _get_cuda_config(repository_ctx) + + # Set up BUILD file for cuda/ + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + cuda_config.compute_capabilities, + ), + "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), + "%{cuda_version}": cuda_config.cuda_version, + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt( + cuda_config.cpu_value, + ), + }, + ) + + is_nvcc_and_clang = _use_nvcc_and_clang(repository_ctx) + tf_sysroot = _tf_sysroot(repository_ctx) + + # Set up crosstool/ + cc = _find_cc(repository_ctx) + host_compiler_includes = get_cxx_inc_directories( + repository_ctx, + cc, + tf_sysroot, + ) + + cuda_defines = {} + + # We do not support hermetic CUDA on Windows. + # This ensures the CROSSTOOL file parser is happy. + cuda_defines.update({ + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + "%{win_compiler_deps}": ":empty", + }) + + cuda_defines["%{builtin_sysroot}"] = tf_sysroot + cuda_defines["%{cuda_toolkit_path}"] = repository_ctx.attr.nvcc_binary.workspace_root + cuda_defines["%{compiler}"] = "clang" + cuda_defines["%{host_compiler_prefix}"] = "/usr/bin" + cuda_defines["%{linker_bin_path}"] = "" + cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" + cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( + host_compiler_includes, + ) + cuda_defines["%{cuda_nvcc_files}"] = "if_cuda([\"@{nvcc_archive}//:bin\", \"@{nvcc_archive}//:nvvm\"])".format( + nvcc_archive = repository_ctx.attr.nvcc_binary.repo_name, + ) + + if not is_nvcc_and_clang: + cuda_defines["%{host_compiler_path}"] = str(cc) + cuda_defines["%{host_compiler_warnings}"] = """ + # Some parts of the codebase set -Werror and hit this warning, so + # switch it off for now. + "-Wno-invalid-partial-specialization" + """ + cuda_defines["%{compiler_deps}"] = ":cuda_nvcc_files" + repository_ctx.file( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + "", + ) + else: + cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{host_compiler_warnings}"] = "" + + nvcc_relative_path = "%s/%s" % ( + repository_ctx.attr.nvcc_binary.workspace_root, + repository_ctx.attr.nvcc_binary.name, + ) + cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + + wrapper_defines = { + "%{cpu_compiler}": str(cc), + "%{cuda_version}": cuda_config.cuda_version, + "%{nvcc_path}": nvcc_relative_path, + "%{host_compiler_path}": str(cc), + "%{use_clang_compiler}": "True", + } + repository_ctx.template( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + repository_ctx.attr.crosstool_wrapper_driver_is_not_gcc_tpl, + wrapper_defines, + ) + + _verify_build_defines(cuda_defines) + + # Only expand template variables in the BUILD file + repository_ctx.template( + "crosstool/BUILD", + repository_ctx.attr.crosstool_build_tpl, + cuda_defines, + ) + + # No templating of cc_toolchain_config - use attributes and templatize the + # BUILD file. + repository_ctx.template( + "crosstool/cc_toolchain_config.bzl", + repository_ctx.attr.cc_toolchain_config_tpl, + {}, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": cuda_config.cuda_version, + "%{cudart_version}": cuda_config.cudart_version, + "%{cupti_version}": cuda_config.cupti_version, + "%{cublas_version}": cuda_config.cublas_version, + "%{cusolver_version}": cuda_config.cusolver_version, + "%{curand_version}": cuda_config.curand_version, + "%{cufft_version}": cuda_config.cufft_version, + "%{cusparse_version}": cuda_config.cusparse_version, + "%{cudnn_version}": cuda_config.cudnn_version, + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": ", ".join([ + cc.split("_")[1] + for cc in cuda_config.compute_capabilities + ]), + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({ + "cuda_version": cuda_config.cuda_version, + "cudnn_version": cuda_config.cudnn_version, + "cuda_compute_capabilities": cuda_config.compute_capabilities, + "cpu_compiler": str(cc), + }), + ) + +def _cuda_autoconf_impl(repository_ctx): + """Implementation of the cuda_autoconf repository rule.""" + build_file = repository_ctx.attr.local_config_cuda_build_file + + if not enable_cuda(repository_ctx): + _create_dummy_repository(repository_ctx) + else: + _create_local_cuda_repository(repository_ctx) + + repository_ctx.symlink(build_file, "BUILD") + +_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH" +_PYTHON_BIN_PATH = "PYTHON_BIN_PATH" +_HERMETIC_CUDA_COMPUTE_CAPABILITIES = "HERMETIC_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +HERMETIC_CUDA_VERSION = "HERMETIC_CUDA_VERSION" +TF_CUDA_VERSION = "TF_CUDA_VERSION" +TF_NEED_CUDA = "TF_NEED_CUDA" +_TF_NVCC_CLANG = "TF_NVCC_CLANG" +_TF_SYSROOT = "TF_SYSROOT" + +_ENVIRONS = [ + _CLANG_CUDA_COMPILER_PATH, + TF_NEED_CUDA, + _TF_NVCC_CLANG, + TF_CUDA_VERSION, + HERMETIC_CUDA_VERSION, + _TF_CUDA_COMPUTE_CAPABILITIES, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + _TF_SYSROOT, + _PYTHON_BIN_PATH, + "TMP", + "TMPDIR", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", +] + +cuda_configure = repository_rule( + implementation = _cuda_autoconf_impl, + environ = _ENVIRONS, + attrs = { + "environ": attr.string_dict(), + "cublas_version": attr.label(default = Label("@cuda_cublas//:version.txt")), + "cudart_version": attr.label(default = Label("@cuda_cudart//:version.txt")), + "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.txt")), + "cufft_version": attr.label(default = Label("@cuda_cufft//:version.txt")), + "cupti_version": attr.label(default = Label("@cuda_cupti//:version.txt")), + "curand_version": attr.label(default = Label("@cuda_curand//:version.txt")), + "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.txt")), + "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.txt")), + "nvcc_binary": attr.label(default = Label("@cuda_nvcc//:bin/nvcc")), + "local_config_cuda_build_file": attr.label(default = Label("//third_party/gpus:local_config_cuda.BUILD")), + "build_defs_tpl": attr.label(default = Label("//third_party/gpus/cuda:build_defs.bzl.tpl")), + "cuda_build_tpl": attr.label(default = Label("//third_party/gpus/cuda/hermetic:BUILD.tpl")), + "cuda_config_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.h.tpl")), + "cuda_config_py_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.py.tpl")), + "crosstool_wrapper_driver_is_not_gcc_tpl": attr.label(default = Label("//third_party/gpus/crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl")), + "crosstool_build_tpl": attr.label(default = Label("//third_party/gpus/crosstool:BUILD.tpl")), + "cc_toolchain_config_tpl": attr.label(default = Label("//third_party/gpus/crosstool:cc_toolchain_config.bzl.tpl")), + }, +) +"""Detects and configures the hermetic CUDA toolchain. + +Add the following to your WORKSPACE file: + +```python +cuda_configure(name = "local_config_cuda") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl new file mode 100644 index 00000000000000..510235d801de4e --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -0,0 +1,44 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cublas_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublas.so.%{libcublas_version}", + deps = [":cublasLt"], +) + +cc_import( + name = "cublasLt_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublasLt.so.%{libcublaslt_version}", +) +%{multiline_comment} +cc_library( + name = "cublas", + visibility = ["//visibility:public"], + %{comment}deps = [":cublas_shared_library"], +) + +cc_library( + name = "cublasLt", + visibility = ["//visibility:public"], + %{comment}deps = [":cublasLt_shared_library"], +) + +cc_library( + name = "headers", + %{comment}hdrs = [ + %{comment}"include/cublas.h", + %{comment}"include/cublasLt.h", + %{comment}"include/cublas_api.h", + %{comment}"include/cublas_v2.h", + %{comment}], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl new file mode 100644 index 00000000000000..f7ba469b42b76a --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -0,0 +1,126 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) + +filegroup( + name = "static", + srcs = ["lib/libcudart_static.a"], + visibility = ["@local_config_cuda//cuda:__pkg__"], +) +%{multiline_comment} +# TODO: Replace system provided library with hermetic NVIDIA driver library. +cc_import( + name = "cuda_driver_shared_library", + interface_library = "lib/stubs/libcuda.so", + system_provided = 1, +) + +cc_import( + name = "cudart_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcudart.so.%{libcudart_version}", +) +%{multiline_comment} +cc_library( + name = "cuda_driver", + %{comment}deps = [":cuda_driver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart", + %{comment}deps = [ + %{comment}":cuda_driver", + %{comment}":cudart_shared_library", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/builtin_types.h", + %{comment}"include/channel_descriptor.h", + %{comment}"include/common_functions.h", + %{comment}"include/cooperative_groups/**", + %{comment}"include/cooperative_groups.h", + %{comment}"include/cuComplex.h", + %{comment}"include/cuda.h", + %{comment}"include/cudaEGL.h", + %{comment}"include/cudaEGLTypedefs.h", + %{comment}"include/cudaGL.h", + %{comment}"include/cudaGLTypedefs.h", + %{comment}"include/cudaProfilerTypedefs.h", + %{comment}"include/cudaTypedefs.h", + %{comment}"include/cudaVDPAU.h", + %{comment}"include/cudaVDPAUTypedefs.h", + %{comment}"include/cuda_awbarrier.h", + %{comment}"include/cuda_awbarrier_helpers.h", + %{comment}"include/cuda_awbarrier_primitives.h", + %{comment}"include/cuda_bf16.h", + %{comment}"include/cuda_bf16.hpp", + %{comment}"include/cuda_device_runtime_api.h", + %{comment}"include/cuda_egl_interop.h", + %{comment}"include/cuda_fp16.h", + %{comment}"include/cuda_fp16.hpp", + %{comment}"include/cuda_fp8.h", + %{comment}"include/cuda_fp8.hpp", + %{comment}"include/cuda_gl_interop.h", + %{comment}"include/cuda_occupancy.h", + %{comment}"include/cuda_pipeline.h", + %{comment}"include/cuda_pipeline_helpers.h", + %{comment}"include/cuda_pipeline_primitives.h", + %{comment}"include/cuda_runtime.h", + %{comment}"include/cuda_runtime_api.h", + %{comment}"include/cuda_surface_types.h", + %{comment}"include/cuda_texture_types.h", + %{comment}"include/cuda_vdpau_interop.h", + %{comment}"include/cudart_platform.h", + %{comment}"include/device_atomic_functions.h", + %{comment}"include/device_atomic_functions.hpp", + %{comment}"include/device_double_functions.h", + %{comment}"include/device_functions.h", + %{comment}"include/device_launch_parameters.h", + %{comment}"include/device_types.h", + %{comment}"include/driver_functions.h", + %{comment}"include/driver_types.h", + %{comment}"include/host_config.h", + %{comment}"include/host_defines.h", + %{comment}"include/library_types.h", + %{comment}"include/math_constants.h", + %{comment}"include/math_functions.h", + %{comment}"include/mma.h", + %{comment}"include/nvfunctional", + %{comment}"include/sm_20_atomic_functions.h", + %{comment}"include/sm_20_atomic_functions.hpp", + %{comment}"include/sm_20_intrinsics.h", + %{comment}"include/sm_20_intrinsics.hpp", + %{comment}"include/sm_30_intrinsics.h", + %{comment}"include/sm_30_intrinsics.hpp", + %{comment}"include/sm_32_atomic_functions.h", + %{comment}"include/sm_32_atomic_functions.hpp", + %{comment}"include/sm_32_intrinsics.h", + %{comment}"include/sm_32_intrinsics.hpp", + %{comment}"include/sm_35_atomic_functions.h", + %{comment}"include/sm_35_intrinsics.h", + %{comment}"include/sm_60_atomic_functions.h", + %{comment}"include/sm_60_atomic_functions.hpp", + %{comment}"include/sm_61_intrinsics.h", + %{comment}"include/sm_61_intrinsics.hpp", + %{comment}"include/surface_functions.h", + %{comment}"include/surface_indirect_functions.h", + %{comment}"include/surface_types.h", + %{comment}"include/texture_fetch_functions.h", + %{comment}"include/texture_indirect_functions.h", + %{comment}"include/texture_types.h", + %{comment}"include/vector_functions.h", + %{comment}"include/vector_functions.hpp", + %{comment}"include/vector_types.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl new file mode 100644 index 00000000000000..165c5b1579e73f --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -0,0 +1,73 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_infer.so.%{libcudnn_ops_infer_version}", +) + +cc_import( + name = "cudnn_cnn_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_infer.so.%{libcudnn_cnn_infer_version}", +) + +cc_import( + name = "cudnn_ops_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_train.so.%{libcudnn_ops_train_version}", +) + +cc_import( + name = "cudnn_cnn_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_train.so.%{libcudnn_cnn_train_version}", +) + +cc_import( + name = "cudnn_adv_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_infer.so.%{libcudnn_adv_infer_version}", +) + +cc_import( + name = "cudnn_adv_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_train.so.%{libcudnn_adv_train_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_ops_infer", + %{comment}":cudnn_ops_train", + %{comment}":cudnn_cnn_infer", + %{comment}":cudnn_cnn_train", + %{comment}":cudnn_adv_infer", + %{comment}":cudnn_adv_train", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl new file mode 100644 index 00000000000000..7f36054a51bb5b --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -0,0 +1,80 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops.so.%{libcudnn_ops_version}", +) + +cc_import( + name = "cudnn_cnn", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn.so.%{libcudnn_cnn_version}", +) + +cc_import( + name = "cudnn_adv", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv.so.%{libcudnn_adv_version}", +) + +cc_import( + name = "cudnn_graph", + hdrs = [":headers"], + shared_library = "lib/libcudnn_graph.so.%{libcudnn_graph_version}", +) + +cc_import( + name = "cudnn_engines_precompiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_precompiled.so.%{libcudnn_engines_precompiled_version}", +) + +cc_import( + name = "cudnn_engines_runtime_compiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_runtime_compiled.so.%{libcudnn_engines_runtime_compiled_version}", +) + +cc_import( + name = "cudnn_heuristic", + hdrs = [":headers"], + shared_library = "lib/libcudnn_heuristic.so.%{libcudnn_heuristic_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_engines_precompiled", + %{comment}":cudnn_ops", + %{comment}":cudnn_graph", + %{comment}":cudnn_cnn", + %{comment}":cudnn_adv", + %{comment}":cudnn_engines_runtime_compiled", + %{comment}":cudnn_heuristic", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl new file mode 100644 index 00000000000000..48ccb0ea3cd197 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -0,0 +1,29 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cufft_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcufft.so.%{libcufft_version}", +) +%{multiline_comment} +cc_library( + name = "cufft", + %{comment}deps = [":cufft_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudalibxt.h", + %{comment}"include/cufft*.h" + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl new file mode 100644 index 00000000000000..3efe76f470953f --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -0,0 +1,59 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cupti_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcupti.so.%{libcupti_version}", +) +%{multiline_comment} +cc_library( + name = "cupti", + %{comment}deps = [":cupti_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/Openacc/**", + %{comment}"include/Openmp/**", + %{comment}"include/cuda_stdint.h", + %{comment}"include/cupti.h", + %{comment}"include/cupti_activity.h", + %{comment}"include/cupti_activity_deprecated.h", + %{comment}"include/cupti_callbacks.h", + %{comment}"include/cupti_checkpoint.h", + %{comment}"include/cupti_driver_cbid.h", + %{comment}"include/cupti_events.h", + %{comment}"include/cupti_metrics.h", + %{comment}"include/cupti_nvtx_cbid.h", + %{comment}"include/cupti_pcsampling.h", + %{comment}"include/cupti_pcsampling_util.h", + %{comment}"include/cupti_profiler_target.h", + %{comment}"include/cupti_result.h", + %{comment}"include/cupti_runtime_cbid.h", + %{comment}"include/cupti_sass_metrics.h", + %{comment}"include/cupti_target.h", + %{comment}"include/cupti_version.h", + %{comment}"include/generated_cudaGL_meta.h", + %{comment}"include/generated_cudaVDPAU_meta.h", + %{comment}"include/generated_cuda_gl_interop_meta.h", + %{comment}"include/generated_cuda_meta.h", + %{comment}"include/generated_cuda_runtime_api_meta.h", + %{comment}"include/generated_cuda_vdpau_interop_meta.h", + %{comment}"include/generated_cudart_removed_meta.h", + %{comment}"include/generated_nvtx_meta.h", + %{comment}"include/nvperf_common.h", + %{comment}"include/nvperf_cuda_host.h", + %{comment}"include/nvperf_host.h", + %{comment}"include/nvperf_target.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/extras/CUPTI/include", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl new file mode 100644 index 00000000000000..50e5a8f18a96fd --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -0,0 +1,26 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "curand_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcurand.so.%{libcurand_version}", +) +%{multiline_comment} +cc_library( + name = "curand", + %{comment}deps = [":curand_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob(["include/curand*.h"]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl new file mode 100644 index 00000000000000..943a08ebeb96e1 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -0,0 +1,34 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusolver_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusolver.so.%{libcusolver_version}", + deps = [ + "@cuda_nvjitlink//:nvjitlink", + "@cuda_cusparse//:cusparse", + "@cuda_cublas//:cublas", + "@cuda_cublas//:cublasLt", + ], +) +%{multiline_comment} +cc_library( + name = "cusolver", + %{comment}deps = [":cusolver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cusolver*.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl new file mode 100644 index 00000000000000..46b24366ce1c04 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -0,0 +1,27 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusparse_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusparse.so.%{libcusparse_version}", + deps = ["@cuda_nvjitlink//:nvjitlink"], +) +%{multiline_comment} +cc_library( + name = "cusparse", + %{comment}deps = [":cusparse_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = ["include/cusparse.h"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl new file mode 100644 index 00000000000000..fdda3aaf92cea5 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl @@ -0,0 +1,125 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistributions JSON repository initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_JSON_DICT", + "CUDNN_REDIST_JSON_DICT", +) + +def _get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_json_file_content(repository_ctx, url_to_sha256, json_file_name): + if len(url_to_sha256) > 1: + (url, sha256) = url_to_sha256 + else: + url = url_to_sha256[0] + sha256 = "" + repository_ctx.download( + url = tf_mirror_urls(url), + sha256 = sha256, + output = json_file_name, + ) + return repository_ctx.read(repository_ctx.path(json_file_name)) + +def _cuda_redist_json_impl(repository_ctx): + cuda_version = (_get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + _get_env_var(repository_ctx, "TF_CUDA_VERSION")) + local_cuda_path = _get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + cudnn_version = (_get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + _get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + local_cudnn_path = _get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + supported_cuda_versions = repository_ctx.attr.cuda_json_dict.keys() + if (cuda_version and not local_cuda_path and + (cuda_version not in supported_cuda_versions)): + fail( + ("The supported CUDA versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add JSON URL for" + + " CUDA version={version}.") + .format( + supported_versions = supported_cuda_versions, + version = cuda_version, + ), + ) + supported_cudnn_versions = repository_ctx.attr.cudnn_json_dict.keys() + if cudnn_version and not local_cudnn_path and (cudnn_version not in supported_cudnn_versions): + fail( + ("The supported CUDNN versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDNN_VERSION" + + " environment variable or add JSON URL for" + + " CUDNN version={version}.") + .format( + supported_versions = supported_cudnn_versions, + version = cudnn_version, + ), + ) + cuda_redistributions = "{}" + cudnn_redistributions = "{}" + if cuda_version and not local_cuda_path: + cuda_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cuda_json_dict[cuda_version], + "redistrib_cuda_%s.json" % cuda_version, + ) + if cudnn_version and not local_cudnn_path: + cudnn_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cudnn_json_dict[cudnn_version], + "redistrib_cudnn_%s.json" % cudnn_version, + ) + + repository_ctx.file( + "distributions.bzl", + """CUDA_REDISTRIBUTIONS = {cuda_redistributions} + +CUDNN_REDISTRIBUTIONS = {cudnn_redistributions} +""".format( + cuda_redistributions = cuda_redistributions, + cudnn_redistributions = cudnn_redistributions, + ), + ) + repository_ctx.file( + "BUILD", + "", + ) + +cuda_redist_json = repository_rule( + implementation = _cuda_redist_json_impl, + attrs = { + "cuda_json_dict": attr.string_list_dict(mandatory = True), + "cudnn_json_dict": attr.string_list_dict(mandatory = True), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "HERMETIC_CUDNN_VERSION", + "TF_CUDA_VERSION", + "TF_CUDNN_VERSION", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", + ], +) + +def cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT): + cuda_redist_json( + name = "cuda_redist_json", + cuda_json_dict = cuda_json_dict, + cudnn_json_dict = cudnn_json_dict, + ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl new file mode 100644 index 00000000000000..7757a92a90b795 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl @@ -0,0 +1,75 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "bin/nvcc", +]) + +filegroup( + name = "nvvm", + srcs = [ + "nvvm/libdevice/libdevice.10.bc", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "nvlink", + srcs = [ + "bin/nvlink", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "fatbinary", + srcs = [ + "bin/fatbinary", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin2c", + srcs = [ + "bin/bin2c", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "ptxas", + srcs = [ + "bin/ptxas", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin", + srcs = glob([ + "bin/**", + "nvvm/bin/**", + ]), + visibility = ["//visibility:public"], +) + +filegroup( + name = "link_stub", + srcs = [ + "bin/crt/link.stub", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/crt/**", + %{comment}"include/fatbinary_section.h", + %{comment}"include/nvPTXCompiler.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl new file mode 100644 index 00000000000000..9784a84471f1a7 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -0,0 +1,17 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nvjitlink_shared_library", + shared_library = "lib/libnvJitLink.so.%{libnvjitlink_version}", +) +%{multiline_comment} +cc_library( + name = "nvjitlink", + %{comment}deps = [":nvjitlink_shared_library"], + visibility = ["//visibility:public"], +) + diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl new file mode 100644 index 00000000000000..23ee30f09f8ff3 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl @@ -0,0 +1,10 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = ["include/nvml.h"], + include_prefix = "third_party/gpus/cuda/nvml/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl new file mode 100644 index 00000000000000..986ef0c8f76166 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl @@ -0,0 +1,9 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +filegroup( + name = "nvprune", + srcs = [ + "bin/nvprune", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl new file mode 100644 index 00000000000000..de18489b455b79 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl @@ -0,0 +1,20 @@ +licenses(["restricted"]) # NVIDIA proprietary license +%{multiline_comment} +cc_import( + name = "nvrtc_main", + shared_library = "lib/libnvrtc.so.%{libnvrtc_version}", +) + +cc_import( + name = "nvrtc_builtins", + shared_library = "lib/libnvrtc-builtins.so.%{libnvrtc-builtins_version}", +) +%{multiline_comment} +cc_library( + name = "nvrtc", + %{comment}deps = [ + %{comment}":nvrtc_main", + %{comment}":nvrtc_builtins", + %{comment}], + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl new file mode 100644 index 00000000000000..3457f41a502dee --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl @@ -0,0 +1,13 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nvToolsExt*.h", + %{comment}"include/nvtx3/**", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl new file mode 100644 index 00000000000000..d2015e737540c3 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -0,0 +1,491 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_PATH_PREFIX", + "CUDNN_REDIST_PATH_PREFIX", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +OS_ARCH_DICT = { + "amd64": "x86_64-unknown-linux-gnu", + "aarch64": "aarch64-unknown-linux-gnu", +} +_REDIST_ARCH_DICT = { + "linux-x86_64": "x86_64-unknown-linux-gnu", + "linux-sbsa": "aarch64-unknown-linux-gnu", +} + +SUPPORTED_ARCHIVE_EXTENSIONS = [ + ".zip", + ".jar", + ".war", + ".aar", + ".tar", + ".tar.gz", + ".tgz", + ".tar.xz", + ".txz", + ".tar.zst", + ".tzst", + ".tar.bz2", + ".tbz", + ".ar", + ".deb", + ".whl", +] + +def get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def get_archive_name(url): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the archive name without extension.""" + filename = _get_file_name(url) + for extension in SUPPORTED_ARCHIVE_EXTENSIONS: + if filename.endswith(extension): + return filename[:-len(extension)] + return filename + +LIB_EXTENSION = ".so." + +def _get_lib_name_and_version(path): + extension_index = path.rfind(LIB_EXTENSION) + last_slash_index = path.rfind("/") + lib_name = path[last_slash_index + 1:extension_index] + lib_version = path[extension_index + len(LIB_EXTENSION):] + return (lib_name, lib_version) + +def _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_dir_path = repository_ctx.path("lib") + if not lib_dir_path.exists: + return [] + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]).lower() + lib_dir_content = lib_dir_path.readdir() + return [ + str(f) + for f in lib_dir_content + if (LIB_EXTENSION in str(f) and + main_lib_name in str(f).lower()) + ] + +def get_lib_name_to_version_dict(repository_ctx): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns a dict of library names and major versions.""" + lib_name_to_version_dict = {} + for path in _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_name, lib_version = _get_lib_name_and_version(path) + key = "%%{%s_version}" % lib_name.lower() + + # We need to find either major or major.minor version if there is no + # file with major version. E.g. if we have the following files: + # libcudart.so + # libcudart.so.12 + # libcudart.so.12.3.2, + # we will save save {"%{libcudart_version}": "12"}. + if len(lib_version.split(".")) == 1: + lib_name_to_version_dict[key] = lib_version + if (len(lib_version.split(".")) == 2 and + key not in lib_name_to_version_dict): + lib_name_to_version_dict[key] = lib_version + return lib_name_to_version_dict + +def create_dummy_build_file(repository_ctx, use_comment_symbols = True): + repository_ctx.template( + "BUILD", + repository_ctx.attr.build_templates[0], + { + "%{multiline_comment}": "'''" if use_comment_symbols else "", + "%{comment}": "#" if use_comment_symbols else "", + }, + ) + +def _get_build_template(repository_ctx, major_lib_version): + template = None + for i in range(0, len(repository_ctx.attr.versions)): + for dist_version in repository_ctx.attr.versions[i].split(","): + if dist_version == major_lib_version: + template = repository_ctx.attr.build_templates[i] + break + if not template: + fail("No build template found for {} version {}".format( + repository_ctx.name, + major_lib_version, + )) + return template + +def get_major_library_version(repository_ctx, lib_name_to_version_dict): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the major library version provided the versions dict.""" + major_version = "" + if len(lib_name_to_version_dict) == 0: + return major_version + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]) + key = "%%{%s_version}" % main_lib_name + major_version = lib_name_to_version_dict[key] + return major_version + +def create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_lib_version): + # buildifier: disable=function-docstring-args + """Creates a BUILD file for the repository.""" + if len(major_lib_version) == 0: + build_template_content = repository_ctx.read( + repository_ctx.attr.build_templates[0], + ) + if "_version}" not in build_template_content: + create_dummy_build_file(repository_ctx, use_comment_symbols = False) + else: + create_dummy_build_file(repository_ctx) + return + build_template = _get_build_template( + repository_ctx, + major_lib_version.split(".")[0], + ) + repository_ctx.template( + "BUILD", + build_template, + lib_name_to_version_dict | { + "%{multiline_comment}": "", + "%{comment}": "", + }, + ) + +def _create_symlinks(repository_ctx, local_path, dirs): + for dir in dirs: + repository_ctx.symlink( + "{path}/{dir}".format( + path = local_path, + dir = dir, + ), + dir, + ) + +def use_local_path(repository_ctx, local_path, dirs): + # buildifier: disable=function-docstring-args + """Creates repository using local redistribution paths.""" + _create_symlinks( + repository_ctx, + local_path, + dirs, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _use_local_cuda_path(repository_ctx, local_cuda_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDA repository.""" + use_local_path( + repository_ctx, + local_cuda_path, + ["include", "lib", "bin", "nvvm"], + ) + +def _use_local_cudnn_path(repository_ctx, local_cudnn_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDNN repository.""" + use_local_path(repository_ctx, local_cudnn_path, ["include", "lib"]) + +def _download_redistribution(repository_ctx, arch_key, path_prefix): + (url, sha256) = repository_ctx.attr.url_dict[arch_key] + + # If url is not relative, then appending prefix is not needed. + if not (url.startswith("http") or url.startswith("file:///")): + url = path_prefix + url + archive_name = get_archive_name(url) + file_name = _get_file_name(url) + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + if repository_ctx.attr.override_strip_prefix: + strip_prefix = repository_ctx.attr.override_strip_prefix + else: + strip_prefix = archive_name + repository_ctx.extract( + archive = file_name, + stripPrefix = strip_prefix, + ) + repository_ctx.delete(file_name) + +def _use_downloaded_cuda_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDA redistribution and initializes hermetic CUDA repository.""" + major_version = "" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cuda_version: + # If no CUDA version is found, comment out all cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cuda_redist_path_prefix, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version(repository_ctx, lib_name_to_version_dict) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _cuda_repo_impl(repository_ctx): + local_cuda_path = get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + if local_cuda_path: + _use_local_cuda_path(repository_ctx, local_cuda_path) + else: + _use_downloaded_cuda_redistribution(repository_ctx) + +cuda_repo = repository_rule( + implementation = _cuda_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cuda_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDA_PATH", + ], +) + +def _use_downloaded_cudnn_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDNN redistribution and initializes hermetic CUDNN repository.""" + cudnn_version = None + major_version = "" + cudnn_version = (get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cudnn_version: + # If no CUDNN version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + arch_key = "cuda{version}_{arch}".format( + version = cuda_version.split(".")[0], + arch = arch_key, + ) + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cudnn_redist_path_prefix, + ) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _cudnn_repo_impl(repository_ctx): + local_cudnn_path = get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + if local_cudnn_path: + _use_local_cudnn_path(repository_ctx, local_cudnn_path) + else: + _use_downloaded_cudnn_redistribution(repository_ctx) + +cudnn_repo = repository_rule( + implementation = _cudnn_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cudnn_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDNN_VERSION", + "TF_CUDNN_VERSION", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDNN_PATH", + ], +) + +def _get_redistribution_urls(dist_info): + url_dict = {} + for arch in _REDIST_ARCH_DICT.keys(): + if "relative_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["relative_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + if "full_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["full_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + for cuda_version, data in dist_info[arch].items(): + # CUDNN JSON might contain paths for each CUDA version. + path_key = "relative_path" + if path_key not in data.keys(): + path_key = "full_path" + url_dict["{cuda_version}_{arch}".format( + cuda_version = cuda_version, + arch = _REDIST_ARCH_DICT[arch], + )] = [data[path_key], data.get("sha256", "")] + return url_dict + +def get_version_and_template_lists(version_to_template): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns lists of versions and templates provided in the dict.""" + template_to_version_map = {} + for version, template in version_to_template.items(): + if template not in template_to_version_map.keys(): + template_to_version_map[template] = [version] + else: + template_to_version_map[template].append(version) + version_list = [] + template_list = [] + for template, versions in template_to_version_map.items(): + version_list.append(",".join(versions)) + template_list.append(Label(template)) + return (version_list, template_list) + +def cudnn_redist_init_repository( + cudnn_redistributions, + cudnn_redist_path_prefix = CUDNN_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDNN repository.""" + if "cudnn" in cudnn_redistributions.keys(): + url_dict = _get_redistribution_urls(cudnn_redistributions["cudnn"]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates["cudnn"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cudnn_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cudnn_redist_path_prefix = cudnn_redist_path_prefix, + ) + +def cuda_redist_init_repositories( + cuda_redistributions, + cuda_redist_path_prefix = CUDA_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDA repositories.""" + for redist_name, _ in redist_versions_to_build_templates.items(): + if redist_name in ["cudnn", "cuda_nccl"]: + continue + if redist_name in cuda_redistributions.keys(): + url_dict = _get_redistribution_urls(cuda_redistributions[redist_name]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates[redist_name] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cuda_redist_path_prefix = cuda_redist_path_prefix, + ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl new file mode 100644 index 00000000000000..d7ccff736a4801 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -0,0 +1,243 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistribution versions.""" + +CUDA_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cuda/redist/" +CUDNN_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redist/" + +CUDA_REDIST_JSON_DICT = { + "11.8": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_11.8.0.json", + "941a950a4ab3b95311c50df7b3c8bca973e0cdda76fc2f4b456d2d5e4dac0281", + ], + "12.1.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.1.1.json", + "bafea3cb83a4cf5c764eeedcaac0040d0d3c5db3f9a74550da0e7b6ac24d378c", + ], + "12.2.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.2.0.json", + "d883762c6339c8ebb3ffb072facc8f7265cd257d2db16a475fff9a9306ecea89", + ], + "12.3.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.1.json", + "b3cc4181d711cf9b6e3718f323b23813c24f9478119911d7b4bceec9b437dbc3", + ], + "12.3.2": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.2.json", + "1b6eacf335dd49803633fed53ef261d62c193e5a56eee5019e7d2f634e39e7ef", + ], + "12.4.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.4.0.json", + "a4f496b8d5299939b34c9ef88dc4274821f8c9451b2d7c9bcee53166932da067", + ], + "12.4.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.4.1.json", + "9cd815f3b71c2e3686ef2219b7794b81044f9dcefaa8e21dacfcb5bc4d931892", + ], + "12.5.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.5.0.json", + "166664b520bfe51f27abcc8c7a934f4cb6ea287f8c399b5f8255f6f4d214569a", + ], + "12.5.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.5.1.json", + "7ab9c76014ae4907fa1b51738af599607a5fd8ca3a5c4bb4c3b31338cc642a93", + ], + "12.6.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json", + "87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab", + ], +} + +CUDNN_REDIST_JSON_DICT = { + "8.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.6.0.json", + "7f6f50bed4fd8216dc10d6ef505771dc0ecc99cce813993ab405cb507a21d51d", + ], + "8.9.4.25": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.4.25.json", + "02258dba8384860c9230fe3c78522e7bd8e350e461ccd37a8d932cb64127ba57", + ], + "8.9.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.6.json", + "6069ef92a2b9bb18cebfbc944964bd2b024b76f2c2c35a43812982e0bc45cf0c", + ], + "8.9.7.29": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.7.29.json", + "a0734f26f068522464fa09b2f2c186dfbe6ad7407a88ea0c50dd331f0c3389ec", + ], + "9.1.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.1.1.json", + "d22d569405e5683ff8e563d00d6e8c27e5e6a902c564c23d752b22a8b8b3fe20", + ], + "9.2.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.2.0.json", + "6852eb279b95d2b5775f7a7737ec133bed059107f863cdd8588f3ae6f13eadd7", + ], + "9.2.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.2.1.json", + "9a4198c59b2e66b2b115a736ebe4dc8f3dc6d78161bb494702f824da8fc77b99", + ], + "9.3.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.3.0.json", + "d17d9a7878365736758550294f03e633a0b023bec879bf173349bfb34781972e", + ], +} + +# The versions are different for x86 and aarch64 architectures because only +# NCCL release versions 2.20.3 and 2.20.5 have the wheels for aarch64. +CUDA_12_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", + }, + "aarch64-unknown-linux-gnu": { + "version": "2.20.5", + "url": "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", + "sha256": "1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", + }, +} + +CUDA_11_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/ac/9a/8b6a28b3b87d5fddab0e92cd835339eb8fbddaa71ae67518c8c1b3d05bae/nvidia_nccl_cu11-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "49d8350629c7888701d1fd200934942671cb5c728f49acc5a0b3a768820bed29", + }, +} + +CUDA_NCCL_WHEELS = { + "11.8": CUDA_11_NCCL_WHEEL_DICT, + "12.1.1": CUDA_12_NCCL_WHEEL_DICT, + "12.2.0": CUDA_12_NCCL_WHEEL_DICT, + "12.3.1": CUDA_12_NCCL_WHEEL_DICT, + "12.3.2": CUDA_12_NCCL_WHEEL_DICT, + "12.4.0": CUDA_12_NCCL_WHEEL_DICT, + "12.1.0": CUDA_12_NCCL_WHEEL_DICT, + "12.5.0": CUDA_12_NCCL_WHEEL_DICT, + "12.5.1": CUDA_12_NCCL_WHEEL_DICT, + "12.6.0": CUDA_12_NCCL_WHEEL_DICT, +} + +REDIST_VERSIONS_TO_BUILD_TEMPLATES = { + "cuda_nccl": { + "repo_name": "cuda_nccl", + "version_to_template": { + "2": "//third_party/nccl/hermetic:cuda_nccl.BUILD.tpl", + }, + }, + "cudnn": { + "repo_name": "cuda_cudnn", + "version_to_template": { + "9": "//third_party/gpus/cuda/hermetic:cuda_cudnn9.BUILD.tpl", + "8": "//third_party/gpus/cuda/hermetic:cuda_cudnn.BUILD.tpl", + }, + }, + "libcublas": { + "repo_name": "cuda_cublas", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + }, + }, + "cuda_cudart": { + "repo_name": "cuda_cudart", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + }, + }, + "libcufft": { + "repo_name": "cuda_cufft", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + "10": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + }, + }, + "cuda_cupti": { + "repo_name": "cuda_cupti", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + }, + }, + "libcurand": { + "repo_name": "cuda_curand", + "version_to_template": { + "10": "//third_party/gpus/cuda/hermetic:cuda_curand.BUILD.tpl", + }, + }, + "libcusolver": { + "repo_name": "cuda_cusolver", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cusolver.BUILD.tpl", + }, + }, + "libcusparse": { + "repo_name": "cuda_cusparse", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + }, + }, + "libnvjitlink": { + "repo_name": "cuda_nvjitlink", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", + }, + }, + "cuda_nvrtc": { + "repo_name": "cuda_nvrtc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + }, + }, + "cuda_cccl": { + "repo_name": "cuda_cccl", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + }, + }, + "cuda_nvcc": { + "repo_name": "cuda_nvcc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + }, + }, + "cuda_nvml_dev": { + "repo_name": "cuda_nvml", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + }, + }, + "cuda_nvprune": { + "repo_name": "cuda_nvprune", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + }, + }, + "cuda_nvtx": { + "repo_name": "cuda_nvtx", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + }, + }, +} diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl index fefbf081c87e1c..8bf1db2b0f8f9f 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for CUDA autoconfiguration. +NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. + `cuda_configure` depends on the following environment variables: * `TF_NEED_CUDA`: Whether to enable building with CUDA. @@ -53,6 +55,11 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" @@ -67,20 +74,6 @@ _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" -def to_list_of_strings(elements): - """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. - - This is to be used to put a list of strings into the bzl file templates - so it gets interpreted as list of strings in Starlark. - - Args: - elements: list of string elements - - Returns: - single string of elements wrapped in quotes separated by a comma.""" - quoted_strings = ["\"" + element + "\"" for element in elements] - return ", ".join(quoted_strings) - def verify_build_defines(params): """Verify all variables that crosstool/BUILD.tpl expects are substituted. @@ -238,156 +231,6 @@ def find_cc(repository_ctx, use_cuda_clang): " environment variable").format(target_cc_name, cc_path_envvar)) return cc -_INC_DIR_MARKER_BEGIN = "#include <...>" - -# OSX add " (framework directory)" at the end of line, strip it. -_OSX_FRAMEWORK_SUFFIX = " (framework directory)" -_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) - -def _cxx_inc_convert(path): - """Convert path returned by cc -E xc++ in a complete path.""" - path = path.strip() - if path.endswith(_OSX_FRAMEWORK_SUFFIX): - path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() - return path - -def _normalize_include_path(repository_ctx, path): - """Normalizes include paths before writing them to the crosstool. - - If path points inside the 'crosstool' folder of the repository, a relative - path is returned. - If path points outside the 'crosstool' folder, an absolute path is returned. - """ - path = str(repository_ctx.path(path)) - crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) - - if path.startswith(crosstool_folder): - # We drop the path to "$REPO/crosstool" and a trailing path separator. - return path[len(crosstool_folder) + 1:] - return path - -def _is_compiler_option_supported(repository_ctx, cc, option): - """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" - result = repository_ctx.execute([ - cc, - option, - "-o", - "/dev/null", - "-c", - str(repository_ctx.path("tools/cpp/empty.cc")), - ]) - return result.stderr.find(option) == -1 - -def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): - """Compute the list of default C or C++ include directories.""" - if lang_is_cpp: - lang = "c++" - else: - lang = "c" - sysroot = [] - if tf_sysroot: - sysroot += ["--sysroot", tf_sysroot] - result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + - sysroot) - stderr = err_out(result) - index1 = stderr.find(_INC_DIR_MARKER_BEGIN) - if index1 == -1: - return [] - index1 = stderr.find("\n", index1) - if index1 == -1: - return [] - index2 = stderr.rfind("\n ") - if index2 == -1 or index2 < index1: - return [] - index2 = stderr.find("\n", index2 + 1) - if index2 == -1: - inc_dirs = stderr[index1 + 1:] - else: - inc_dirs = stderr[index1 + 1:index2].strip() - - print_resource_dir_supported = _is_compiler_option_supported( - repository_ctx, - cc, - "-print-resource-dir", - ) - - if print_resource_dir_supported: - resource_dir = repository_ctx.execute( - [cc, "-print-resource-dir"], - ).stdout.strip() + "/share" - inc_dirs += "\n" + resource_dir - - compiler_includes = [ - _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) - for p in inc_dirs.split("\n") - ] - - # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc - # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) - # but Bazel might encounter either (usually reported by the compiler) - # especially when a compiler wrapper (e.g. ccache) is used. - # So we need to also include paths where symlinks are not resolved. - - # Try to find real path to CC installation to "see through" compiler wrappers - # GCC has the path to g++ - index1 = result.stderr.find("COLLECT_GCC=") - if index1 != -1: - index1 = result.stderr.find("=", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname - else: - # Clang has the directory - index1 = result.stderr.find("InstalledDir: ") - if index1 != -1: - index1 = result.stderr.find(" ", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname - else: - # Fallback to the CC path - cc_topdir = repository_ctx.path(cc).dirname.dirname - - # We now have the compiler installation prefix, e.g. /symlink/gcc - # And the resolved installation prefix, e.g. /opt/gcc - cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() - cc_topdir = str(cc_topdir).strip() - - # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. - # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] - # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] - if cc_topdir_resolved != cc_topdir: - unresolved_compiler_includes = [ - cc_topdir + inc[len(cc_topdir_resolved):] - for inc in compiler_includes - if inc.startswith(cc_topdir_resolved) - ] - compiler_includes = compiler_includes + unresolved_compiler_includes - return compiler_includes - -def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): - """Compute the list of default C and C++ include directories.""" - - # For some reason `clang -xc` sometimes returns include paths that are - # different from the ones from `clang -xc++`. (Symlink and a dir) - # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists - includes_cpp = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - True, - tf_sysroot, - ) - includes_c = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - False, - tf_sysroot, - ) - - return includes_cpp + [ - inc - for inc in includes_c - if inc not in includes_cpp - ] - def auto_configure_fail(msg): """Output failure message when cuda configuration fails.""" red = "\033[0;31m" @@ -1293,6 +1136,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cuda_nvcc_files}"] = "[]" if is_cuda_clang and not is_nvcc_and_clang: cuda_defines["%{host_compiler_path}"] = str(cc) cuda_defines["%{host_compiler_warnings}"] = """ diff --git a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py index b88694af5c014d..68623bf671da71 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py @@ -14,6 +14,9 @@ # ============================================================================== """Prints CUDA library and header directories and versions found on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + The script searches for CUDA library and header files on the system, inspects them to determine their version and prints the configuration to stdout. The paths to inspect and the required versions are specified through environment diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl index ff9b53b407be44..fb63d4db886c1c 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -22,12 +22,15 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "enable_cuda", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) load( ":sycl_configure.bzl", diff --git a/third_party/xla/third_party/tsl/third_party/gpus/sycl_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/sycl_configure.bzl index 05330b2fe53195..dd80694e7274f5 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/sycl_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/sycl_configure.bzl @@ -16,11 +16,14 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" diff --git a/third_party/xla/third_party/tsl/third_party/nccl/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/nccl/build_defs.bzl.tpl index 53a6d4e1e41890..a0930df34ecec8 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/nccl/build_defs.bzl.tpl @@ -5,7 +5,6 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") # CUDA toolkit version as tuple (e.g. '(11, 1)'). _cuda_version = %{cuda_version} -_cuda_clang = %{cuda_clang} def _rdc_copts(): """Returns copts for compiling relocatable device code.""" @@ -121,25 +120,25 @@ _device_link = rule( "gpu_archs": attr.string_list(), "nvlink_args": attr.string_list(), "_nvlink": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvlink"), + default = Label("%{nvlink_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_fatbinary": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/fatbinary"), + default = Label("%{fatbinary_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_bin2c": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/bin2c"), + default = Label("%{bin2c_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_link_stub": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/crt/link.stub"), + default = Label("%{link_stub_label}"), allow_single_file = True, ), }, @@ -189,7 +188,7 @@ _prune_relocatable_code = rule( "input": attr.label(mandatory = True, allow_files = True), "gpu_archs": attr.string_list(), "_nvprune": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvprune"), + default = Label("%{nvprune_label}"), allow_single_file = True, executable = True, cfg = "host", diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/BUILD b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl new file mode 100644 index 00000000000000..61d7809bcdaad1 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl @@ -0,0 +1,30 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nccl_shared_library", + shared_library = "lib/libnccl.so.%{libnccl_version}", + hdrs = [":headers"], + deps = ["@local_config_cuda//cuda:cuda_headers", ":headers"], +) +%{multiline_comment} +cc_library( + name = "nccl", + %{comment}deps = [":nccl_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nccl*.h", + %{comment}]), + include_prefix = "third_party/nccl", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["//visibility:public"], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl new file mode 100644 index 00000000000000..75f5a10b6fe24e --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl @@ -0,0 +1,183 @@ +"""Repository rule for hermetic NCCL configuration. + +`nccl_configure` depends on the following environment variables: + + * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should + be used, "0" if NCCL should be linked in statically. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + +""" + +load( + "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "TF_NEED_CUDA", + "enable_cuda", + "get_cuda_version", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", +) + +_TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" + +_NCCL_DUMMY_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_STUB_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl_via_stub", + }), + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +def _create_local_nccl_repository(repository_ctx): + cuda_version = get_cuda_version(repository_ctx).split(".")[:2] + nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + + if get_host_environ(repository_ctx, _TF_NCCL_USE_STUB, "0") == "0": + repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) + else: + repository_ctx.file("BUILD", _NCCL_ARCHIVE_STUB_BUILD_CONTENT) + + repository_ctx.template("generated_names.bzl", repository_ctx.attr.generated_names_tpl, {}) + repository_ctx.template( + "build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), + "%{nvlink_label}": "@cuda_nvcc//:nvlink", + "%{fatbinary_label}": "@cuda_nvcc//:fatbinary", + "%{bin2c_label}": "@cuda_nvcc//:bin2c", + "%{link_stub_label}": "@cuda_nvcc//:link_stub", + "%{nvprune_label}": "@cuda_nvprune//:nvprune", + }, + ) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"%s\"" % nccl_version) + +def _nccl_autoconf_impl(repository_ctx): + if (not enable_cuda(repository_ctx) or + get_cpu_value(repository_ctx) != "Linux"): + # Add a dummy build file to make bazel query happy. + repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") + else: + _create_local_nccl_repository(repository_ctx) + +_ENVIRONS = [ + TF_NEED_CUDA, + TF_CUDA_VERSION, + _TF_NCCL_USE_STUB, + HERMETIC_CUDA_VERSION, + "LOCAL_NCCL_PATH", +] + +nccl_configure = repository_rule( + environ = _ENVIRONS, + implementation = _nccl_autoconf_impl, + attrs = { + "environ": attr.string_dict(), + "nccl_version": attr.label(default = Label("@cuda_nccl//:version.txt")), + "generated_names_tpl": attr.label(default = Label("//third_party/nccl:generated_names.bzl.tpl")), + "build_defs_tpl": attr.label(default = Label("//third_party/nccl:build_defs.bzl.tpl")), + }, +) +"""Downloads and configures the hermetic NCCL configuration. + +Add the following to your WORKSPACE file: + +```python +nccl_configure(name = "local_config_nccl") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl new file mode 100644 index 00000000000000..244cb851ddf591 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl @@ -0,0 +1,145 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic NCCL repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "OS_ARCH_DICT", + "create_build_file", + "create_dummy_build_file", + "get_archive_name", + "get_env_var", + "get_lib_name_to_version_dict", + "get_major_library_version", + "get_version_and_template_lists", + "use_local_path", +) +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_NCCL_WHEELS", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +def _use_downloaded_nccl_wheel(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads NCCL wheel and inits hermetic NCCL repository.""" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + major_version = "" + if not cuda_version: + # If no CUDA version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch = OS_ARCH_DICT[repository_ctx.os.arch] + dict_key = "{cuda_version}-{arch}".format( + cuda_version = cuda_version, + arch = arch, + ) + supported_versions = repository_ctx.attr.url_dict.keys() + if dict_key not in supported_versions: + fail( + ("The supported NCCL versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add NCCL distribution for" + + " CUDA version={version}, OS={arch}.") + .format( + supported_versions = supported_versions, + version = cuda_version, + arch = arch, + ), + ) + sha256 = repository_ctx.attr.sha256_dict[dict_key] + url = repository_ctx.attr.url_dict[dict_key] + + archive_name = get_archive_name(url) + file_name = archive_name + ".zip" + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + repository_ctx.extract( + archive = file_name, + stripPrefix = repository_ctx.attr.strip_prefix, + ) + repository_ctx.delete(file_name) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _use_local_nccl_path(repository_ctx, local_nccl_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic NCCL repository.""" + use_local_path(repository_ctx, local_nccl_path, ["include", "lib"]) + +def _cuda_nccl_repo_impl(repository_ctx): + local_nccl_path = get_env_var(repository_ctx, "LOCAL_NCCL_PATH") + if local_nccl_path: + _use_local_nccl_path(repository_ctx, local_nccl_path) + else: + _use_downloaded_nccl_wheel(repository_ctx) + +cuda_nccl_repo = repository_rule( + implementation = _cuda_nccl_repo_impl, + attrs = { + "sha256_dict": attr.string_dict(mandatory = True), + "url_dict": attr.string_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "strip_prefix": attr.string(), + }, + environ = ["HERMETIC_CUDA_VERSION", "TF_CUDA_VERSION", "LOCAL_NCCL_PATH"], +) + +def nccl_redist_init_repository( + cuda_nccl_wheels = CUDA_NCCL_WHEELS, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes NCCL repository.""" + nccl_artifacts_dict = {"sha256_dict": {}, "url_dict": {}} + for cuda_version, nccl_wheel_info in cuda_nccl_wheels.items(): + for arch in OS_ARCH_DICT.values(): + if arch in nccl_wheel_info.keys(): + cuda_version_to_arch_key = "%s-%s" % (cuda_version, arch) + nccl_artifacts_dict["sha256_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch].get("sha256", "") + nccl_artifacts_dict["url_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch]["url"] + repo_data = redist_versions_to_build_templates["cuda_nccl"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_nccl_repo( + name = repo_data["repo_name"], + sha256_dict = nccl_artifacts_dict["sha256_dict"], + url_dict = nccl_artifacts_dict["url_dict"], + versions = versions, + build_templates = templates, + strip_prefix = "nvidia/nccl", + ) diff --git a/third_party/xla/third_party/tsl/third_party/nccl/nccl_configure.bzl b/third_party/xla/third_party/tsl/third_party/nccl/nccl_configure.bzl index 22cf64d4771062..59f8b5c08ef0ee 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/nccl_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/nccl/nccl_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for NCCL configuration. +NB: DEPRECATED! Use `hermetic/nccl_configure` rule instead. + `nccl_configure` depends on the following environment variables: * `TF_NCCL_VERSION`: Installed NCCL version or empty to build from source. @@ -8,7 +10,6 @@ files. * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is `/usr/local/cuda,usr/`. - * `TF_CUDA_CLANG`: "1" if using Clang, "0" if using NVCC. * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should be used, "0" if NCCL should be linked in statically. @@ -33,7 +34,6 @@ _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" _TF_NCCL_VERSION = "TF_NCCL_VERSION" _TF_NEED_CUDA = "TF_NEED_CUDA" _TF_CUDA_PATHS = "TF_CUDA_PATHS" -_TF_CUDA_CLANG = "TF_CUDA_CLANG" _TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" _DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR" @@ -129,7 +129,11 @@ def _create_local_nccl_repository(repository_ctx): _label("build_defs.bzl.tpl"), { "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), - "%{cuda_clang}": repr(get_host_environ(repository_ctx, _TF_CUDA_CLANG)), + "%{nvlink_label}": "@local_config_cuda//cuda:cuda/bin/nvlink", + "%{fatbinary_label}": "@local_config_cuda//cuda:cuda/bin/fatbinary", + "%{bin2c_label}": "@local_config_cuda//cuda:cuda/bin/bin2c", + "%{link_stub_label}": "@local_config_cuda//cuda:cuda/bin/crt/link.stub", + "%{nvprune_label}": "@local_config_cuda//cuda:cuda/bin/nvprune", }, ) else: @@ -181,7 +185,6 @@ _ENVIRONS = [ _TF_CUDA_COMPUTE_CAPABILITIES, _TF_NEED_CUDA, _TF_CUDA_PATHS, - _TF_CUDA_CLANG, ] remote_nccl_configure = repository_rule( diff --git a/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl b/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl index f8fdd1033b5e2f..13aed2b687129f 100644 --- a/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl @@ -255,8 +255,12 @@ def _basic_wildcard_match(name, patterns, expected_match_result, match_all): def _custom_python_interpreter_impl(ctx): version = ctx.attr.version - strip_prefix = ctx.attr.strip_prefix.format(version = version) - urls = [url.format(version = version) for url in ctx.attr.urls] + version_variant = ctx.attr.version_variant + strip_prefix = ctx.attr.strip_prefix.format( + version = version, + version_variant = version_variant, + ) + urls = [url.format(version = version, version_variant = version_variant) for url in ctx.attr.urls] binary_name = ctx.attr.binary_name if not binary_name: ver_chunks = version.split(".") @@ -272,13 +276,12 @@ def _custom_python_interpreter_impl(ctx): output = srcs_dir, ) - configure_params = [] + configure_params = list(ctx.attr.configure_params) if "CC" in ctx.os.environ: configure_params.append("CC={}".format(ctx.os.environ["CC"])) if "CXX" in ctx.os.environ: configure_params.append("CXX={}".format(ctx.os.environ["CXX"])) - configure_params.append("--enable-optimizations") configure_params.append("--prefix=%s" % install_path.realpath) _exec_and_check( ctx, @@ -361,6 +364,11 @@ custom_python_interpreter = repository_rule( "strip_prefix": attr.string(), "binary_name": attr.string(mandatory = False), "version": attr.string(), + "version_variant": attr.string(), + "configure_params": attr.string_list( + mandatory = False, + default = ["--enable-optimizations"], + ), }, ) diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD new file mode 100644 index 00000000000000..8d626dc7635d1a --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD @@ -0,0 +1,7 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +# spirv_llvm_translator license placeholder diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD new file mode 100644 index 00000000000000..557e2e8f50edd2 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD @@ -0,0 +1,34 @@ +cc_library( + name = "spirv_llvm_translator", + srcs = glob([ + "lib/SPIRV/libSPIRV/*.cpp", + "lib/SPIRV/libSPIRV/*.hpp", + "lib/SPIRV/libSPIRV/*.h", + "lib/SPIRV/Mangler/*.cpp", + "lib/SPIRV/Mangler/*.h", + "lib/SPIRV/*.cpp", + "lib/SPIRV/*.hpp", + "lib/SPIRV/*.h", + ]), + hdrs = glob(["include/*"]), + includes = [ + "include/", + "lib/SPIRV/", + "lib/SPIRV/Mangler/", + "lib/SPIRV/libSPIRV/", + ], + visibility = ["//visibility:public"], + deps = [ + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@spirv_headers//:spirv_cpp_headers", + ], +) diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch new file mode 100644 index 00000000000000..fc843b1b039b09 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch @@ -0,0 +1,25 @@ +diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h +index a828add8..924e13b4 100644 + +Spir backend uses different addrspace representations link with nvptx backend link. +We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding +changing addrspace based on device backend everywhere) + +--- a/lib/SPIRV/SPIRVInternal.h ++++ b/lib/SPIRV/SPIRVInternal.h +@@ -179,11 +179,12 @@ typedef SPIRVMap IntBoolOpMap; + "-v512:512:512-v1024:1024:1024" + + enum SPIRAddressSpace { +- SPIRAS_Private, ++ SPIRAS_Generic, + SPIRAS_Global, +- SPIRAS_Constant, ++ SPIRAS_Internal, + SPIRAS_Local, +- SPIRAS_Generic, ++ SPIRAS_Constant, ++ SPIRAS_Private, + SPIRAS_GlobalDevice, + SPIRAS_GlobalHost, + SPIRAS_Input, \ No newline at end of file diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 0c28198f980b95..9a4dfa2aafdc51 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -225,8 +225,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -236,8 +236,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "9.1", + cuda_version = "12.3.2", + cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -248,8 +248,8 @@ def initialize_rbe_configs(): name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -258,8 +258,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -270,8 +270,8 @@ def initialize_rbe_configs(): name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -479,7 +479,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -558,7 +558,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -710,11 +710,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -749,11 +749,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -788,12 +788,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -826,12 +826,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -864,12 +864,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "9.1", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "9.1.1", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "10.0", }, ) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl index 18a84d96c39f82..dbd7bad8d855c6 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl @@ -1,8 +1,8 @@ """Macro that creates external repositories for remote config.""" -load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure") +load("//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") @@ -42,7 +42,7 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_CUDNN_VERSION": cudnn_version, "TF_CUDA_VERSION": cuda_version, "CUDNN_INSTALL_PATH": cudnn_install_path if cudnn_install_path != None else "/usr/lib/x86_64-linux-gnu", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": tensorrt_version if tensorrt_version != None else "", "TENSORRT_INSTALL_PATH": tensorrt_install_path if tensorrt_install_path != None else "/usr/lib/x86_64-linux-gnu", "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", @@ -51,20 +51,26 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_SYSROOT": sysroot if sysroot else "", }) - container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os) + cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) + cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) + container_name = "cuda%s-cudnn%s-%s" % ( + cuda_version_in_container, + cudnn_version_in_container, + os, + ) container_image = _container_image_uri(container_name) exec_properties = { "container-image": container_image, "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, @@ -175,13 +181,13 @@ def sigbuild_tf_configs(name_container_map, env): "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, diff --git a/third_party/xla/third_party/tsl/tsl/lib/core/BUILD b/third_party/xla/third_party/tsl/tsl/lib/core/BUILD index b0d6e94cd0330b..a5b0791d5d28ce 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/core/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/core/BUILD @@ -36,7 +36,6 @@ filegroup( srcs = [ "bitmap.h", "bits.h", - "status_test_util.h", "@local_xla//xla/tsl/lib/core:legacy_lib_core_status_test_util_header", ], compatible_with = get_compatible_with_portable(), @@ -68,7 +67,6 @@ filegroup( filegroup( name = "legacy_lib_core_status_test_util_header", srcs = [ - "status_test_util.h", "@local_xla//xla/tsl/lib/core:legacy_lib_core_status_test_util_header", ], compatible_with = get_compatible_with_portable(), @@ -90,17 +88,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "status_test_util", - testonly = 1, - hdrs = ["status_test_util.h"], - deps = [ - "//tsl/platform:status_matchers", - "//tsl/platform:test", - "@local_xla//xla/tsl/lib/core:status_test_util", - ], -) - cc_library( name = "bits", hdrs = ["bits.h"], diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD b/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD index c15c9293dbdbf6..302c0c412ef11b 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD @@ -71,7 +71,6 @@ cc_library( deps = [ ":collection_registry", ":metric_def", - "//tsl/lib/histogram", "//tsl/platform", "//tsl/platform:macros", "//tsl/platform:mutex", @@ -81,6 +80,7 @@ cc_library( "//tsl/protobuf:histogram_proto_cc", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/lib/histogram", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/sampler.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/sampler.h index 63f583bd4f1b44..e17f3cdae00d91 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/sampler.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/sampler.h @@ -122,7 +122,7 @@ class Sampler { #include #include -#include "tsl/lib/histogram/histogram.h" +#include "xla/tsl/lib/histogram/histogram.h" #include "tsl/lib/monitoring/collection_registry.h" #include "tsl/lib/monitoring/metric_def.h" #include "tsl/platform/macros.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index f777279540bc2b..20d43489eefa67 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -3,6 +3,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "@local_xla//xla/tsl:tsl.bzl", + "if_hermetic_cuda_tools", "if_not_fuchsia", "if_not_windows", "if_oss", @@ -59,6 +60,9 @@ cc_library( srcs = ["cuda_libdevice_path.cc"], hdrs = ["//tsl/platform:cuda_libdevice_path.h"], compatible_with = [], + data = if_hermetic_cuda_tools([ + "@cuda_nvcc//:nvvm", + ]), tags = [ "manual", "no_oss", @@ -66,6 +70,7 @@ cc_library( ], deps = [ "//tsl/platform", + "//tsl/platform:env", "//tsl/platform:logging", "//tsl/platform:path", "//tsl/platform:types", diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc index 46321e74b5dc38..ac0a804b4dfd42 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc @@ -31,6 +31,7 @@ limitations under the License. #if !defined(PLATFORM_GOOGLE) #include "third_party/gpus/cuda/cuda_config.h" +#include "tsl/platform/env.h" #endif #include "tsl/platform/logging.h" @@ -38,8 +39,25 @@ namespace tsl { std::vector CandidateCudaRoots() { #if !defined(PLATFORM_GOOGLE) - auto roots = std::vector{TF_CUDA_TOOLKIT_PATH, - std::string("/usr/local/cuda")}; + auto roots = std::vector{}; + std::string runfiles_suffix = "runfiles"; + + // The CUDA candidate root for c++ targets. + std::string executable_path = tsl::Env::Default()->GetExecutablePath(); + std::string cuda_nvcc_dir = + io::JoinPath(executable_path + "." + runfiles_suffix, "cuda_nvcc"); + roots.emplace_back(cuda_nvcc_dir); + + // The CUDA candidate root for python targets. + std::string runfiles_dir = tsl::Env::Default()->GetRunfilesDir(); + std::size_t runfiles_ind = runfiles_dir.rfind(runfiles_suffix); + cuda_nvcc_dir = io::JoinPath( + runfiles_dir.substr(0, runfiles_ind + runfiles_suffix.length()), + "cuda_nvcc"); + roots.emplace_back(cuda_nvcc_dir); + + roots.emplace_back(TF_CUDA_TOOLKIT_PATH); + roots.emplace_back(std::string("/usr/local/cuda")); #if defined(PLATFORM_POSIX) && !defined(__APPLE__) Dl_info info; @@ -53,13 +71,17 @@ std::vector CandidateCudaRoots() { // relative to the current binary for the wheel-based nvcc package. for (auto path : {"../nvidia/cuda_nvcc", "../../nvidia/cuda_nvcc"}) roots.emplace_back(io::JoinPath(dir, path)); + + // Also add the path to the copy of libdevice.10.bc that we include within + // the Python wheel. + roots.emplace_back(io::JoinPath(dir, "cuda")); } #endif // defined(PLATFORM_POSIX) && !defined(__APPLE__) for (auto root : roots) VLOG(3) << "CUDA root = " << root; return roots; #else // !defined(PLATFORM_GOOGLE) - return {std::string("/usr/local/cuda")}; + return {}; #endif //! defined(PLATFORM_GOOGLE) } diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc index 868fb35f887dab..e5dbff497ad710 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc @@ -411,6 +411,13 @@ void* AlignedMalloc(size_t size, int minimum_alignment) { void AlignedFree(void* aligned_memory) { Free(aligned_memory); } +void AlignedSizedFree(void* aligned_memory, size_t alignment, size_t size) { + (void)alignment; + (void)size; + + Free(aligned_memory); +} + void* Malloc(size_t size) { return malloc(size); } void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/mem.h b/third_party/xla/third_party/tsl/tsl/platform/mem.h index 0f32727f0f753d..6d0dc803e93b80 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/mem.h +++ b/third_party/xla/third_party/tsl/tsl/platform/mem.h @@ -28,6 +28,7 @@ namespace port { // and a multiple of sizeof(void*). void* AlignedMalloc(size_t size, int minimum_alignment); void AlignedFree(void* aligned_memory); +void AlignedSizedFree(void* aligned_memory, size_t alignment, size_t size); void* Malloc(size_t size); void* Realloc(void* ptr, size_t size); diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc index f8e19503edb305..57600173577329 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc @@ -211,6 +211,13 @@ void* AlignedMalloc(size_t size, int minimum_alignment) { void AlignedFree(void* aligned_memory) { _aligned_free(aligned_memory); } +void AlignedSizedFree(void* aligned_memory, size_t alignment, size_t size) { + (void)alignment; + (void)size; + + _aligned_free(aligned_memory); +} + void* Malloc(size_t size) { return malloc(size); } void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc index 438f98c2b3ef24..4943fba0c1bfea 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc @@ -36,7 +36,7 @@ namespace { // nvtxNameOsThreadA: // https://nvidia.github.io/NVTX/doxygen/group___r_e_s_o_u_r_c_e___n_a_m_i_n_g.html // This convention may not match the one in tsl::Env::GetCurrentThreadId(). -std::optional GetCurrentThreadId() { +std::optional MaybeGetCurrentThreadId() { #ifdef __linux__ return syscall(SYS_gettid); #else @@ -57,7 +57,8 @@ ProfilerDomainHandle DefaultProfilerDomain() { } void NameCurrentThread(const std::string& thread_name) { - if (std::optional tid = GetCurrentThreadId(); tid.has_value()) { + if (std::optional tid = MaybeGetCurrentThreadId(); + tid.has_value()) { nvtxNameOsThreadA(*tid, thread_name.c_str()); } } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h index 478dae87b8a399..ef303663b3d142 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h @@ -46,9 +46,9 @@ class ProfilerLock { ProfilerLock& operator=(const ProfilerLock&) = delete; // Movable. - ProfilerLock(ProfilerLock&& other) + ProfilerLock(ProfilerLock&& other) noexcept : active_(std::exchange(other.active_, false)) {} - ProfilerLock& operator=(ProfilerLock&& other) { + ProfilerLock& operator=(ProfilerLock&& other) noexcept { active_ = std::exchange(other.active_, false); return *this; } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h index 75c2902f323d05..da9fe210737dd9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h @@ -146,8 +146,8 @@ class TraceMe { } // Movable. - TraceMe(TraceMe&& other) { *this = std::move(other); } - TraceMe& operator=(TraceMe&& other) { + TraceMe(TraceMe&& other) noexcept { *this = std::move(other); } + TraceMe& operator=(TraceMe&& other) noexcept { #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(other.start_time_ != kUntracedActivity)) { name_.Emplace(std::move(other.name_).Consume()); diff --git a/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD index fdf0979d82ea69..141f1b6e6edf82 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD @@ -29,7 +29,6 @@ cc_library( "@local_xla//xla/python:__pkg__", "//tensorflow/core/profiler/rpc/client:__pkg__", "//tensorflow/python/profiler/internal:__pkg__", - "//learning/pathways/data_parallel:__pkg__", ]), deps = [ ":profiler_client_for_pybind", @@ -62,6 +61,7 @@ cc_library( "@local_xla//xla/python:__pkg__", "//tsl/profiler:internal", "//tsl/profiler/rpc:__pkg__", + "//learning/pathways/data_parallel:__pkg__", ]), deps = [ "//tsl/lib/io:zlib_compression_options", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD index 203657d0744e82..39113ebb7fc07f 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD @@ -363,6 +363,7 @@ tsl_cc_test( ":tpu_xplane_utils", ":xplane_schema", ":xplane_utils", + ":xplane_visitor", "//tsl/platform:test", "//tsl/platform:test_main", "//tsl/profiler/protobuf:xplane_proto_cc", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc index 19841f53ce7fdb..9274a1da941743 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tsl/platform/regexp.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/utils/xplane_schema.h" @@ -48,5 +49,11 @@ std::optional GetTensorCoreId(absl::string_view plane_name) { return std::nullopt; } +std::optional GetSparseCoreId(absl::string_view plane_name) { + std::optional core_id; + RE2::FullMatch(plane_name, {kSparseCorePlaneRegex}, &core_id); + return core_id; +} + } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h index f3a150ca37e607..2fb7c677e3a058 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h @@ -36,6 +36,10 @@ std::vector FindMutableTensorCorePlanes( // TensorCore plane name. std::optional GetTensorCoreId(absl::string_view plane_name); +// Get Sparsecore Id from SparseCore plane name if plane name is a valid +// SparseCore plane name. +std::optional GetSparseCoreId(absl::string_view plane_name); + } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc index a385c77821c347..e5bcd73c339be9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include "tsl/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_utils.h" +#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { namespace { +using ::testing::Optional; using ::testing::UnorderedElementsAre; TEST(TpuXPlaneUtilsTest, GetTensorCoreXPlanesFromXSpace) { @@ -65,6 +67,22 @@ TEST(TpuXPlaneUtilsTest, IsNotTensorCorePlaneNameWithPrefix) { GetTensorCoreId(absl::StrCat("/prefix", TpuPlaneName(0))).has_value()); } +TEST(TpuXplaneUtilsTest, GetSparseCorePlanesFromXSpace) { + XSpace space; + XPlane* p1 = FindOrAddMutablePlaneWithName(&space, TpuPlaneName(0)); + XPlane* p2 = FindOrAddMutablePlaneWithName(&space, TpuPlaneName(1)); + XPlane* p3 = FindOrAddMutablePlaneWithName( + &space, absl::StrCat(TpuPlaneName(0), " SparseCore 0")); + XPlane* p4 = FindOrAddMutablePlaneWithName( + &space, absl::StrCat(TpuPlaneName(0), " SparseCore 1")); + + EXPECT_THAT(FindTensorCorePlanes(space), UnorderedElementsAre(p1, p2)); + EXPECT_THAT(FindPlanesWithPrefix(space, kTpuPlanePrefix), + UnorderedElementsAre(p1, p2, p3, p4)); + EXPECT_THAT(GetSparseCoreId(p3->name()), Optional(0)); + EXPECT_THAT(GetSparseCoreId(p4->name()), Optional(1)); +} + } // namespace } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index 33de2b0f6c3e19..2cd8aaa74b55b0 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -33,6 +33,8 @@ const absl::string_view kGpuPlanePrefix = "/device:GPU:"; const absl::string_view kTpuPlanePrefix = "/device:TPU:"; const absl::string_view kTpuNonCorePlaneNamePrefix = "#Chip"; const char kTpuPlaneRegex[] = {"/device:TPU:([0-9]*)$"}; +const char kSparseCorePlaneRegex[] = { + "/device:TPU:[0-9]+ SparseCore ([0-9]+)$"}; // TODO(b/195582092): change it to /device:custom once all literals are // migrated. const absl::string_view kCustomPlanePrefix = "/device:CUSTOM:"; diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h index 2e693b4474b92d..edf808b864648e 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -41,6 +41,8 @@ TF_CONST_INIT extern const absl::string_view kGpuPlanePrefix; TF_CONST_INIT extern const absl::string_view kTpuPlanePrefix; // Regex for XPlanes that contain TensorCore planes. TF_CONST_INIT extern const char kTpuPlaneRegex[]; +// Regex for XPlanes that contain TPU Core planes. +TF_CONST_INIT extern const char kSparseCorePlaneRegex[]; // Name prefix of XPlane that contains custom device events. TF_CONST_INIT extern const absl::string_view kCustomPlanePrefix; // Name prefix of XPlane that contains TPU non-core events such as HBM, ICI etc. diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD index 65000ff408801c..10e1dac5abc717 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD +++ b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD @@ -20,13 +20,6 @@ package( licenses = ["notice"], ) -tf_proto_library( - name = "bfc_memory_map_proto", - srcs = ["bfc_memory_map.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - tf_proto_library( name = "dnn_proto", srcs = ["dnn.proto"], @@ -123,7 +116,7 @@ tf_proto_library( protodeps = [ # TODO(tlongeri): Conceptually, these fit into protos_all but adding them currently causes # breakages (and they are not actually used). - ":bfc_memory_map_proto", + "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", ":coordination_config_proto", ":distributed_runtime_payloads_proto", ":error_codes_proto_impl", diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index 0a2993f3542ba4..7b85e735b1f880 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -17,14 +17,12 @@ load("//third_party/eigen3:workspace.bzl", eigen3 = "repo") load("//third_party/farmhash:workspace.bzl", farmhash = "repo") load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo") load("//third_party/git:git_configure.bzl", "git_configure") -load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/gpus:sycl_configure.bzl", "sycl_configure") load("//third_party/hwloc:workspace.bzl", hwloc = "repo") load("//third_party/implib_so:workspace.bzl", implib_so = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") load("//third_party/nasm:workspace.bzl", nasm = "repo") -load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo") load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo") @@ -69,9 +67,7 @@ def _tf_toolchains(): # Note that we check the minimum bazel version in WORKSPACE. clang6_configure(name = "local_config_clang6") cc_download_clang_toolchain(name = "local_config_download_clang") - cuda_configure(name = "local_config_cuda") tensorrt_configure(name = "local_config_tensorrt") - nccl_configure(name = "local_config_nccl") git_configure(name = "local_config_git") syslibs_configure(name = "local_config_syslibs") python_configure(name = "local_config_python") @@ -160,13 +156,13 @@ def _tf_repositories(): tf_http_archive( name = "mkl_dnn_acl_compatible", - build_file = "//tensorflow/third_party/mkl_dnn:mkldnn_acl.BUILD", + build_file = "//third_party/mkl_dnn:mkldnn_acl.BUILD", patch_file = [ - "//tensorflow/third_party/mkl_dnn:onednn_acl_threadcap.patch", - "//tensorflow/third_party/mkl_dnn:onednn_acl_reorder.patch", - "//tensorflow/third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch", - "//tensorflow/third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch", - "//tensorflow/third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch", + "//third_party/mkl_dnn:onednn_acl_threadcap.patch", + "//third_party/mkl_dnn:onednn_acl_reorder.patch", + "//third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch", + "//third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch", + "//third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch", ], sha256 = "2f76b407ef8893cca71340f88cd800019a1f14f8ac1bbdbb89a84be1370b52e3", strip_prefix = "oneDNN-3.2.1", @@ -560,9 +556,9 @@ def _tf_repositories(): tf_http_archive( name = "pybind11", - urls = tf_mirror_urls("https://github.com/pybind/pybind11/archive/v2.10.0.tar.gz"), - sha256 = "eacf582fa8f696227988d08cfc46121770823839fe9e301a20fbce67e7cd70ec", - strip_prefix = "pybind11-2.10.0", + urls = tf_mirror_urls("https://github.com/pybind/pybind11/archive/v2.13.4.tar.gz"), + sha256 = "efc901aa0aab439a3fea6efeaf930b5a349fb06394bf845c64ce15a9cf8f0240", + strip_prefix = "pybind11-2.13.4", build_file = "//third_party:pybind11.BUILD", system_build_file = "//third_party/systemlibs:pybind11.BUILD", ) @@ -591,6 +587,22 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/google/glog/archive/refs/tags/v0.4.0.tar.gz"), ) + tf_http_archive( + name = "spirv_headers", + sha256 = "11d835c60297b26532c05c3f3b581ba7a2787b5ae7399e94f72c392169216f11", + strip_prefix = "SPIRV-Headers-b73e168ca5e123dcf3dea8a34b19a5130f421ae1", + urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-Headers/archive/b73e168ca5e123dcf3dea8a34b19a5130f421ae1.tar.gz"), + ) + + tf_http_archive( + name = "spirv_llvm_translator", + sha256 = "d499769f4fd1e0ce9d4dbd3622ee7e3e641b5623dcdf811521e3e7c0bdb1e6c2", + strip_prefix = "SPIRV-LLVM-Translator-dad1f0eaab8047a4f73c50ed5f3d1694b78aae97", + build_file = "//third_party/spirv_llvm_translator:spirv_llvm_translator.BUILD", + patch_file = ["//third_party/spirv_llvm_translator:spirv_llvm_translator.patch"], + urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-LLVM-Translator/archive/dad1f0eaab8047a4f73c50ed5f3d1694b78aae97.tar.gz"), + ) + # buildifier: disable=unnamed-macro def workspace(): # Check the bazel version before executing any repository rules, in case diff --git a/third_party/xla/third_party/uv/uv.BUILD b/third_party/xla/third_party/uv/uv.BUILD index b04383ad3487e7..43c194a53ea516 100644 --- a/third_party/xla/third_party/uv/uv.BUILD +++ b/third_party/xla/third_party/uv/uv.BUILD @@ -55,7 +55,19 @@ cc_library( # TODO: Add Linux, etc. as in https://github.com/libuv/libuv/blob/v1.38.0/CMakeLists.txt. hdrs = [ "include/uv.h", - ], + "src/heap-inl.h", + "src/idna.h", + "src/queue.h", + "src/strscpy.h", + "src/unix/atomic-ops.h", + "src/unix/internal.h", + "src/unix/spinlock.h", + "src/uv-common.h", + ] + select({ + "@platforms//os:osx": [ + "src/unix/darwin-stub.h", + ], + }) + glob(["include/uv/*.h"]), copts = [ "-fexceptions", "-Wno-unused-variable", diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl index 0c28198f980b95..9a4dfa2aafdc51 100644 --- a/third_party/xla/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl @@ -225,8 +225,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -236,8 +236,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "9.1", + cuda_version = "12.3.2", + cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -248,8 +248,8 @@ def initialize_rbe_configs(): name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -258,8 +258,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -270,8 +270,8 @@ def initialize_rbe_configs(): name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -479,7 +479,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -558,7 +558,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -710,11 +710,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -749,11 +749,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -788,12 +788,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -826,12 +826,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -864,12 +864,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "9.1", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "9.1.1", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "10.0", }, ) diff --git a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl index 18a84d96c39f82..ec2ac4cc8ea430 100644 --- a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl @@ -1,8 +1,8 @@ """Macro that creates external repositories for remote config.""" -load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") +load("@local_config_cuda//cuda/hermetic:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") @@ -42,7 +42,7 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_CUDNN_VERSION": cudnn_version, "TF_CUDA_VERSION": cuda_version, "CUDNN_INSTALL_PATH": cudnn_install_path if cudnn_install_path != None else "/usr/lib/x86_64-linux-gnu", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": tensorrt_version if tensorrt_version != None else "", "TENSORRT_INSTALL_PATH": tensorrt_install_path if tensorrt_install_path != None else "/usr/lib/x86_64-linux-gnu", "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", @@ -51,20 +51,26 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_SYSROOT": sysroot if sysroot else "", }) - container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os) + cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) + cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) + container_name = "cuda%s-cudnn%s-%s" % ( + cuda_version_in_container, + cudnn_version_in_container, + os, + ) container_image = _container_image_uri(container_name) exec_properties = { "container-image": container_image, "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, @@ -175,13 +181,13 @@ def sigbuild_tf_configs(name_container_map, env): "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, diff --git a/third_party/xla/warnings.bazelrc b/third_party/xla/warnings.bazelrc index ac219136da23c1..ae92c8c9db8472 100644 --- a/third_party/xla/warnings.bazelrc +++ b/third_party/xla/warnings.bazelrc @@ -4,13 +4,8 @@ build:warnings --copt=-Werror --host_copt=-Werror # ...and silence them outside of the workspace. build:warnings --per_file_copt=external/.*@-w -# ...and silence them on host builds. There is no host_per_file_copt and -# everything we build in the host configuration we either also build in the -# target configuration or is external, so we can't control it. -# If/when Bazel supports --host_per_file_copt, we could use that instead: -# https://github.com/bazelbuild/bazel/issues/12406. -# Would need to then make all the --copt below duplicated with --host_copt. -build:warnings --host_copt=-w +# ...and silence them on host builds. +build:warnings --host_per_file_copt=external/.*@-w build:warnings --copt=-Wall build:warnings --copt=-Werror diff --git a/third_party/xla/workspace0.bzl b/third_party/xla/workspace0.bzl index 76b8ed2bbae1f2..f0b37ee94921f4 100644 --- a/third_party/xla/workspace0.bzl +++ b/third_party/xla/workspace0.bzl @@ -5,6 +5,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies") load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") +load("@com_google_benchmark//:bazel/benchmark_deps.bzl", "benchmark_deps") load("@local_tsl//:workspace0.bzl", "tsl_workspace0") def _tf_bind(): @@ -125,6 +126,9 @@ def workspace(): swift_rules_dependencies() apple_support_dependencies() + # We only need `benchmark_deps` to be able to have bazel query to work and not complain about missing `@libpfm`. + benchmark_deps() + # If a target is bound twice, the later one wins, so we have to do tf bindings # at the end of the WORKSPACE file. _tf_bind() diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index e2244c1ae9d216..dea8d378e31806 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -16,6 +16,7 @@ load("//third_party/robin_map:workspace.bzl", robin_map = "repo") load("//third_party/shardy:workspace.bzl", shardy = "repo") load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") load("//third_party/triton:workspace.bzl", triton = "repo") +load("//third_party/uv:workspace.bzl", uv = "repo") def _initialize_third_party(): """ Load third party repositories. See above load() statements. """ @@ -27,6 +28,7 @@ def _initialize_third_party(): shardy() stablehlo() triton() + uv() # Define all external repositories required by TensorFlow def _tf_repositories(): diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 5745c8d953e3b2..bd161a52b8757e 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -516,6 +516,7 @@ xla_cc_test( ":test", ":util", ":xla_data_proto_cc", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1133,6 +1134,7 @@ cc_library( [ ":parse_flags_from_env", ":xla_proto_cc", + "//xla/stream_executor/cuda:nvjitlink_support", "//xla/stream_executor/cuda:ptx_compiler_support", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/algorithm:container", @@ -1260,10 +1262,10 @@ cc_library( deps = [ ":autotune_results_proto_cc", ":autotuning_proto_cc", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/strings:proto_serialization", ], ) diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h index 6a6f50574e1d9e..03c5f3b9760c4b 100644 --- a/third_party/xla/xla/array.h +++ b/third_party/xla/xla/array.h @@ -603,12 +603,12 @@ class Array { std::fill(data.get(), data.get() + size, init); } - OwnedBuffer(OwnedBuffer&& other) + OwnedBuffer(OwnedBuffer&& other) noexcept : data(std::move(other.data)), size(other.size) { other.size = 0; } - OwnedBuffer& operator=(OwnedBuffer&& other) { + OwnedBuffer& operator=(OwnedBuffer&& other) noexcept { data = std::move(other.data); size = other.size; other.size = 0; diff --git a/third_party/xla/xla/autotune_result_wrapper.cc b/third_party/xla/xla/autotune_result_wrapper.cc index 855c8aaeb13f5d..ee92f173d6a4d4 100644 --- a/third_party/xla/xla/autotune_result_wrapper.cc +++ b/third_party/xla/xla/autotune_result_wrapper.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/BUILD rename to third_party/xla/xla/backends/cpu/runtime/BUILD index a7e97b03e7b40c..56ded99d3af407 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -1,7 +1,7 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") load("//xla/service/cpu:build_defs.bzl", "runtime_copts") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_windows", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") package( @@ -19,13 +19,16 @@ package_group( filegroup( name = "runtime_srcs", - srcs = ["conv_impl.cc"], + srcs = [ + "convolution_thunk_f16.cc", + "convolution_thunk_f32.cc", + ], visibility = internal_visibility([":friends"]), ) filegroup( name = "runtime_hdrs", - srcs = ["conv_impl.h"], + srcs = ["convolution_thunk_internal.h"], visibility = internal_visibility([":friends"]), ) @@ -124,6 +127,7 @@ cc_library( name = "thunk_executor", srcs = ["thunk_executor.cc"], hdrs = ["thunk_executor.h"], + defines = if_windows(["_ENABLE_EXTENDED_ALIGNED_STORAGE"]), deps = [ ":resource_use", ":thunk", @@ -271,16 +275,18 @@ cc_library( ) cc_library( - name = "conv_impl", - srcs = ["conv_impl.cc"], - hdrs = ["conv_impl.h"], + name = "convolution_thunk_internal", + srcs = [ + "convolution_thunk_f16.cc", + "convolution_thunk_f32.cc", + ], + hdrs = ["convolution_thunk_internal.h"], copts = runtime_copts(), visibility = internal_visibility([":friends"]), deps = [ "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:mutex", # build_cleaner: keep ], ) @@ -290,7 +296,7 @@ cc_library( hdrs = ["convolution_thunk.h"], copts = runtime_copts(), deps = [ - ":conv_impl", + ":convolution_thunk_internal", ":thunk", "//xla:executable_run_options", "//xla:shape_util", @@ -556,7 +562,9 @@ cc_library( "//xla:util", "//xla/ffi:attribute_map", "//xla/ffi:call_frame", + "//xla/ffi:execution_state", "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:custom_call_status", @@ -567,6 +575,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", @@ -777,7 +786,9 @@ cc_library( "//xla/stream_executor/host:host_kernel", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", @@ -808,6 +819,7 @@ xla_cc_test( "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/backends/cpu/runtime/README.md b/third_party/xla/xla/backends/cpu/runtime/README.md new file mode 100644 index 00000000000000..84d313e5a2afe4 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/README.md @@ -0,0 +1,16 @@ +# XLA:CPU Runtime + +XLA:CPU runtime is implemented as a collection of `Thunks` that are responsible +for executing individual operations. XLA fusions, for example are jit-compiled +to executables using LLVM, and executed at run time by `KernelThunk`. Operations +that are not compiled have corresponding thunks, i.e., `FFT` operations is +executed as `FftThunk` and relies on DUCC FFT implementation. + +Thunks are executed concurrently using `ThunkExecutor`, which launches thunks +when all data dependencies are ready. We rely on buffer assignment to track read +and write conflicts, and compute a directed acyclic graph that defines execution +order. + +Conceptually, XLA:CPU runtime is similar to XLA:GPU, which also has thunks. +However, for CPU backend we do a lot more multi-threading to be able to +efficiently use all available cores on the host CPU. diff --git a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc similarity index 95% rename from third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc index 3bb705ebf9fcd2..fa55bbc48dbffc 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/all_gather_thunk.h" +#include "xla/backends/cpu/runtime/all_gather_thunk.h" #include #include @@ -24,11 +24,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.h similarity index 85% rename from third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.h index 28ba6c6ace84a1..2d2dca9a7eac9d 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_ #include #include "absl/status/statusor.h" -#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -40,4 +40,4 @@ class AllGatherThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc index 923d03ce7fd464..a5d9d283867c2d 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/all_reduce_thunk.h" +#include "xla/backends/cpu/runtime/all_reduce_thunk.h" #include #include @@ -26,12 +26,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.h similarity index 87% rename from third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.h index f4580b0f63be45..77866382353e02 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -45,4 +45,4 @@ class AllReduceThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc index d55486602d6546..8badd0c4e7e232 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/all_to_all_thunk.h" +#include "xla/backends/cpu/runtime/all_to_all_thunk.h" #include #include @@ -23,11 +23,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.h similarity index 85% rename from third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.h index 0c24627354829b..b58afe94394572 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ #include #include "absl/status/statusor.h" -#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -40,4 +40,4 @@ class AllToAllThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/buffer_allocations.h rename to third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h index 4d757261c5a39e..44d71712a9c19c 100644 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h +++ b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ -#define XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#define XLA_BACKENDS_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ #include #include @@ -141,4 +141,4 @@ BufferAllocations::GetDeviceAddressUnchecked( } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc rename to third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc index 9fd7d447825de7..c92be6205ac910 100644 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" #include #include diff --git a/third_party/xla/xla/service/cpu/runtime/call_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/call_thunk.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/call_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/call_thunk.cc index a0a4d2bf5c9673..0473ad78e40f49 100644 --- a/third_party/xla/xla/service/cpu/runtime/call_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/call_thunk.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/call_thunk.h" +#include "xla/backends/cpu/runtime/call_thunk.h" #include #include #include "absl/memory/memory.h" #include "absl/status/statusor.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" diff --git a/third_party/xla/xla/service/cpu/runtime/call_thunk.h b/third_party/xla/xla/backends/cpu/runtime/call_thunk.h similarity index 84% rename from third_party/xla/xla/service/cpu/runtime/call_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/call_thunk.h index e6c9ecbd3544e8..b7addf7297c392 100644 --- a/third_party/xla/xla/service/cpu/runtime/call_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/call_thunk.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CALL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_CALL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_CALL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CALL_THUNK_H_ #include #include "absl/status/statusor.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -45,4 +45,4 @@ class CallThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_CALL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CALL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc index 1908c3ff66e40c..a830c0f7fd4ea1 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/collective_permute_thunk.h" +#include "xla/backends/cpu/runtime/collective_permute_thunk.h" #include #include @@ -28,12 +28,12 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.h similarity index 86% rename from third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.h index 6478ced6f1e939..702b2f2b15f3dd 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ #include #include @@ -23,7 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -51,4 +51,4 @@ class CollectivePermuteThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/collective_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc index 32a452a6bcdd0d..4bebdd09cd31c1 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include #include @@ -32,13 +32,13 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h similarity index 94% rename from third_party/xla/xla/service/cpu/runtime/collective_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/collective_thunk.h index 5ae9c98844f887..8efc767838806d 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_THUNK_H_ #include #include @@ -27,11 +27,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/stream_executor/device_memory.h" @@ -122,4 +122,4 @@ class CollectiveThunk : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/conditional_thunk.cc index 4ee46a975e6217..42246dd1d3df51 100644 --- a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/conditional_thunk.h" +#include "xla/backends/cpu/runtime/conditional_thunk.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.h b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk.h similarity index 85% rename from third_party/xla/xla/service/cpu/runtime/conditional_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/conditional_thunk.h index 6185b6dad9b27b..0b01d8517a6ff4 100644 --- a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CONDITIONAL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONDITIONAL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_CONDITIONAL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CONDITIONAL_THUNK_H_ #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -48,4 +48,4 @@ class ConditionalThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_CONDITIONAL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CONDITIONAL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc index d24a58dec3edcc..a5222a8de6bb3d 100644 --- a/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/conditional_thunk.h" +#include "xla/backends/cpu/runtime/conditional_thunk.h" #include #include #include #include +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_testlib.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk.cc index c7bdd0a2ccf18e..e4dd0ef3f98ce2 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/convolution_thunk.h" +#include "xla/backends/cpu/runtime/convolution_thunk.h" #define EIGEN_USE_THREADS @@ -31,10 +31,10 @@ limitations under the License. #include "absl/types/span.h" #include "Eigen/Core" #include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/executable_run_options.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/conv_impl.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime_conv2d_acl.h" #include "xla/shape.h" #include "xla/status_macros.h" @@ -328,7 +328,7 @@ ConvolutionThunk::HandleEigen2DConvolution(const ExecuteParams& params, std::optional> done_callback = std::nullopt) { using scalar_type = decltype(type_tag); - ::tensorflow::xla::EigenConv2DImpl( + internal::EigenConv2D( eigen_device, static_cast(output.opaque()), static_cast(input.opaque()), static_cast(kernel.opaque()), input_batch_, input_dims_.x, @@ -368,7 +368,7 @@ ConvolutionThunk::HandleEigen3DConvolution(const ExecuteParams& params, std::optional> done_callback = std::nullopt) { using scalar_type = decltype(type_tag); - ::tensorflow::xla::EigenConv3DImpl( + internal::EigenConv3D( eigen_device, static_cast(output.opaque()), static_cast(input.opaque()), static_cast(kernel.opaque()), input_batch_, input_dims_.x, diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.h b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk.h similarity index 94% rename from third_party/xla/xla/service/cpu/runtime/convolution_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk.h index d3ba1173369827..de4f7629ae48dd 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk.h @@ -13,18 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_H_ #include #include #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -123,4 +124,4 @@ class ConvolutionThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/core/status_test_util.h b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f16.cc similarity index 62% rename from third_party/xla/third_party/tsl/tsl/lib/core/status_test_util.h rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f16.cc index a15aa79a181ad8..7b6e2ae17d1855 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/core/status_test_util.h +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f16.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ -#define TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" -#include "xla/tsl/lib/core/status_test_util.h" +CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); +CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); -#endif // TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ +CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); +CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); diff --git a/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f32.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f32.cc new file mode 100644 index 00000000000000..b93314b8474444 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f32.cc @@ -0,0 +1,27 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" +#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" // IWYU pragma: keep + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" // IWYU pragma: keep +#endif + +CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); +CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); + +CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); +CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); diff --git a/third_party/xla/xla/service/cpu/runtime/conv_impl.h b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h similarity index 63% rename from third_party/xla/xla/service/cpu/runtime/conv_impl.h rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h index b97bc85a4edc73..3275f9d8fa8455 100644 --- a/third_party/xla/xla/service/cpu/runtime/conv_impl.h +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,36 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CONV_IMPL_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONV_IMPL_H_ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_INTERNAL_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_INTERNAL_H_ + +#define EIGEN_USE_THREADS #include #include +#include "Eigen/Core" #include "unsupported/Eigen/CXX11/Tensor" -#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" -#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) -#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" -#endif +namespace xla::cpu::internal { -// 'tensorflow' namespace is used so that types don't require qualification. -namespace tensorflow { -namespace xla { +// TODO(ezhulenev): Make internal implementation a private static method of +// ConvolutionThunk (for consistency with DotThunk). Today we keep it as a free +// function to use it in the legacy XLA CPU runtime. template -void EigenConv2DImpl( - const EigenDevice& device, ScalarType* out, ScalarType* lhs, - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, - Eigen::Index input_y, Eigen::Index input_channels, Eigen::Index kernel_x, - Eigen::Index kernel_y, Eigen::Index kernel_channels, - Eigen::Index kernel_filters, Eigen::Index output_x, Eigen::Index output_y, - Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index padding_x_before, - Eigen::Index padding_x_after, Eigen::Index padding_y_before, - Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, - Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, - Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, - std::optional> done_callback) { +void EigenConv2D(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, Eigen::Index input_batch, + Eigen::Index input_x, Eigen::Index input_y, + Eigen::Index input_channels, Eigen::Index kernel_x, + Eigen::Index kernel_y, Eigen::Index kernel_channels, + Eigen::Index kernel_filters, Eigen::Index output_x, + Eigen::Index output_y, Eigen::Index x_stride, + Eigen::Index y_stride, Eigen::Index padding_x_before, + Eigen::Index padding_x_after, Eigen::Index padding_y_before, + Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, + Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, + Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, + std::optional> done_callback) { const Eigen::TensorMap, Eigen::Aligned> input(lhs, input_batch, input_x, input_y, input_channels); @@ -114,22 +116,23 @@ void EigenConv2DImpl( } template -void EigenConv3DImpl( - const EigenDevice& device, ScalarType* out, ScalarType* lhs, - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, - Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, - Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, - Eigen::Index kernel_channels, Eigen::Index kernel_filters, - Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, - Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, - Eigen::Index padding_x_before, Eigen::Index padding_x_after, - Eigen::Index padding_y_before, Eigen::Index padding_y_after, - Eigen::Index padding_z_before, Eigen::Index padding_z_after, - Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, - Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, - Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, - Eigen::Index feature_group_count, - std::optional> done_callback) { +void EigenConv3D(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, Eigen::Index input_batch, + Eigen::Index input_x, Eigen::Index input_y, + Eigen::Index input_z, Eigen::Index input_channels, + Eigen::Index kernel_x, Eigen::Index kernel_y, + Eigen::Index kernel_z, Eigen::Index kernel_channels, + Eigen::Index kernel_filters, Eigen::Index output_x, + Eigen::Index output_y, Eigen::Index output_z, + Eigen::Index x_stride, Eigen::Index y_stride, + Eigen::Index z_stride, Eigen::Index padding_x_before, + Eigen::Index padding_x_after, Eigen::Index padding_y_before, + Eigen::Index padding_y_after, Eigen::Index padding_z_before, + Eigen::Index padding_z_after, Eigen::Index lhs_x_dilation, + Eigen::Index lhs_y_dilation, Eigen::Index lhs_z_dilation, + Eigen::Index rhs_x_dilation, Eigen::Index rhs_y_dilation, + Eigen::Index rhs_z_dilation, Eigen::Index feature_group_count, + std::optional> done_callback) { using ConstTType = Eigen::TensorMap, Eigen::Aligned>; @@ -210,10 +213,10 @@ void EigenConv3DImpl( } // Extern Conv2D template for all supported devices and data types. -#define CONV2D_EXTERN_TEMPLATE(EigenDevice, ScalarType) \ - extern template void EigenConv2DImpl( \ - const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ +#define CONV2D_EXTERN_TEMPLATE(DEVICE, SCALAR_TYPE) \ + extern template void EigenConv2D( \ + const DEVICE& device, SCALAR_TYPE* out, SCALAR_TYPE* lhs, \ + SCALAR_TYPE* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ Eigen::Index input_y, Eigen::Index input_channels, \ Eigen::Index kernel_x, Eigen::Index kernel_y, \ Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ @@ -233,10 +236,10 @@ CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float); #undef CONV2D_EXTERN_TEMPLATE // Extern Conv3D template for all supported devices and data types. -#define CONV3D_EXTERN_TEMPLATE(EigenDevice, ScalarType) \ - extern template void EigenConv3DImpl( \ - const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ +#define CONV3D_EXTERN_TEMPLATE(DEVICE, SCALAR_TYPE) \ + extern template void EigenConv3D( \ + const DEVICE& device, SCALAR_TYPE* out, SCALAR_TYPE* lhs, \ + SCALAR_TYPE* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \ Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \ Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ @@ -258,7 +261,39 @@ CONV3D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float); #undef CONV3D_EXTERN_TEMPLATE -} // namespace xla -} // namespace tensorflow +} // namespace xla::cpu::internal + +#define CONV2D_INSTANTIATE_TEMPLATE(DEVICE, SCALAR_TYPE) \ + template void xla::cpu::internal::EigenConv2D( \ + const DEVICE& device, SCALAR_TYPE* out, SCALAR_TYPE* lhs, \ + SCALAR_TYPE* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ + Eigen::Index input_y, Eigen::Index input_channels, \ + Eigen::Index kernel_x, Eigen::Index kernel_y, \ + Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ + Eigen::Index output_x, Eigen::Index output_y, Eigen::Index x_stride, \ + Eigen::Index y_stride, Eigen::Index padding_x_before, \ + Eigen::Index padding_x_after, Eigen::Index padding_y_before, \ + Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \ + Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \ + Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \ + std::optional> done_callback) + +#define CONV3D_INSTANTIATE_TEMPLATE(DEVICE, SCALAR_TYPE) \ + template void xla::cpu::internal::EigenConv3D( \ + const DEVICE& device, SCALAR_TYPE* out, SCALAR_TYPE* lhs, \ + SCALAR_TYPE* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ + Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \ + Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \ + Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ + Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, \ + Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, \ + Eigen::Index padding_x_before, Eigen::Index padding_x_after, \ + Eigen::Index padding_y_before, Eigen::Index padding_y_after, \ + Eigen::Index padding_z_before, Eigen::Index padding_z_after, \ + Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, \ + Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \ + Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \ + Eigen::Index feature_group_count, \ + std::optional> done_callback) -#endif // XLA_SERVICE_CPU_RUNTIME_CONV_IMPL_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_INTERNAL_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc index 3671431333d595..20a75d1f97ebcc 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/convolution_thunk.h" +#include "xla/backends/cpu/runtime/convolution_thunk.h" #include #include @@ -25,10 +25,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "Eigen/Core" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/copy_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/copy_thunk.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/copy_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/copy_thunk.cc index 1ea16dbdbf4d53..67b4d557256950 100644 --- a/third_party/xla/xla/service/cpu/runtime/copy_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/copy_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/copy_thunk.h" +#include "xla/backends/cpu/runtime/copy_thunk.h" #define EIGEN_USE_THREADS @@ -34,10 +34,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/pjrt/transpose.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/service/cpu/runtime/copy_thunk.h b/third_party/xla/xla/backends/cpu/runtime/copy_thunk.h similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/copy_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/copy_thunk.h index a65425c7f5427d..ed2cd68df5137a 100644 --- a/third_party/xla/xla/service/cpu/runtime/copy_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/copy_thunk.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_COPY_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_COPY_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_COPY_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_COPY_THUNK_H_ #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/pjrt/transpose.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -69,4 +69,4 @@ class CopyThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_COPY_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_COPY_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/copy_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/copy_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc index 406d6b1a8aa7dc..8a8e4fb4debd27 100644 --- a/third_party/xla/xla/service/cpu/runtime/copy_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/copy_thunk.h" +#include "xla/backends/cpu/runtime/copy_thunk.h" #include #include +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc similarity index 88% rename from third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc index 8c6deca2d24064..8c774ba7759c35 100644 --- a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/custom_call_thunk.h" +#include "xla/backends/cpu/runtime/custom_call_thunk.h" #include #include @@ -26,6 +26,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -35,18 +36,21 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/call_frame.h" +#include "xla/ffi/execution_state.h" #include "xla/ffi/ffi_api.h" #include "xla/primitive_util.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/custom_call_target_registry.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" @@ -111,6 +115,36 @@ absl::StatusOr BuildCallFrameForTypedFFI( return builder.Build(); } +absl::Status InstantiateHandlerState(absl::string_view target_name, + ffi::ExecutionState* execution_state) { + // Find the registered FFI handler for this target. + auto handler = ffi::FindHandler(target_name, "Host"); + if (!handler.ok()) { + return NotFound( + "No registered implementation for FFI custom call to %s for Host", + target_name); + } + + // Initialize FFI handler state if it has an instantiate callback. + if (handler->bundle.instantiate) { + // At FFI handler instantiation time, we don't have any arguments or + // results or access to the underlying device (stream, etc.) + ffi::CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + + // TODO(abanas): Add attributes support. All attributes should be accessible + // at all phases, namely instantiation and execution. Also add tests for CPU + // and GPU backends (GPU supports it, but tests are missing there). + ffi::CallFrame instantiate_call_frame = builder.Build(); + + ffi::CallOptions options; + options.execution_state = execution_state; + TF_RETURN_IF_ERROR(Call(handler->bundle.instantiate, instantiate_call_frame, + options, XLA_FFI_ExecutionStage_INSTANTIATE)); + } + + return absl::OkStatus(); +} + } // namespace absl::StatusOr> CustomCallThunk::Create( @@ -121,7 +155,13 @@ absl::StatusOr> CustomCallThunk::Create( TF_ASSIGN_OR_RETURN( call_frame, BuildCallFrameForTypedFFI(api_version, op_buffers, backend_config)); + + // TODO(abanas): Pass execution state to thunk. + auto execution_state = std::make_unique(); + TF_RETURN_IF_ERROR( + InstantiateHandlerState(target_name, execution_state.get())); } + return absl::WrapUnique(new CustomCallThunk( std::move(info), target_name, std::move(op_buffers), api_version, std::move(backend_config), std::move(call_frame))); diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h index 901545fa9f5d1f..bfea5368f7cb9b 100644 --- a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ #include #include @@ -25,9 +25,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/ffi/call_frame.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/custom_call_status.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -80,4 +80,4 @@ class CustomCallThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc index c92307c52f064c..418ed65ce1cbb2 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" #include #include @@ -30,10 +30,10 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk.h b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.h similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk.h index acaa94d5bf7779..61bcb8194e1150 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_DOT_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_DOT_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_DOT_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_DOT_THUNK_H_ #define EIGEN_USE_THREADS @@ -29,9 +29,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "Eigen/Core" #include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -175,4 +175,4 @@ DOT_THUNK_EXTERN_MATMUL_TEMPLATE(std::complex); } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_DOT_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_DOT_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_c128.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_c128.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_c128.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_c128.cc index cd2852e26aa980..1c791bd6fac78c 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_c128.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_c128.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul>( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_c64.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_c64.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_c64.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_c64.cc index 55f21cceb344bf..957e2d6d855630 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_c64.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_c64.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul>( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f16.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f16.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_f16.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_f16.cc index df04b0d1272a1a..35d85c89154187 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f16.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f16.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f32.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f32.cc similarity index 93% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_f32.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_f32.cc index d98c5d940ed3b1..f3aee5501ac413 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f32.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f32.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) #include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f64.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f64.cc similarity index 91% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_f64.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_f64.cc index f782cc7045ff7e..bcb8bd676af8db 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f64.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f64.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_s32.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_s32.cc similarity index 91% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_s32.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_s32.cc index 59186ec8a5669a..0851e01b539c0a 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_s32.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_s32.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/fft_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/fft_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc index 5d792c2fc8c163..b7c898b26d177c 100644 --- a/third_party/xla/xla/service/cpu/runtime/fft_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/fft_thunk.h" +#include "xla/backends/cpu/runtime/fft_thunk.h" #include #include @@ -21,10 +21,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime_fft.h" #include "xla/service/cpu/runtime_single_threaded_fft.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/cpu/runtime/fft_thunk.h b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.h similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/fft_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/fft_thunk.h index b63ed5e9b744e7..64d4063d828cf7 100644 --- a/third_party/xla/xla/service/cpu/runtime/fft_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_FFT_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_FFT_THUNK_H_ #include #include @@ -22,8 +22,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -68,4 +68,4 @@ class FftThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_FFT_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/infeed_thunk.cc index 9e8acff4ecb271..e1a601565c69d3 100644 --- a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/infeed_thunk.h" +#include "xla/backends/cpu/runtime/infeed_thunk.h" #include #include @@ -24,10 +24,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.h b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk.h similarity index 87% rename from third_party/xla/xla/service/cpu/runtime/infeed_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/infeed_thunk.h index 622046f2e3785d..1d4225d1ddd008 100644 --- a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_INFEED_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_INFEED_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_INFEED_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_INFEED_THUNK_H_ #include #include #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -61,4 +61,4 @@ class InfeedThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_INFEED_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_INFEED_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk_test.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/infeed_thunk_test.cc index 53394e242c56a0..3bbb4272f22834 100644 --- a/third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/infeed_thunk.h" +#include "xla/backends/cpu/runtime/infeed_thunk.h" #include +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc similarity index 70% rename from third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc index 5ab801cc42aab2..4656bf8ef73a39 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/kernel_thunk.h" +#include "xla/backends/cpu/runtime/kernel_thunk.h" #define EIGEN_USE_THREADS @@ -25,9 +25,12 @@ limitations under the License. #include #include #include +#include +#include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/base/optimization.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" @@ -35,10 +38,10 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/host/host_kernel_c_api.h" @@ -109,9 +112,11 @@ template KernelThunk::KernelThunk( Info info, absl::Span arguments_buffers, absl::Span results_buffers, + absl::flat_hash_set invariant_buffers, std::string kernel_name, se::ThreadDim thread_dim, std::optional min_alignment) : Thunk(Kind::kKernel, std::move(info)), + invariant_buffers_(std::move(invariant_buffers)), num_kernel_args_(arguments_buffers.size() + results_buffers.size()), kernel_name_(std::move(kernel_name)), thread_dim_(thread_dim), @@ -192,10 +197,13 @@ KernelThunk::ExecuteInternal( VlogKernelArgs(arguments_buffers_, results_buffers_, kernel_args); } - // Сheck that all resolved buffers are properly aligned. + // Сheck that all resolved buffers are properly aligned, and that invariant + // property holds. if constexpr (ShouldCheckBufferSlices()) { TF_RETURN_IF_ERROR( CheckBufferAlignment(info(), min_alignment_.value_or(0), kernel_args)); + TF_RETURN_IF_ERROR(CheckInvariantBufferSlices()); + TF_RETURN_IF_ERROR(CheckInvariantBuffersMemory(*allocations)); } // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk @@ -235,6 +243,106 @@ KernelThunk::ExecuteInternal( return OkExecuteEvent(); } +template +absl::Status +KernelThunk::CheckInvariantBufferSlices() const { + // We can use absl::c_contains here when we have C++20 support. + // TODO(abanas): Check for overlapping buffers. + auto contains = [](const auto& container, + const BufferAllocation::Slice& buffer) { + return absl::c_find(container, buffer) != container.end(); + }; + + // Verify all argument buffers. + for (const BufferAllocation::Slice& buffer : arguments_buffers_) { + if (invariant_buffers_.contains(buffer)) { + // This argument should be read only, i.e. not one of the results. + if (contains(results_buffers_, buffer)) { + return Internal( + "Mismatch in invariant buffers metadata, invariant buffer %s " + "should not be one of the results", + buffer.ToString()); + } + } else { + // For completeness, we check that a read write buffer is one of the + // results. + if (!contains(results_buffers_, buffer)) { + return Internal( + "Mismatch in invariant buffers metadata, read-write buffer %s " + "is not one of the results", + buffer.ToString()); + } + } + } + + // Verify that there are no extra buffers in invariant buffers set. + for (auto& buffer : invariant_buffers_) { + if (!contains(arguments_buffers_, buffer)) { + return Internal( + "Mismatch in invariant buffers metadata, unknown buffer found: %s", + buffer.ToString()); + } + } + return absl::OkStatus(); +} + +// TODO(abanas): Return absl::flat_hash_set. This requires implementing a hash +// function for DeviceMemoryBase. +template +static absl::StatusOr> ToDeviceMemorySet( + const Iterable& buffers, const BufferAllocations& allocations) { + std::vector result; + for (const BufferAllocation::Slice& slice : buffers) { + TF_ASSIGN_OR_RETURN(auto memory, allocations.GetDeviceAddress(slice)); + result.push_back(std::move(memory)); + } + return result; +} + +// The logic here is similar to CheckInvariantBufferSlices, but we check +// memory addresses instead of buffer slices. +template +absl::Status +KernelThunk::CheckInvariantBuffersMemory( + const BufferAllocations& allocations) const { + // We can use absl::c_contains here when we have C++20 support. + auto contains = [](const std::vector& container, + const se::DeviceMemoryBase& memory) { + return absl::c_find(container, memory) != container.end(); + }; + + TF_ASSIGN_OR_RETURN(auto results_memory_set, + ToDeviceMemorySet(results_buffers_, allocations)); + TF_ASSIGN_OR_RETURN(auto invariant_memory_set, + ToDeviceMemorySet(invariant_buffers_, allocations)); + + // Verify all argument buffers. + for (const BufferAllocation::Slice& argument_slice : arguments_buffers_) { + TF_ASSIGN_OR_RETURN(auto argument_memory, + allocations.GetDeviceAddress(argument_slice)); + if (contains(invariant_memory_set, argument_memory)) { + // This argument should be read only, i.e. not one of the results. + if (contains(results_memory_set, argument_memory)) { + return Internal( + "Mismatch in invariant buffers metadata, device memory of " + "invariant buffer %s should not be one of the results", + argument_slice.ToString()); + } + } else { + // For completeness, we check that a read write buffer is one of the + // results. + if (!contains(results_memory_set, argument_memory)) { + return Internal( + "Mismatch in invariant buffers metadata, device memory of " + "read-write buffer %s is not one of the results", + argument_slice.ToString()); + } + } + } + + return absl::OkStatus(); +} + template Thunk::BufferUses KernelThunk::buffer_uses() const { return KernelBufferUses(arguments_buffers_, results_buffers_); @@ -259,6 +367,7 @@ absl::StatusOr> KernelThunk::Create( absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, + absl::flat_hash_set invariant_buffers, std::optional min_alignment) { if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { return Internal("Host kernel %s minimum alignment %d is not a power of 2", @@ -269,7 +378,8 @@ absl::StatusOr> KernelThunk::Create( return absl::WrapUnique( new SmallKernelThunk( std::move(info), arguments_buffers, results_buffers, - std::move(kernel_name), thread_dim, min_alignment)); + std::move(invariant_buffers), std::move(kernel_name), thread_dim, + min_alignment)); }; static constexpr auto _0 = std::integral_constant{}; @@ -295,7 +405,8 @@ absl::StatusOr> KernelThunk::Create( // Return a generic KernelThunk for dynamic numbers of arguments and results. return absl::WrapUnique( new KernelThunk(std::move(info), arguments_buffers, results_buffers, - std::move(kernel_name), thread_dim, min_alignment)); + std::move(invariant_buffers), std::move(kernel_name), + thread_dim, min_alignment)); } } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h similarity index 89% rename from third_party/xla/xla/service/cpu/runtime/kernel_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h index 134602f99537b5..fd0567ae1e62e9 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_KERNEL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_KERNEL_THUNK_H_ #include #include @@ -28,12 +28,14 @@ limitations under the License. #include #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/host/host_kernel_c_api.h" #include "xla/stream_executor/launch_dim.h" @@ -94,12 +96,20 @@ class KernelThunk : public Thunk { KernelThunk(Info info, absl::Span arguments_buffers, absl::Span results_buffers, + absl::flat_hash_set invariant_buffers, std::string kernel_name, se::ThreadDim thread_dim, std::optional min_alignment); + absl::Status CheckInvariantBufferSlices() const; + + absl::Status CheckInvariantBuffersMemory( + const BufferAllocations& buffer_allocations) const; + ArgumentsBuffers arguments_buffers_; ResultsBuffers results_buffers_; + absl::flat_hash_set invariant_buffers_; + size_t num_kernel_args_; std::string kernel_name_; @@ -149,6 +159,7 @@ class KernelThunk final : public internal::KernelThunk<> { absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, + absl::flat_hash_set invariant_buffers, std::optional min_alignment = std::nullopt); tsl::AsyncValueRef Execute( @@ -157,4 +168,4 @@ class KernelThunk final : public internal::KernelThunk<> { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_KERNEL_THUNK_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc new file mode 100644 index 00000000000000..1599694f8c7896 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc @@ -0,0 +1,294 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/kernel_thunk.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +class AddF32HostKernel : public Thunk::FunctionRegistry { + public: + absl::StatusOr FindKernel(std::string_view name) override { + return +[](const SE_HOST_KernelCallFrame* call_frame) { + const SE_HOST_KernelArg& in = call_frame->args[0]; + const SE_HOST_KernelArg& out = call_frame->args[1]; + + float* in_ptr = reinterpret_cast(in.data); + float* out_ptr = reinterpret_cast(out.data); + + uint64_t i = call_frame->thread->x; + *(out_ptr + i) = *(in_ptr + i) + *(in_ptr + i); + + return static_cast(nullptr); + }; + } +}; + +TEST(KernelThunkTest, CheckAlignment) { + auto thunk = + KernelThunk::Create({"test"}, {}, {}, "test", se::ThreadDim(), {}, + /*min_alignment=*/3); + EXPECT_TRUE(absl::StrContains(thunk.status().message(), + "minimum alignment 3 is not a power of 2")); +} + +TEST(KernelThunkTest, AddF32) { + std::vector buffers; + std::vector in = {1.0, 2.0, 3.0, 4.0}; + std::vector out(4, 0.0); + + size_t size_in_bytes = in.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation in_alloc(0, size_in_bytes, 0); + BufferAllocation out_alloc(1, size_in_bytes, 0); + + BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); + BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, + "add_f32", se::ThreadDim(4), {in_slice})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError(); + + std::vector expected = {2.0, 4.0, 6.0, 8.0}; + EXPECT_EQ(out, expected); +} + +TEST(KernelThunkTest, AddF32Inline) { + std::vector buffers; + std::vector in_out = {1.0, 2.0, 3.0, 4.0}; + + size_t size_in_bytes = in_out.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + BufferAllocation in_out_alloc(0, size_in_bytes, 0); + BufferAllocation::Slice in_out_slice(&in_out_alloc, 0, size_in_bytes); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + KernelThunk::Create({"add_f32"}, {in_out_slice}, {in_out_slice}, + "add_f32", se::ThreadDim(4), {})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + std::vector expected = {2.0, 4.0, 6.0, 8.0}; + EXPECT_EQ(in_out, expected); +} + +TEST(KernelThunkInvariantBuffersTest, MissingBufferSlice) { +#ifdef NDEBUG + GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; +#endif + + std::vector buffers; + std::vector in = {1.0, 2.0, 3.0, 4.0}; + std::vector out(4, 0.0); + + size_t size_in_bytes = in.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation in_alloc(0, size_in_bytes, 0); + BufferAllocation out_alloc(1, size_in_bytes, 0); + + BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); + BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); + + // Invariant buffer set is incorrect - should include in_slice, but is empty. + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, "add_f32", + se::ThreadDim(4), /*invariant_buffers=*/{})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsError()); + + auto status = execute_event.GetError(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status.message(), + "Mismatch in invariant buffers metadata")); +} + +TEST(KernelThunkInvariantBuffersTest, ExtraInputOutputBufferSlice) { +#ifdef NDEBUG + GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; +#endif + + std::vector buffers; + std::vector in_out = {1.0, 2.0, 3.0, 4.0}; + + size_t size_in_bytes = in_out.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + BufferAllocation in_out_alloc(0, size_in_bytes, 0); + BufferAllocation::Slice in_out_slice(&in_out_alloc, 0, size_in_bytes); + + // Invariant buffer set is incorrect - should be empty, but contains input + // buffer that's not invariant. + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, KernelThunk::Create( + {"add_f32"}, {in_out_slice}, {in_out_slice}, "add_f32", + se::ThreadDim(4), /*invariant_buffers=*/{in_out_slice})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsError()); + + auto status = execute_event.GetError(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status.message(), + "Mismatch in invariant buffers metadata")); +} + +TEST(KernelThunkInvariantBuffersTest, ExtraIncorrectBufferSlice) { +#ifdef NDEBUG + GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; +#endif + + std::vector buffers; + std::vector in = {1.0, 2.0, 3.0, 4.0}; + std::vector out(4, 0.0); + std::vector unrelated(4, 0.0); + + size_t size_in_bytes = in.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(unrelated.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation in_alloc(0, size_in_bytes, 0); + BufferAllocation out_alloc(1, size_in_bytes, 0); + BufferAllocation unrelated_alloc(2, size_in_bytes, 0); + + BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); + BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); + BufferAllocation::Slice unrelated_slice(&unrelated_alloc, 0, size_in_bytes); + + // Invariant buffer set contains all invariant buffers, but still it is + // incorrect - it contains a buffer that's unrelated to the kernel. + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, "add_f32", + se::ThreadDim(4), + /*invariant_buffers=*/{in_slice, unrelated_slice})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsError()); + + auto status = execute_event.GetError(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status.message(), + "Mismatch in invariant buffers metadata")); +} + +// This case should never happen in practice, it simulates a bug in the code +// that incorrectly sets up aliases. +TEST(KernelThunkInvariantBuffersTest, + MemorySectionIncorrectlyMarkedAsInvariant) { +#ifdef NDEBUG + GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; +#endif + + // We've got only one memory section + std::vector buffers; + std::vector in_out = {1.0, 2.0, 3.0, 4.0}; + + // We've got two buffer slices with different indexes, but both pointing to + // the same memory section. + size_t size_in_bytes = in_out.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation in_0_alloc(0, size_in_bytes, 0); + BufferAllocation in_1_alloc(1, size_in_bytes, 0); + + BufferAllocation::Slice in_0_slice(&in_0_alloc, 0, size_in_bytes); + BufferAllocation::Slice in_1_slice(&in_1_alloc, 0, size_in_bytes); + + // Invariant buffer set is incorrect. in_1_slice is not aliased to any output, + // but it points to the same memory section as in_0_slice (which is not + // invariant, because is aliased with the output). + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, KernelThunk::Create({"add_f32"}, {in_0_slice, in_1_slice}, + {in_0_slice}, "add_f32", se::ThreadDim(4), + /*invariant_buffers=*/{in_1_slice})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsError()); + + auto status = execute_event.GetError(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status.message(), + "Mismatch in invariant buffers metadata")); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.cc index 61c8f4c801db32..ace52302dc953d 100644 --- a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/logical_id_thunk.h" +#include "xla/backends/cpu/runtime/logical_id_thunk.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.h similarity index 90% rename from third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.h index bb4d2fd12840ff..6a42fe69963d1a 100644 --- a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -68,4 +68,4 @@ class PartitionIdThunk final } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc index 72ce59f85dad5c..c8dd0a60782fed 100644 --- a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/logical_id_thunk.h" +#include "xla/backends/cpu/runtime/logical_id_thunk.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/executable_run_options.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.cc index a56ae0c437a7ec..b541953a403dee 100644 --- a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/outfeed_thunk.h" +#include "xla/backends/cpu/runtime/outfeed_thunk.h" #include #include @@ -23,10 +23,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.h b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.h similarity index 87% rename from third_party/xla/xla/service/cpu/runtime/outfeed_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.h index ff05339002ffc5..74920899255d46 100644 --- a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_OUTFEED_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_OUTFEED_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_OUTFEED_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_OUTFEED_THUNK_H_ #include #include #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -60,4 +60,4 @@ class OutfeedThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_OUTFEED_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_OUTFEED_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk_test.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/outfeed_thunk_test.cc index 2c6b9b9a91123f..0139a95f777e47 100644 --- a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/outfeed_thunk.h" +#include "xla/backends/cpu/runtime/outfeed_thunk.h" #include +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc index 701ac3243ebd90..920aa3dc545b19 100644 --- a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/reduce_scatter_thunk.h" +#include "xla/backends/cpu/runtime/reduce_scatter_thunk.h" #include #include @@ -24,12 +24,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.h similarity index 86% rename from third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.h index d37e1b22db5566..104d6c354dfa88 100644 --- a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -44,4 +44,4 @@ class ReduceScatterThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/resource_use.cc b/third_party/xla/xla/backends/cpu/runtime/resource_use.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/resource_use.cc rename to third_party/xla/xla/backends/cpu/runtime/resource_use.cc index 3e5ceabb9ac53a..a3c03849b5178a 100644 --- a/third_party/xla/xla/service/cpu/runtime/resource_use.cc +++ b/third_party/xla/xla/backends/cpu/runtime/resource_use.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/resource_use.h" #include diff --git a/third_party/xla/xla/service/cpu/runtime/resource_use.h b/third_party/xla/xla/backends/cpu/runtime/resource_use.h similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/resource_use.h rename to third_party/xla/xla/backends/cpu/runtime/resource_use.h index 6ee1f1bfd6ac95..1442a2895a02bf 100644 --- a/third_party/xla/xla/service/cpu/runtime/resource_use.h +++ b/third_party/xla/xla/backends/cpu/runtime/resource_use.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ -#define XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_RESOURCE_USE_H_ +#define XLA_BACKENDS_CPU_RUNTIME_RESOURCE_USE_H_ #include #include @@ -111,4 +111,4 @@ class ResourceUse { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_RESOURCE_USE_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/resource_use_test.cc b/third_party/xla/xla/backends/cpu/runtime/resource_use_test.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/resource_use_test.cc rename to third_party/xla/xla/backends/cpu/runtime/resource_use_test.cc index 4d3c9bbaf4cecc..dd5115bcaf2ae5 100644 --- a/third_party/xla/xla/service/cpu/runtime/resource_use_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/resource_use_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/resource_use.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/runtime/rng_state_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/rng_state_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.cc index df611bd5fe169f..39a3de9b9429dc 100644 --- a/third_party/xla/xla/service/cpu/runtime/rng_state_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/rng_state_thunk.h" +#include "xla/backends/cpu/runtime/rng_state_thunk.h" #include #include @@ -26,8 +26,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/rng_state_thunk.h b/third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.h similarity index 89% rename from third_party/xla/xla/service/cpu/runtime/rng_state_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.h index 9798ed7c105f4b..d00bf4523e5dea 100644 --- a/third_party/xla/xla/service/cpu/runtime/rng_state_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_RNG_STATE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_RNG_STATE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_RNG_STATE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_RNG_STATE_THUNK_H_ #include #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/numeric/int128.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" namespace xla::cpu { @@ -56,4 +56,4 @@ class RngGetAndUpdateStateThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_RNG_STATE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_RNG_STATE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/sort_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc index 041bf030d52abe..8d2df6f298cbcf 100644 --- a/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/sort_thunk.h" +#include "xla/backends/cpu/runtime/sort_thunk.h" #include #include @@ -38,11 +38,11 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/primitive_util.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/service/cpu/runtime/sort_thunk.h b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h similarity index 93% rename from third_party/xla/xla/service/cpu/runtime/sort_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/sort_thunk.h index 049fa062cff603..a1c2b5eda242ee 100644 --- a/third_party/xla/xla/service/cpu/runtime/sort_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_SORT_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_SORT_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_SORT_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_SORT_THUNK_H_ #include #include @@ -28,8 +28,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -81,4 +81,4 @@ class SortThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_SORT_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_SORT_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/sort_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/sort_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc index 4c7b2514a1c709..1f450f77548d70 100644 --- a/third_party/xla/xla/service/cpu/runtime/sort_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/sort_thunk.h" +#include "xla/backends/cpu/runtime/sort_thunk.h" #include #include @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/backends/cpu/runtime/thunk.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/thunk.cc index 9228de3d5f156a..41a02a5ca3a413 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include #include diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/backends/cpu/runtime/thunk.h similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/thunk.h rename to third_party/xla/xla/backends/cpu/runtime/thunk.h index 9141da74628691..cfc60597e6ac6d 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_THUNK_H_ #include #include @@ -31,12 +31,12 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/resource_use.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/runtime/buffer_use.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/resource_use.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/host/host_kernel_c_api.h" @@ -379,4 +379,4 @@ class ThunkSequence : public std::vector> { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/thunk_executor.cc rename to third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc index 9b4c735703fdc1..eb32b508b3a1b1 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/thunk_executor.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include #include @@ -32,9 +32,9 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/thunk_executor.h rename to third_party/xla/xla/backends/cpu/runtime/thunk_executor.h index f0df6cfafb3d8a..5ba15b0432b504 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_ -#define XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_THUNK_EXECUTOR_H_ +#define XLA_BACKENDS_CPU_RUNTIME_THUNK_EXECUTOR_H_ #include #include @@ -33,7 +33,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/service/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -259,4 +259,4 @@ class ThunkExecutor { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_THUNK_EXECUTOR_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc rename to third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc index 60996ebd7ed61b..ebe98304b9f6f4 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/thunk_executor.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include #include @@ -30,11 +30,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_test.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/thunk_test.cc index 3b975750be6f1d..1b20de023d91f8 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include #include diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_testlib.h b/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h similarity index 88% rename from third_party/xla/xla/service/cpu/runtime/thunk_testlib.h rename to third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h index 154c2b28972701..4da0650efee7c4 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_testlib.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ -#define XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_THUNK_TESTLIB_H_ +#define XLA_BACKENDS_CPU_RUNTIME_THUNK_TESTLIB_H_ #include "absl/status/status.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -59,4 +59,4 @@ class ResourceUseThunk : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_THUNK_TESTLIB_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/topk_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/topk_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/topk_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/topk_thunk.cc index 6c238224166a52..0c72933dc1a3aa 100644 --- a/third_party/xla/xla/service/cpu/runtime/topk_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/topk_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/topk_thunk.h" +#include "xla/backends/cpu/runtime/topk_thunk.h" #include #include @@ -21,8 +21,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime_topk.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/topk_thunk.h b/third_party/xla/xla/backends/cpu/runtime/topk_thunk.h similarity index 90% rename from third_party/xla/xla/service/cpu/runtime/topk_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/topk_thunk.h index 7b2bfb63502bfe..7e7fadb03852e7 100644 --- a/third_party/xla/xla/service/cpu/runtime/topk_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/topk_thunk.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_TOPK_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_TOPK_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_TOPK_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_TOPK_THUNK_H_ #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -56,4 +56,4 @@ class TopKThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_TOPK_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_TOPK_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/while_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/while_thunk.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/while_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/while_thunk.cc index 486a0b93e72f58..6c1e81f5dee0d6 100644 --- a/third_party/xla/xla/service/cpu/runtime/while_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/while_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/while_thunk.h" +#include "xla/backends/cpu/runtime/while_thunk.h" #include #include @@ -26,11 +26,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/service/cpu/runtime/while_thunk.h b/third_party/xla/xla/backends/cpu/runtime/while_thunk.h similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/while_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/while_thunk.h index e631e54842a52a..c1de07de86ad52 100644 --- a/third_party/xla/xla/service/cpu/runtime/while_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/while_thunk.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_WHILE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_WHILE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_WHILE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_WHILE_THUNK_H_ #include #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -83,4 +83,4 @@ class WhileThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_WHILE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_WHILE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc index fc6a32c8bd715e..d4b874a72b380f 100644 --- a/third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/while_thunk.h" +#include "xla/backends/cpu/runtime/while_thunk.h" #include #include @@ -22,12 +22,12 @@ limitations under the License. #include #include +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_testlib.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/backends/interpreter/platform.cc b/third_party/xla/xla/backends/interpreter/platform.cc index 8b77eb1c801101..0b5756d4e3e175 100644 --- a/third_party/xla/xla/backends/interpreter/platform.cc +++ b/third_party/xla/xla/backends/interpreter/platform.cc @@ -47,31 +47,27 @@ XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const { return XlaInterpreterExecutor::CreateDeviceDescription(ordinal); } -absl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( +absl::StatusOr XlaInterpreterPlatform::FindExisting( int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); + return executor_cache_.Get(ordinal); } -absl::StatusOr XlaInterpreterPlatform::GetExecutor( - const StreamExecutorConfig& config) { +absl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( + int ordinal) { return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> -XlaInterpreterPlatform::GetUncachedExecutor( - const StreamExecutorConfig& config) { - auto executor = - std::make_unique(config.ordinal, this); +XlaInterpreterPlatform::GetUncachedExecutor(int ordinal) { + auto executor = std::make_unique(ordinal, this); auto init_status = executor->Init(); if (!init_status.ok()) { return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString())}; + ordinal, init_status.ToString())}; } return std::move(executor); diff --git a/third_party/xla/xla/backends/interpreter/platform.h b/third_party/xla/xla/backends/interpreter/platform.h index a2a7690e1d3f8f..50a69504ae0139 100644 --- a/third_party/xla/xla/backends/interpreter/platform.h +++ b/third_party/xla/xla/backends/interpreter/platform.h @@ -47,14 +47,13 @@ class XlaInterpreterPlatform : public Platform { absl::StatusOr ExecutorForDevice(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; + absl::StatusOr FindExisting(int ordinal) override; - // Returns a device constructed with the options specified in "config" without + // Returns a device constructed with ordinal without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config); + int ordinal); private: // This platform's name. diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index 98ac9b38010be7..7a822f518e2921 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -311,7 +311,10 @@ tsl_gpu_library( "@local_tsl//tsl/profiler/utils:xplane_builder", "@local_tsl//tsl/profiler/utils:xplane_schema", "@local_tsl//tsl/profiler/utils:xplane_utils", - ] + if_cuda(["//xla/tsl/cuda:cupti"]), + ] + if_cuda([ + "//xla/tsl/cuda:cupti", + "//xla/tsl/cuda", + ]), ) tsl_gpu_library( diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index 28f6c274bd3e42..7e2cc0dea11a71 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -227,6 +227,7 @@ cc_library( xla_test( name = "math_test", + timeout = "long", srcs = ["math_test.cc"], backend_tags = { # Times out. diff --git a/third_party/xla/xla/core/host_offloading/README.md b/third_party/xla/xla/core/host_offloading/README.md new file mode 100644 index 00000000000000..22f6449bce3b09 --- /dev/null +++ b/third_party/xla/xla/core/host_offloading/README.md @@ -0,0 +1,8 @@ +# XLA Host Offloading + +XLA host offloading allows us to run part of the HLO module on the host attached +to the accelerator device (TPU or GPU) using the XLA:CPU compiler. On JAX side +it is available as `jax.experimental.compute_on` API. + +With `compute_on` annotation, JAX + XLA can be used to implement +[ZeRO-Offload](https://arxiv.org/abs/2101.06840) host offloading. \ No newline at end of file diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 890bd25f205558..8c7603ab7a7bfa 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/debug_options_parsers.h" #include "xla/parse_flags_from_env.h" +#include "xla/stream_executor/cuda/nvjitlink_support.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" #include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" @@ -166,9 +167,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_highest_priority_async_stream(true); opts.set_xla_gpu_enable_pipelined_collectives(false); - opts.set_xla_gpu_enable_pipelined_all_reduce(false); - opts.set_xla_gpu_enable_pipelined_all_gather(false); - opts.set_xla_gpu_enable_pipelined_reduce_scatter(false); + opts.set_xla_gpu_enable_pipelined_all_reduce(true); + opts.set_xla_gpu_enable_pipelined_all_gather(true); + opts.set_xla_gpu_enable_pipelined_reduce_scatter(true); opts.set_xla_gpu_enable_pipelined_p2p(false); opts.set_xla_gpu_run_post_layout_collective_pipeliner(false); @@ -231,13 +232,11 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_triton_hopper(false); - // We disable this until b/319271534 is fixed due to errors during linking. - // - // TODO(b/319271534): Re-enable once we use libnvjitlink. opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); - - opts.set_xla_gpu_enable_libnvptxcompiler(false); - opts.set_xla_gpu_enable_libnvjitlink(false); + opts.set_xla_gpu_enable_libnvptxcompiler( + stream_executor::IsLibNvPtxCompilerSupported()); + opts.set_xla_gpu_enable_libnvjitlink( + stream_executor::IsLibNvJitLinkSupported()); opts.set_xla_gpu_enable_dot_strength_reduction(true); @@ -272,7 +271,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_terminate_on_error(false); - opts.set_xla_gpu_shard_autotuning(false); + opts.set_xla_gpu_shard_autotuning(true); opts.set_xla_syntax_sugar_async_ops(false); @@ -284,6 +283,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cudnn_gemm_max_plans(5); + opts.set_xla_gpu_enable_triton_gemm_int4(false); return opts; } @@ -1447,8 +1447,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "xla_gpu_enable_pipelined_collectives", bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_collectives), debug_options->xla_gpu_enable_pipelined_collectives(), - "Enable pipelinling of collective instructions (all-reduce, all-gather, " - "and reduce-scatter).")); + "Enable pipelinling of collective instructions. It has the same effect " + "as setting xla_gpu_enable_pipelined_all_reduce, " + "xla_gpu_enable_pipelined_all_gather, " + "xla_gpu_enable_pipelined_reduce_scatter and " + "xla_gpu_enable_pipelined_p2p flags to true.")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_pipelined_all_reduce", bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_all_reduce), @@ -1843,6 +1846,17 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_cudnn_gemm_max_plans(), "Limit for the number of kernel configurations (plans) to use during " "autotuning of cuDNN GEMM fusions.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_triton_gemm_int4", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_gemm_int4), + debug_options->xla_gpu_enable_triton_gemm_int4(), + "Experimental: Enable Triton gemm for int4 inputs.")); + flag_list->push_back( + tsl::Flag("xla_gpu_async_dot", + bool_setter_for(&DebugOptions::set_xla_gpu_async_dot), + debug_options->xla_gpu_async_dot(), + "Wrap `dot` operations into async computations in an effort to " + "parallelize matrix operations.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/examples/axpy/README.md b/third_party/xla/xla/examples/axpy/README.md index 397dd21c8fb6d8..39bacfb18c5659 100644 --- a/third_party/xla/xla/examples/axpy/README.md +++ b/third_party/xla/xla/examples/axpy/README.md @@ -72,10 +72,8 @@ LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); // PlatformUtil::GetPlatform("CUDA")); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, PlatformUtil::GetPlatform("cpu")); -se::StreamExecutorConfig config; -config.ordinal = 0; TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, - platform->GetExecutor(config)); + platform->ExecutorForDevice(0)); // LocalDeviceState and PjRtStreamExecutorDevice describes the state of a // device which can do computation or transfer buffers. Could represent a GPU diff --git a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc index 897a1e953d20f8..49a99ee88a679c 100644 --- a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc +++ b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc @@ -62,10 +62,8 @@ TEST(StableHloAxpyTest, LoadAndRunCpuExecutable) { // PlatformUtil::GetPlatform("CUDA")); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, PlatformUtil::GetPlatform("cpu")); - se::StreamExecutorConfig config; - config.ordinal = 0; TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, - platform->GetExecutor(config)); + platform->ExecutorForDevice(/*ordinal=*/0)); // LocalDeviceState and PjRtStreamExecutorDevice describes the state of a // device which can do computation or transfer buffers. This could represent a diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 082189ed765cdf..dccbcc60d25936 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -231,16 +231,17 @@ class Ffi { template static std::string StrCat(Args... args); - static inline XLA_FFI_Error* MakeError(const XLA_FFI_Api* api, - XLA_FFI_Error_Code errc, - std::string message); + static XLA_FFI_Error* Sucess(); - static inline XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api, - std::string message); + static XLA_FFI_Error* MakeError(const XLA_FFI_Api* api, + XLA_FFI_Error_Code errc, std::string message); - static inline XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api, - std::string_view struct_name, - size_t expected, size_t actual); + static XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api, + std::string message); + + static XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api, + std::string_view struct_name, + size_t expected, size_t actual); }; XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api, @@ -265,8 +266,11 @@ std::string Ffi::StrCat(Args... args) { return ss.str(); } -XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, XLA_FFI_Error_Code errc, - std::string message) { +inline XLA_FFI_Error* Ffi::Sucess() { return nullptr; } + +inline XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, + XLA_FFI_Error_Code errc, + std::string message) { XLA_FFI_Error_Create_Args args; args.struct_size = XLA_FFI_Error_Create_Args_STRUCT_SIZE; args.priv = nullptr; @@ -275,15 +279,15 @@ XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, XLA_FFI_Error_Code errc, return api->XLA_FFI_Error_Create(&args); } -XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api, - std::string message) { +inline XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api, + std::string message) { return MakeError(api, XLA_FFI_Error_Code_INVALID_ARGUMENT, std::move(message)); } -XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api, - std::string_view struct_name, - size_t expected, size_t actual) { +inline XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api, + std::string_view struct_name, + size_t expected, size_t actual) { if (expected != actual) { return InvalidArgument( api, StrCat("Unexpected ", struct_name, " size: expected ", expected, @@ -305,12 +309,13 @@ namespace internal { // parameter packs. We need this to be able to pattern match FFI handler // signature at compile time. +// A type tag for decoding optional argument. +template +struct OptionalArgTag {}; + // A type tag to forward all remaining args as `RemainingArgs`. struct RemainingArgsTag {}; -// A type tag to forward all remaining results as `RemainingRets`. -struct RemainingRetsTag {}; - // A type tag to distinguish parameters tied to results in the `Binding` // variadic template. In XLA FFI we use destination passing style APIs and don't // return anything from the handler, but instead pass a destination where the @@ -318,6 +323,13 @@ struct RemainingRetsTag {}; template struct RetTag {}; +// A type tag for decoding optional result. +template +struct OptionalRetTag {}; + +// A type tag to forward all remaining results as `RemainingRets`. +struct RemainingRetsTag {}; + // A type tag to distinguish parameters tied to the attributes in the // `Binding` variadic template. template @@ -356,12 +368,30 @@ struct NumTagged { //----------------------------------------------------------------------------// -// Checks if remaining arguments are in the parameter pack. +template +struct IsOptionalArgTag : std::false_type {}; +template +struct IsOptionalArgTag> : std::true_type {}; + +template +struct IsOptionalRetTag : std::false_type {}; +template +struct IsOptionalRetTag> : std::true_type {}; + +// Checks if parameter pack has an optional argument. +template +using HasOptionalArgTag = std::disjunction...>; + +// Checks if parameter pack has remaining arguments. template using HasRemainingArgsTag = std::disjunction...>; -// Checks if remaining results are in the parameter pack. +// Checks if parameter pack has an optional result. +template +using HasOptionalRetTag = std::disjunction...>; + +// Checks if parameter pack has remaining results. template using HasRemainingRetsTag = std::disjunction...>; @@ -412,11 +442,34 @@ class Binding { public: template Binding Arg() && { + static_assert(!internal::HasOptionalArgTag::value, + "argument can't be passed after optional argument"); + static_assert(!internal::HasRemainingArgsTag::value, + "argument can't be passed after remaining arguments"); return {std::move(*this)}; } template Binding> Ret() && { + static_assert(!internal::HasOptionalRetTag::value, + "result can't be passed after optional result"); + static_assert(!internal::HasRemainingRetsTag::value, + "result can't be passed after remaining results"); + return {std::move(*this)}; + } + + template + Binding> OptionalArg() && { + static_assert( + !internal::HasRemainingArgsTag::value, + "optional argument can't be passed after remaining arguments"); + return {std::move(*this)}; + } + + template + Binding> OptionalRet() && { + static_assert(!internal::HasRemainingRetsTag::value, + "optional result can't be passed after remaining results"); return {std::move(*this)}; } @@ -899,10 +952,20 @@ struct Decode { } }; -} // namespace internal +template +struct Decode> { + static std::optional> call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + if (offsets.args >= ctx.call_frame->args.size) { + return std::optional(std::nullopt); + } + return Decode::call(offsets, ctx, diagnostic); + } +}; template -struct internal::Decode> { +struct Decode> { static std::optional> call(DecodingOffsets& offsets, DecodingContext& ctx, DiagnosticEngine& diagnostic) { @@ -913,7 +976,19 @@ struct internal::Decode> { }; template -struct internal::Decode> { +struct Decode> { + static std::optional>> call( + DecodingOffsets& offsets, DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + if (offsets.rets >= ctx.call_frame->rets.size) { + return std::optional>(std::nullopt); + } + return Decode>::call(offsets, ctx, diagnostic); + } +}; + +template +struct Decode> { using R = typename AttrDecoding::Type; static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, @@ -945,7 +1020,7 @@ struct internal::Decode> { }; template -struct internal::Decode> { +struct Decode> { using R = typename CtxDecoding::Type; static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, @@ -955,6 +1030,8 @@ struct internal::Decode> { } }; +} // namespace internal + //===----------------------------------------------------------------------===// // Type-safe wrapper for accessing a variable number of arguments. //===----------------------------------------------------------------------===// @@ -1099,23 +1176,31 @@ struct FnArgType { using Type = T; }; -template <> -struct FnArgType { - using Type = RemainingArgs; +template +struct FnArgType> { + using Type = std::optional; }; template <> -struct FnArgType { - using Type = RemainingRets; +struct FnArgType { + using Type = RemainingArgs; }; -// Extracts the underlying type from the returned result type tag. template struct FnArgType> { using Type = Result; }; -// Extracts the underlying type from the attribute type tag. +template +struct FnArgType> { + using Type = std::optional>; +}; + +template <> +struct FnArgType { + using Type = RemainingRets; +}; + template struct FnArgType> { using Type = typename AttrDecoding::Type; @@ -1126,7 +1211,6 @@ struct FnArgType> { using Type = T; }; -// Extracts the underlying type from the context type tag. template struct FnArgType> { using Type = typename CtxDecoding::Type; @@ -1136,20 +1220,27 @@ struct FnArgType> { // a special decoding rule defined by template specialization. template struct IsTagged : std::false_type {}; + +template +struct IsTagged> : std::true_type {}; template struct IsTagged> : std::true_type {}; template +struct IsTagged> : std::true_type {}; +template struct IsTagged> : std::true_type {}; template struct IsTagged> : std::true_type {}; template struct IsTagged> : std::true_type {}; + template <> struct IsTagged : std::true_type {}; template <> struct IsTagged : std::true_type {}; -// A template for counting regular arguments in the Ts pack. +// A template for counting regular arguments in the Ts pack (arguments that are +// not wrapped into a special tag). template struct NumArgs; @@ -1175,9 +1266,15 @@ class Handler : public Ffi { static constexpr int64_t kNumArgs = internal::NumArgs::value; + static constexpr int64_t kNumOptionalArgs = + internal::NumTagged::value; + static constexpr int64_t kNumRets = internal::NumTagged::value; + static constexpr int64_t kNumOptionalRets = + internal::NumTagged::value; + static constexpr int64_t kNumAttrs = internal::NumTagged::value; @@ -1232,7 +1329,16 @@ class Handler : public Ffi { return InvalidArgument( call_frame->api, StrCat("Wrong number of arguments: expected at least ", - kNumArgs - 1, " but got ", call_frame->args.size)); + kNumArgs - kNumOptionalArgs - 1, " but got ", + call_frame->args.size)); + } + } else if constexpr (internal::HasOptionalArgTag::value) { + if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) { + return InvalidArgument( + call_frame->api, + StrCat("Wrong number of arguments: expected at least ", + kNumArgs - kNumOptionalArgs, " but got ", + call_frame->args.size)); } } else { if (XLA_FFI_PREDICT_FALSE(call_frame->args.size != kNumArgs)) { @@ -1249,8 +1355,17 @@ class Handler : public Ffi { if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size < kNumRets)) { return InvalidArgument( call_frame->api, - StrCat("Wrong number of results: expected at least ", kNumRets - 1, - " but got ", call_frame->rets.size)); + StrCat("Wrong number of results: expected at least ", + kNumRets - kNumOptionalRets - 1, " but got ", + call_frame->rets.size)); + } + } else if constexpr (internal::HasOptionalRetTag::value) { + if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size < kNumRets)) { + return InvalidArgument( + call_frame->api, + StrCat("Wrong number of results: expected at least ", + kNumRets - kNumOptionalRets, " but got ", + call_frame->rets.size)); } } else { if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size != kNumRets)) { diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index 113c9b9c38d71b..2bbfd048688bc8 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -383,6 +384,129 @@ TEST(FfiTest, RemainingRets) { TF_ASSERT_OK(status); } +TEST(FfiTest, OptionalArgs) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + { // Single optional argument. + auto fn = [&](std::optional arg0) { + EXPECT_TRUE(arg0.has_value()); + return Error::Success(); + }; + + auto handler = Ffi::Bind().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Two optional arguments. + auto fn = [&](std::optional arg0, + std::optional arg1) { + EXPECT_TRUE(arg0.has_value()); + EXPECT_FALSE(arg1.has_value()); + return Error::Success(); + }; + + auto handler = + Ffi::Bind().OptionalArg().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Optional argument after a regular one. + auto fn = [&](AnyBuffer arg0, std::optional arg1) { + EXPECT_FALSE(arg1.has_value()); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Arg().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Remaining arguments after optional one. + auto fn = [&](std::optional arg0, RemainingArgs args) { + EXPECT_TRUE(arg0.has_value()); + EXPECT_EQ(args.size(), 0); + return Error::Success(); + }; + + auto handler = Ffi::Bind().OptionalArg().RemainingArgs().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } +} + +TEST(FfiTest, OptionalRets) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); + builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + { // Single optional result. + auto fn = [&](std::optional> ret0) { + EXPECT_TRUE(ret0.has_value()); + return Error::Success(); + }; + + auto handler = Ffi::Bind().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Two optional results. + auto fn = [&](std::optional> ret0, + std::optional> ret1) { + EXPECT_TRUE(ret0.has_value()); + EXPECT_FALSE(ret1.has_value()); + return Error::Success(); + }; + + auto handler = + Ffi::Bind().OptionalRet().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Optional result after a regular one. + auto fn = [&](Result ret0, + std::optional> ret1) { + EXPECT_FALSE(ret1.has_value()); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Ret().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Remaining results after optional one. + auto fn = [&](std::optional> ret0, RemainingRets rets) { + EXPECT_TRUE(ret0.has_value()); + EXPECT_EQ(rets.size(), 0); + return Error::Success(); + }; + + auto handler = Ffi::Bind().OptionalRet().RemainingRets().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } +} + TEST(FfiTest, AutoBinding) { static constexpr char kI32[] = "i32"; diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 15e918cf3f875e..9fb4ff8e249600 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -676,6 +676,129 @@ TEST(FfiTest, RemainingRets) { TF_ASSERT_OK(status); } +TEST(FfiTest, OptionalArgs) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + { // Single optional argument. + auto fn = [&](std::optional arg0) { + EXPECT_TRUE(arg0.has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Two optional arguments. + auto fn = [&](std::optional arg0, + std::optional arg1) { + EXPECT_TRUE(arg0.has_value()); + EXPECT_FALSE(arg1.has_value()); + return absl::OkStatus(); + }; + + auto handler = + Ffi::Bind().OptionalArg().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Optional argument after a regular one. + auto fn = [&](AnyBuffer arg0, std::optional arg1) { + EXPECT_FALSE(arg1.has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().Arg().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Remaining arguments after optional one. + auto fn = [&](std::optional arg0, RemainingArgs args) { + EXPECT_TRUE(arg0.has_value()); + EXPECT_EQ(args.size(), 0); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().OptionalArg().RemainingArgs().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } +} + +TEST(FfiTest, OptionalRets) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); + builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + { // Single optional result. + auto fn = [&](std::optional> ret0) { + EXPECT_TRUE(ret0.has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Two optional results. + auto fn = [&](std::optional> ret0, + std::optional> ret1) { + EXPECT_TRUE(ret0.has_value()); + EXPECT_FALSE(ret1.has_value()); + return absl::OkStatus(); + }; + + auto handler = + Ffi::Bind().OptionalRet().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Optional result after a regular one. + auto fn = [&](Result ret0, + std::optional> ret1) { + EXPECT_FALSE(ret1.has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().Ret().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Remaining results after optional one. + auto fn = [&](std::optional> ret0, RemainingRets rets) { + EXPECT_TRUE(ret0.has_value()); + EXPECT_EQ(rets.size(), 0); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().OptionalRet().RemainingRets().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } +} + TEST(FfiTest, RunOptionsCtx) { auto call_frame = CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build(); auto* expected = reinterpret_cast(0x01234567); diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD index b857a8a15ad532..47574f0ff59f38 100644 --- a/third_party/xla/xla/hlo/evaluator/BUILD +++ b/third_party/xla/xla/hlo/evaluator/BUILD @@ -135,10 +135,12 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/hlo/ir:hlo", + "//xla/service:call_graph", "//xla/service:dynamic_dimension_inference", "//xla/service:hlo_element_type_converter", "//xla/service:hlo_module_config", "//xla/service:shape_inference", + "//xla/service:tuple_points_to_analysis", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:test_utils", diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index b91e50f23052d6..9fe65193b84d97 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -146,8 +146,9 @@ absl::StatusOr Compare(const Shape& shape, Comparison comparison, std::optional GetInstructionStaticValueAsBool( const HloInstruction* instruction) { HloEvaluator evaluator; - absl::StatusOr static_value = evaluator.Evaluate( - instruction, /*recursively_evaluate_nonconstant_operands=*/true); + absl::StatusOr static_value = + evaluator.Evaluate(instruction, /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { return static_value->GetFirstElement(); } @@ -232,10 +233,12 @@ struct DynamicOrStaticInteger { }; std::optional GetInstructionValueAsInteger( - const HloInstruction* instruction) { + const HloInstruction* instruction, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { HloEvaluator evaluator; - absl::StatusOr static_value = evaluator.Evaluate( - instruction, /*recursively_evaluate_nonconstant_operands=*/true); + absl::StatusOr static_value = + evaluator.Evaluate(instruction, precomputed_analyses, + /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { if (instruction->shape().element_type() == PrimitiveType::PRED) { return DynamicOrStaticInteger{ @@ -274,14 +277,16 @@ struct ParamIndexAndValue { }; std::optional TryParsingInstructionAsParameterAndInteger( - const HloInstruction* instruction) { + const HloInstruction* instruction, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { // Skip copies. if (instruction->opcode() == HloOpcode::kCopy) { - return TryParsingInstructionAsParameterAndInteger(instruction->operand(0)); + return TryParsingInstructionAsParameterAndInteger(instruction->operand(0), + precomputed_analyses); } if (instruction->opcode() == HloOpcode::kCopyDone) { return TryParsingInstructionAsParameterAndInteger( - instruction->operand(0)->operand(1)); + instruction->operand(0)->operand(1), precomputed_analyses); } ParamIndexAndValue result; if (Match(instruction, match::GetTupleElement().WithOperand( @@ -289,7 +294,7 @@ std::optional TryParsingInstructionAsParameterAndInteger( result.param_index = instruction->tuple_index(); } std::optional integer_value = - GetInstructionValueAsInteger(instruction); + GetInstructionValueAsInteger(instruction, precomputed_analyses); result.value = std::move(integer_value); if (!result.IsValid()) { return std::nullopt; @@ -318,11 +323,12 @@ using WhileCondComparisonOrNoOp = std::variant; std::optional ParseComparisonOperand( - const HloInstruction* operand) { + const HloInstruction* operand, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { if (operand->opcode() == HloOpcode::kCopy || operand->opcode() == HloOpcode::kCopyStart || operand->opcode() == HloOpcode::kCopyDone) { - return ParseComparisonOperand(operand->operand(0)); + return ParseComparisonOperand(operand->operand(0), precomputed_analyses); } std::optional param_index; if (Match(operand, match::GetTupleElement().WithOperand( @@ -330,7 +336,7 @@ std::optional ParseComparisonOperand( param_index = operand->tuple_index(); } std::optional operand_value = - GetInstructionValueAsInteger(operand); + GetInstructionValueAsInteger(operand, precomputed_analyses); if (!param_index.has_value() && !operand_value.has_value()) { return std::nullopt; } @@ -338,12 +344,13 @@ std::optional ParseComparisonOperand( } std::optional PatternMatchLoopCondComparison( - const HloInstruction* comparison) { + const HloInstruction* comparison, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { CHECK_EQ(comparison->opcode(), HloOpcode::kCompare); std::optional lhs = - ParseComparisonOperand(comparison->operand(0)); + ParseComparisonOperand(comparison->operand(0), precomputed_analyses); std::optional rhs = - ParseComparisonOperand(comparison->operand(1)); + ParseComparisonOperand(comparison->operand(1), precomputed_analyses); if (!lhs.has_value() || !rhs.has_value()) { return std::nullopt; } @@ -353,18 +360,21 @@ std::optional PatternMatchLoopCondComparison( // Finds the while loop condition comparison by matching the loop condition root // with known patterns. std::optional PatternMatchLoopCondRoot( - const HloInstruction* loop_cond_root) { + const HloInstruction* loop_cond_root, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { if (loop_cond_root->opcode() == HloOpcode::kCopy) { - return PatternMatchLoopCondRoot(loop_cond_root->operand(0)); + return PatternMatchLoopCondRoot(loop_cond_root->operand(0), + precomputed_analyses); } if (loop_cond_root->opcode() == HloOpcode::kCopyDone) { - return PatternMatchLoopCondRoot(loop_cond_root->operand(0)->operand(1)); + return PatternMatchLoopCondRoot(loop_cond_root->operand(0)->operand(1), + precomputed_analyses); } if (loop_cond_root->opcode() == HloOpcode::kCompare) { // Base pattern #1: gte-0 comp gte-1 // Base pattern #2: constant comp gte // Base pattern #3: gte comp constant - return PatternMatchLoopCondComparison(loop_cond_root); + return PatternMatchLoopCondComparison(loop_cond_root, precomputed_analyses); } // Base pattern #4: gte is a boolean scalar and it was return immediately. if (Match(loop_cond_root, match::GetTupleElement().WithOperand( @@ -390,7 +400,8 @@ std::optional PatternMatchLoopCondRoot( const HloInstruction* to_apply_root = to_apply->root_instruction(); if (Match(to_apply_root, match::Tuple())) { return PatternMatchLoopCondRoot( - to_apply_root->operand(loop_cond_root->tuple_index())); + to_apply_root->operand(loop_cond_root->tuple_index()), + precomputed_analyses); } } // Recursive pattern #2: @@ -400,23 +411,26 @@ std::optional PatternMatchLoopCondRoot( match::GetTupleElement().WithOperand(0, match::Tuple()))) { const HloInstruction* new_cond_root = loop_cond_root->operand(0)->operand(loop_cond_root->tuple_index()); - return PatternMatchLoopCondRoot(new_cond_root); + return PatternMatchLoopCondRoot(new_cond_root, precomputed_analyses); } return std::nullopt; } std::optional PatternMatchInductionVarUpdate( - const HloInstruction* induction_var_update, int64_t tuple_index) { + const HloInstruction* induction_var_update, int64_t tuple_index, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { if (induction_var_update->opcode() == HloOpcode::kCopy) { return PatternMatchInductionVarUpdate(induction_var_update->operand(0), - tuple_index); + tuple_index, precomputed_analyses); } if (induction_var_update->opcode() == HloOpcode::kCopyDone) { return PatternMatchInductionVarUpdate( - induction_var_update->operand(0)->operand(1), tuple_index); + induction_var_update->operand(0)->operand(1), tuple_index, + precomputed_analyses); } std::optional update_param_index_and_value = - TryParsingInstructionAsParameterAndInteger(induction_var_update); + TryParsingInstructionAsParameterAndInteger(induction_var_update, + precomputed_analyses); if (update_param_index_and_value.has_value()) { if (update_param_index_and_value->param_index.has_value()) { @@ -450,12 +464,14 @@ std::optional PatternMatchInductionVarUpdate( const HloInstruction* update_lhs = induction_var_update->operand(0); VLOG(3) << "PatternMatchInductionVarUpdate, LHS: " << update_lhs->ToString(); std::optional update_lhs_param_index_and_value = - TryParsingInstructionAsParameterAndInteger(update_lhs); + TryParsingInstructionAsParameterAndInteger(update_lhs, + precomputed_analyses); const HloInstruction* update_rhs = induction_var_update->operand(1); VLOG(3) << "PatternMatchInductionVarUpdate, RHS: " << update_rhs->ToString(); std::optional update_rhs_param_index_and_value = - TryParsingInstructionAsParameterAndInteger(update_rhs); + TryParsingInstructionAsParameterAndInteger(update_rhs, + precomputed_analyses); if (!update_lhs_param_index_and_value.has_value() || !update_lhs_param_index_and_value->value.has_value() || @@ -496,14 +512,16 @@ std::optional PatternMatchInductionVarUpdate( // using pattern matching. std::optional PatternMatchInductionVarUpdateFromLoopBodyRoot( - const HloInstruction* loop_body_root, int64_t tuple_index) { + const HloInstruction* loop_body_root, int64_t tuple_index, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { if (loop_body_root->opcode() != HloOpcode::kTuple || loop_body_root->operand_count() <= tuple_index) { return std::nullopt; } const HloInstruction* induction_var_update = loop_body_root->operand(tuple_index); - return PatternMatchInductionVarUpdate(induction_var_update, tuple_index); + return PatternMatchInductionVarUpdate(induction_var_update, tuple_index, + precomputed_analyses); } std::optional PatternMatchLoopCondVarOverride( @@ -528,16 +546,15 @@ std::optional EvaluateWhileLoopParamInitValue( } const HloInstruction* element_instruction = param_instruction->operand(tuple_index); - return GetInstructionValueAsInteger(element_instruction); + return GetInstructionValueAsInteger(element_instruction, + /*precomputed_analyses=*/{}); } } // namespace namespace internal { -#if !defined(_MSC_VER) constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl"; -#endif std::optional ParseEvalErrorDetail(const absl::Status& error) { auto error_detail = error.GetPayload(kEvalErrorDetailUrl); @@ -636,14 +653,16 @@ std::optional HandleStaticLoopComparison( } std::optional PatternMatchParseWhileLoop( - const HloInstruction* while_op) { + const HloInstruction* while_op, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { VLOG(3) << "PatternMatchParseWhileLoop, while_op: " << while_op->name(); const HloComputation* while_cond = while_op->while_condition(); const HloComputation* while_body = while_op->while_body(); const HloInstruction* while_operand = while_op->operand(0); // Try to parse the loop condition comparison. std::optional loop_comparison_or_noop = - PatternMatchLoopCondRoot(while_cond->root_instruction()); + PatternMatchLoopCondRoot(while_cond->root_instruction(), + precomputed_analyses); if (!loop_comparison_or_noop.has_value()) { return std::nullopt; } @@ -706,7 +725,8 @@ std::optional PatternMatchParseWhileLoop( induction_var_init = EvaluateWhileLoopParamInitValue( while_operand, *loop_comparison.lhs.param_index); induction_var_update = PatternMatchInductionVarUpdateFromLoopBodyRoot( - while_body->root_instruction(), *loop_comparison.lhs.param_index); + while_body->root_instruction(), *loop_comparison.lhs.param_index, + precomputed_analyses); lhs_is_induction_var = true; } } else { @@ -716,7 +736,8 @@ std::optional PatternMatchParseWhileLoop( induction_var_init = EvaluateWhileLoopParamInitValue( while_operand, *loop_comparison.rhs.param_index); induction_var_update = PatternMatchInductionVarUpdateFromLoopBodyRoot( - while_body->root_instruction(), *loop_comparison.rhs.param_index); + while_body->root_instruction(), *loop_comparison.rhs.param_index, + precomputed_analyses); lhs_is_induction_var = false; } } @@ -922,7 +943,7 @@ absl::StatusOr HloEvaluator::Evaluate( } absl::StatusOr HloEvaluator::Evaluate( - const HloInstruction* instruction, + const HloInstruction* instruction, PrecomputedAnalyses precomputed_analyses, bool recursively_evaluate_nonconstant_operands) { arg_literals_.clear(); evaluated_.clear(); @@ -932,7 +953,7 @@ absl::StatusOr HloEvaluator::Evaluate( absl::MakeCleanup([this] { enable_partial_evaluation_ = false; }); enable_partial_evaluation_ = recursively_evaluate_nonconstant_operands; TF_RETURN_IF_ERROR( - EvaluateInternal(instruction, /*shape_index=*/{}, + EvaluateInternal(instruction, precomputed_analyses, /*shape_index=*/{}, recursively_evaluate_nonconstant_operands)); const Literal& result = GetEvaluatedLiteralFor(instruction); if (!result.IsKnown()) { @@ -945,8 +966,8 @@ bool HloEvaluator::TryEvaluate(const HloInstruction* instruction, Literal* result, bool recursively_evaluate_nonconstant_operands) { CHECK(result != nullptr); - auto result_or = - Evaluate(instruction, recursively_evaluate_nonconstant_operands); + auto result_or = Evaluate(instruction, /*precomputed_analyses=*/{}, + recursively_evaluate_nonconstant_operands); if (!result_or.ok()) { VLOG(1) << "TryEvaluate failed:" << result_or.status(); return false; @@ -1068,11 +1089,12 @@ absl::StatusOr HloEvaluator::EvaluateDotOp( } absl::Status HloEvaluator::EvaluateParameterFromCallerArgument( - const HloInstruction* parameter, const ShapeIndex& shape_index) { + const HloInstruction* parameter, const ShapeIndex& shape_index, + PrecomputedAnalyses analyses) { CHECK(!evaluated_.contains(parameter)); const HloComputation* parent_computation = parameter->parent(); std::vector computation_callers = - call_graph_cache_->GetComputationCallers(parent_computation); + analyses.call_graph->GetComputationCallers(parent_computation); // If the parent computation has multiple callers, we cannot determine from // which caller the arguments are passed. if (computation_callers.size() != 1) { @@ -1095,11 +1117,11 @@ absl::Status HloEvaluator::EvaluateParameterFromCallerArgument( HloComputation* while_body = computation_caller->while_body(); TF_ASSIGN_OR_RETURN( const LogicalBuffer* logical_buffer, - tuple_points_to_analysis_cache_->GetBufferDefinedAt( + analyses.tuple_points_to->GetBufferDefinedAt( while_body->parameter_instruction(parameter->parameter_number()), shape_index)); const TuplePointsToAnalysis::BufferAliasVector& buffer_aliases = - tuple_points_to_analysis_cache_->GetBufferAliases(*logical_buffer); + analyses.tuple_points_to->GetBufferAliases(*logical_buffer); bool unchanged_in_return = false; for (const BufferAlias& buffer_alias : buffer_aliases) { if (buffer_alias.instruction() == while_body->root_instruction() && @@ -1111,7 +1133,8 @@ absl::Status HloEvaluator::EvaluateParameterFromCallerArgument( return MakeEvalErrorDueToParamOrInfeed(*parameter); } } - TF_RETURN_IF_ERROR(EvaluateInternal(caller_operand, shape_index, true)); + TF_RETURN_IF_ERROR( + EvaluateInternal(caller_operand, analyses, shape_index, true)); const Literal& caller_operand_literal = GetEvaluatedLiteralFor(caller_operand); evaluated_[parameter] = @@ -1156,7 +1179,8 @@ DimensionVector HloEvaluator::MakeDimMultipliers(const Shape& shape) { } absl::Status HloEvaluator::EvaluateInternal( - const HloInstruction* instruction, const ShapeIndex& shape_index, + const HloInstruction* instruction, PrecomputedAnalyses precomputed_analyses, + const ShapeIndex& shape_index, bool recursively_evaluate_nonconstant_operands) { // Don't need to evaluate this instruction again if it has already been // evaluated. @@ -1172,34 +1196,44 @@ absl::Status HloEvaluator::EvaluateInternal( if (instruction->opcode() == HloOpcode::kGetTupleElement) { ShapeIndex new_shape_index = shape_index; new_shape_index.push_front(instruction->tuple_index()); - TF_RETURN_IF_ERROR( - EvaluateInternal(instruction->operand(0), new_shape_index, - /*recursively_evaluate_nonconstant_operands=*/true)); + TF_RETURN_IF_ERROR(EvaluateInternal( + instruction->operand(0), precomputed_analyses, new_shape_index, + /*recursively_evaluate_nonconstant_operands=*/true)); } else if (instruction->opcode() == HloOpcode::kTuple && !shape_index.empty()) { ShapeIndex new_shape_index = shape_index; int64_t tuple_index = new_shape_index.front(); new_shape_index.pop_front(); TF_RETURN_IF_ERROR( - EvaluateInternal(instruction->operand(tuple_index), new_shape_index, + EvaluateInternal(instruction->operand(tuple_index), + precomputed_analyses, new_shape_index, /*recursively_evaluate_nonconstant_operands=*/true)); } else if (instruction->opcode() == HloOpcode::kParameter) { - if (!call_graph_cache_) { - HloModule* module = instruction->GetModule(); - call_graph_cache_ = CallGraph::Build(module); - } - if (!tuple_points_to_analysis_cache_) { - HloModule* module = instruction->GetModule(); - absl::StatusOr> - tuple_points_to_analysis = TuplePointsToAnalysis::Run(module); - if (tuple_points_to_analysis.ok()) { - tuple_points_to_analysis_cache_ = - *std::move(tuple_points_to_analysis); - } - } - if (call_graph_cache_ && tuple_points_to_analysis_cache_) { - absl::Status argument_eval_status = - EvaluateParameterFromCallerArgument(instruction, shape_index); + CallGraph* call_graph = + (precomputed_analyses.call_graph != nullptr) + ? precomputed_analyses.call_graph + : std::invoke([this, instruction]() -> CallGraph* { + call_graph_cache_ = + CallGraph::Build(instruction->GetModule()); + return call_graph_cache_.get(); + }); + TuplePointsToAnalysis* tuple_points_to_analysis = + (precomputed_analyses.tuple_points_to != nullptr) + ? precomputed_analyses.tuple_points_to + : std::invoke([this, instruction]() -> TuplePointsToAnalysis* { + absl::StatusOr> + tuple_points_to_analysis = + TuplePointsToAnalysis::Run(instruction->GetModule()); + if (!tuple_points_to_analysis.ok()) { + return nullptr; + } + tuple_points_to_analysis_cache_ = + *std::move(tuple_points_to_analysis); + return tuple_points_to_analysis_cache_.get(); + }); + if (call_graph && tuple_points_to_analysis) { + absl::Status argument_eval_status = EvaluateParameterFromCallerArgument( + instruction, shape_index, {tuple_points_to_analysis, call_graph}); if (!argument_eval_status.ok()) { VLOG(4) << "Failed to evaluate parameter " << instruction->name() << " from caller. Reason: " << argument_eval_status.message(); @@ -1211,7 +1245,7 @@ absl::Status HloEvaluator::EvaluateInternal( } else { for (HloInstruction* operand : instruction->operands()) { TF_RETURN_IF_ERROR(EvaluateInternal( - operand, /*shape_index=*/{}, + operand, precomputed_analyses, /*shape_index=*/{}, /*recursively_evaluate_nonconstant_operands=*/true)); // Except for the above and following cases, we do not support handling // unknown operands for other HLOs. So mark the result as unknown. @@ -3448,7 +3482,7 @@ absl::StatusOr CreateScalarLiteral(int64_t value, absl::StatusOr TryParseAndEvaluateWhileInductionVar( const HloInstruction* while_hlo) { std::optional parsed_while_loop = - PatternMatchParseWhileLoop(while_hlo); + PatternMatchParseWhileLoop(while_hlo, /*precomputed_analyses=*/{}); if (!parsed_while_loop.has_value() || parsed_while_loop->is_dynamic()) { return FailedPrecondition( "Cannot evaluate a while loop's induction variable since the loop " @@ -3489,7 +3523,8 @@ absl::Status HloEvaluator::HandleWhile(const HloInstruction* while_hlo) { auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); if (!lcv.IsKnown()) { std::optional parsed_while_loop = - PatternMatchParseWhileLoop(while_hlo); + PatternMatchParseWhileLoop(while_hlo, + /*precomputed_analyses=*/{}); evaluated_[while_hlo] = Literal::CreateFromShapeWithUnknownLeafArrays(while_hlo->shape()); if (!parsed_while_loop.has_value() || parsed_while_loop->is_dynamic() || diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index 0eab57a0d68de1..5f004073b7a3a4 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -55,43 +55,18 @@ limitations under the License. namespace xla { -// Represents a parsed static while loop. We normalize the loop representation -// so that it starts from the induction_var_init_value and increments by -// step_size until it exceeds or goes below loop_bound. -struct ParsedStaticWhileLoop { - // The number of iterations to be executed. - int64_t trip_count = -1; - // The tuple index of the induction variable in the while argument tuple. - int64_t induction_var_index = -1; - // The induction variable's initial value. - int64_t induction_var_init_value = -1; - // The induction variable is incremented by this number (could be negative) - // in each iteration. - int64_t step_size = -1; - int64_t loop_bound = -1; -}; - -// Indicates whether a parsed while loop is static or dynamic. If the loop is -// static, it contains a value for StaticLoopInfo; otherwise the loop is -// dynamic. We consider a loop dynamic if its induction variable's initial -// value or the loop bound's value depends on the while's parent computation's -// parameter. -struct ParsedWhileLoop { - std::optional static_while_loop; - bool is_dynamic() const { return !static_while_loop.has_value(); } -}; -constexpr ParsedWhileLoop kParsedDynamicWhileLoop = ParsedWhileLoop(); - -// Tries to parse a while loop using a set of predefined patterns. -// Returns the parsing result. -std::optional PatternMatchParseWhileLoop( - const HloInstruction* while_op); - // Responsible for evaluating HLO and obtain literal as the evaluation results. // // This class is not thread-safe. class HloEvaluator : public ConstDfsHloVisitorWithDefault { public: + // Precomputed analyses that can be passed to Evaluate functions to avoid + // recomputation during evaluation. + struct PrecomputedAnalyses { + TuplePointsToAnalysis* tuple_points_to; + CallGraph* call_graph; + }; + // Only evaluate up to max_loop_iterations per while-loop execution if // specified. explicit HloEvaluator(int64_t max_loop_iterations = -1); @@ -167,8 +142,12 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // within its parent computation until it encounters something that cannot be // evaluated, such as an Infeed or a Parameter instruction. // It makes best effort to partially evaluate a dependency if possible. + // The caller may pass in non-null `precomputed_analyses` to avoid + // recomputation during evaluation; the caller must ensure that any + // precomputed analyses were performed on the module containing `instruction`. absl::StatusOr Evaluate( const HloInstruction* instruction, + PrecomputedAnalyses precomputed_analyses = {}, bool recursively_evaluate_nonconstant_operands = false); // Same as Evaluate, except returning false on error and accepts an output @@ -270,13 +249,20 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // marked as undetermined unless it has been previously evaluated using // EvaluateInternal. Such partial evaluation reduces the computation and // memory overhead in cases where we need only one tuple element by avoiding - // the evaluation of a full tuple. + // the evaluation of a full tuple. Any non-null `precomputed_analyses` will be + // used instead of recomputing. absl::Status EvaluateInternal( - const HloInstruction* instruction, const ShapeIndex& shape_index = {}, + const HloInstruction* instruction, + PrecomputedAnalyses precomputed_analyses, + const ShapeIndex& shape_index = {}, bool recursively_evaluate_nonconstant_operands = false); + // Evaluates the result of a `parameter` instruction by traversing the call + // graph as given in `analyses`. `shape_index` has the same effect as in + // EvaluateInternal above. absl::Status EvaluateParameterFromCallerArgument( - const HloInstruction* parameter, const ShapeIndex& shape_index); + const HloInstruction* parameter, const ShapeIndex& shape_index, + PrecomputedAnalyses analyses); // Helper method to extract a list of int64_t from evaluated instruction for // start_indices for DynamicSlice and DynamicUpdateSlice. @@ -518,6 +504,41 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { std::unique_ptr> MatmulArray2D(const Array2D& lhs, const Array2D& rhs); +// Represents a parsed static while loop. We normalize the loop representation +// so that it starts from the induction_var_init_value and increments by +// step_size until it exceeds or goes below loop_bound. +struct ParsedStaticWhileLoop { + // The number of iterations to be executed. + int64_t trip_count = -1; + // The tuple index of the induction variable in the while argument tuple. + int64_t induction_var_index = -1; + // The induction variable's initial value. + int64_t induction_var_init_value = -1; + // The induction variable is incremented by this number (could be negative) + // in each iteration. + int64_t step_size = -1; + int64_t loop_bound = -1; +}; + +// Indicates whether a parsed while loop is static or dynamic. If the loop is +// static, it contains a value for StaticLoopInfo; otherwise the loop is +// dynamic. We consider a loop dynamic if its induction variable's initial +// value or the loop bound's value depends on the while's parent computation's +// parameter. +struct ParsedWhileLoop { + std::optional static_while_loop; + bool is_dynamic() const { return !static_while_loop.has_value(); } +}; +constexpr ParsedWhileLoop kParsedDynamicWhileLoop = ParsedWhileLoop(); + +// Tries to parse a while loop using a set of predefined patterns. +// Returns the parsing result. Any non-null `precompute_analyses` will be used +// instead of recomputing, and it is the caller's responsibility to ensure that +// the analyses are valid for the module that contains `while_op`. +std::optional PatternMatchParseWhileLoop( + const HloInstruction* while_op, + HloEvaluator::PrecomputedAnalyses precomputed_analyses = {}); + // Functionality exposed for testing. Do not rely on anything in this namespace // outside this file. namespace internal { @@ -530,11 +551,7 @@ enum class EvalErrorDetail : uint32_t { kDynamicValueDependence = 0, }; -#if defined(_MSC_VER) -extern const absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl"; -#else extern const absl::string_view kEvalErrorDetailUrl; -#endif std::optional ParseEvalErrorDetail(const absl::Status& error); diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc index 72dc6f84c4ade6..901c99fe1b66d3 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -50,10 +50,12 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" +#include "xla/service/call_graph.h" #include "xla/service/dynamic_dimension_inference.h" #include "xla/service/hlo_element_type_converter.h" #include "xla/service/hlo_module_config.h" #include "xla/service/shape_inference.h" +#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" @@ -167,14 +169,15 @@ class HloEvaluatorTest : public HloTestBase { TF_ASSERT_OK_AND_ASSIGN( Literal result, evaluator_.Evaluate( - instruction, + instruction, /*precomputed_analyses=*/{}, /*recursively_evaluate_nonconstant_operands=*/true)); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } void TestRecursiveEvaluationFailure(HloInstruction* instruction) { - absl::StatusOr result = evaluator_.Evaluate( - instruction, /*recursively_evaluate_nonconstant_operands=*/true); + absl::StatusOr result = + evaluator_.Evaluate(instruction, /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true); EXPECT_TRUE(!result.ok()); } @@ -5035,6 +5038,79 @@ TEST_F(HloEvaluatorTest, GetTupleElementInterleavedWithTupleSucceeds) { TestRecursivelyEvaluateInstruction(gte2, expected); } +// Tests that we can evaluate a parameter instruction through the call graph. +TEST_F(HloEvaluatorTest, ParameterThroughCallSucceeds) { + constexpr absl::string_view kHloModule = R"( + HloModule parameter_through_call + + %identity { + ROOT %param = s32[] parameter(0) + } + + ENTRY parameter_through_call { + %constant = s32[] constant(42) + ROOT %call = s32[] call(s32[] %constant), to_apply=%identity + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + const HloInstruction* parameter_instruction = nullptr; + for (const auto* computation : hlo_module->computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_instruction = instruction; + } + } + } + ASSERT_NE(parameter_instruction, nullptr); + + Literal expected = LiteralUtil::CreateR0(42); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator_.Evaluate(parameter_instruction, /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +// As above, but with analyses precomputed. +TEST_F(HloEvaluatorTest, ParameterThroughCallSucceedsWithPrecomputation) { + constexpr absl::string_view kHloModule = R"( + HloModule parameter_through_call + + %identity { + ROOT %param = s32[] parameter(0) + } + + ENTRY parameter_through_call { + %constant = s32[] constant(42) + ROOT %call = s32[] call(s32[] %constant), to_apply=%identity + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + const HloInstruction* parameter_instruction = nullptr; + for (const auto* computation : hlo_module->computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_instruction = instruction; + } + } + } + ASSERT_NE(parameter_instruction, nullptr); + + Literal expected = LiteralUtil::CreateR0(42); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr tuple_points_to, + TuplePointsToAnalysis::Run(hlo_module.get())); + std::unique_ptr call_graph = CallGraph::Build(hlo_module.get()); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator_.Evaluate(parameter_instruction, + {tuple_points_to.get(), call_graph.get()}, + /*recursively_evaluate_nonconstant_operands=*/true)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + class PatternMatchParseWhileLoopTest : public HloTestBase {}; TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedInsideOfCond) { @@ -5084,6 +5160,59 @@ TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedInsideOfCond) { EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 5); } +TEST_F(PatternMatchParseWhileLoopTest, + LoopBoundDefinedInsideOfCondWithPrecomputation) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %while_condition { + %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %loop_bound = s32[] constant(5) + ROOT result = pred[] compare(%gte.0, %loop_bound), direction=LT + } + + %while_body { + %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = f32[1024, 1024] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.1, f32[1024, 1024] %gte.2) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.1 = f32[1024, 1024] parameter(0) + %constant.0 = s32[] constant(0) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + %while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + ROOT %result = f32[1024, 1024] get-tuple-element((s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=2 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr tuple_points_to, + TuplePointsToAnalysis::Run(hlo_module.get())); + std::unique_ptr call_graph = CallGraph::Build(hlo_module.get()); + + HloInstruction* while_op = + hlo_module->entry_computation()->root_instruction()->mutable_operand(0); + std::optional parsed_while_loop = PatternMatchParseWhileLoop( + while_op, {tuple_points_to.get(), call_graph.get()}); + ASSERT_TRUE(parsed_while_loop.has_value()); + EXPECT_FALSE(parsed_while_loop->is_dynamic()); + EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 5); + EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0); + EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0); + EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1); + EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 5); +} + TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedOutsideOfCond) { constexpr absl::string_view kHloModule = R"( HloModule accumulated_all_reduce diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 30a0283e10ee6d..09b40eaa7127cd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -33,6 +33,7 @@ cc_library( compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", + ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_solver", ":auto_sharding_strategy", @@ -110,6 +111,7 @@ cc_library( hdrs = ["auto_sharding_memory.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_proto_cc", ":auto_sharding_strategy", "//xla:status_macros", @@ -148,6 +150,7 @@ cc_library( ], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_proto_cc", "//xla:shape_util", "//xla/hlo/ir:hlo", @@ -170,6 +173,7 @@ cc_library( hdrs = ["auto_sharding_cost_graph.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_strategy", ":matrix", "//xla:shape_util", @@ -188,7 +192,9 @@ cc_library( hdrs = ["auto_sharding_option.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_util", + "//xla:array", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -219,6 +225,7 @@ cc_library( compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", + ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_wrapper", @@ -247,6 +254,7 @@ cc_library( hdrs = ["cluster_environment.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_util", @@ -273,6 +281,7 @@ cc_library( hdrs = ["auto_sharding_util.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_strategy", "//xla:array", "//xla:shape_tree", @@ -328,6 +337,20 @@ tf_proto_library( visibility = ["//visibility:public"], ) +cc_library( + name = "auto_sharding_device_mesh", + srcs = ["auto_sharding_device_mesh.cc"], + hdrs = [ + "auto_sharding_device_mesh.h", + ], + compatible_with = get_compatible_with_libtpu_portable(), + deps = [ + "//xla:array", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/types:span", + ], +) + build_test( name = "auto_sharding_runner_build_test", targets = [ @@ -345,6 +368,8 @@ xla_cc_test( ], deps = [ ":auto_sharding", + ":auto_sharding_cost_graph", + ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_util", @@ -359,6 +384,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 85efaa8588220b..042e4547a28e09 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -50,6 +50,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_memory.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" @@ -121,7 +122,7 @@ std::vector CommunicationReshardingCostVector( double ComputeMemoryReshardingCost(const Shape& shape, const HloSharding& src_sharding, const HloSharding& dst_sharding, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { int64_t src_n_dim = NumTileDimensions(src_sharding); int64_t dst_n_dim = NumTileDimensions(dst_sharding); @@ -889,7 +890,7 @@ double ComputeSortCommunicationCost(const int64_t sort_dim, // Enumerate all 1d partition strategies. void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -961,7 +962,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, } void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -969,7 +970,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, absl::Span tensor_dims); void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -1012,7 +1013,7 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, } void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -1075,7 +1076,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, } void EnumerateAll1DPartitionReshape( - const HloInstruction* ins, const Array& device_mesh, + const HloInstruction* ins, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, bool only_allow_divisible, const std::string& suffix) { @@ -1129,14 +1130,14 @@ void EnumerateAll1DPartitionReshape( } void BuildStrategyAndCostForReshape( - const HloInstruction* ins, const Array& device_mesh, + const HloInstruction* ins, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, absl::Span tensor_dims); // Enumerate all partitions for reshape. Batch dim is always partitioned. void EnumeratePartitionReshape(const HloInstruction* ins, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const InstructionBatchDimMap& batch_dim_map, @@ -1181,7 +1182,7 @@ void EnumeratePartitionReshape(const HloInstruction* ins, } void BuildStrategyAndCostForReshape( - const HloInstruction* ins, const Array& device_mesh, + const HloInstruction* ins, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, absl::Span tensor_dims) { @@ -1876,7 +1877,7 @@ std::unique_ptr CreateReshapeStrategies( const InstructionBatchDimMap& batch_dim_map, const AutoShardingOption& option, StrategyGroups& strategy_groups, const CallGraph& call_graph) { - const Array& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; int mesh_nn_dims = VectorGreaterThanOneElementCount(device_mesh.dimensions()); std::unique_ptr strategy_group = CreateLeafStrategyGroup( @@ -1989,6 +1990,7 @@ AutoShardingSolverResult CallSolver( request.mutable_max_cost()->set_coeff(*max_cost); } for (const auto& [edge, edge_cost] : cost_graph.edge_costs_) { + const auto normalized_edge_cost = Normalize(edge_cost); AutoShardingSolverRequest_Pair raw_edge; raw_edge.set_first(edge.first); raw_edge.set_second(edge.second); @@ -1997,8 +1999,8 @@ AutoShardingSolverResult CallSolver( AutoShardingSolverRequest_Costs mij; for (NodeStrategyIdx i = 0; i < edge_cost.n_; i++) { for (NodeStrategyIdx j = 0; j < edge_cost.m_; j++) { - rij.add_costs(edge_cost(i, j).communication_cost); - mij.add_costs(edge_cost(i, j).memory_cost); + rij.add_costs(normalized_edge_cost(i, j).communication_cost); + mij.add_costs(normalized_edge_cost(i, j).memory_cost); } } request.mutable_resharding_costs()->Add(std::move(rij)); @@ -2335,7 +2337,7 @@ absl::Status InsertReshardReshapes( absl::flat_hash_map>& preserve_shardings) { const std::vector& instructions = sequence.instructions(); - const Array& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; // Post process: fix some corner cases. ReshardingCache resharding_cache_entity; ReshardingCache* resharding_cache = &resharding_cache_entity; @@ -3291,8 +3293,8 @@ absl::Status GenerateReduceScatter( void AnnotateShardingWithSimpleHeuristic( HloModule* module, const std::string& heuristic, const AliasMap& alias_map, const ClusterEnvironment& cluster_env) { - const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh_1d = cluster_env.device_mesh_1d_; int64_t num_devices = device_mesh.num_elements(); // Count the non-one mesh dimension. @@ -3413,7 +3415,7 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, const AutoShardingOption& option) { int mesh_dim = option.force_batch_dim_to_mesh_dim; int batch_dim = batch_map.at(GetBatchDimMapKey(ins)); - const Array& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; if (shape.dimensions(batch_dim) % device_mesh.dim(mesh_dim) != 0) { return absl::InvalidArgumentError( @@ -3451,8 +3453,8 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, HloSharding GetReduceScatterOutput(const HloInstruction* ins, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env) { - const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh_1d = cluster_env.device_mesh_1d_; if (ins->opcode() == HloOpcode::kDot) { const DotDimensionNumbers& dot_dnums = ins->dot_dimension_numbers(); @@ -3963,7 +3965,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // batch_dim_map = spmd::BuildInstructionBatchDimMap(sequence); // ----- Read parameters of device mesh ----- - Array original_device_mesh(option_.device_mesh_shape); + spmd::DeviceMesh original_device_mesh(option_.device_mesh_shape); original_device_mesh.SetValues(option_.device_mesh_ids); const int64_t original_memory_budget = option_.memory_budget_per_device; @@ -3990,7 +3992,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( std::vector mesh_shape = partial_mesh_shapes[mesh_idx]; LOG(INFO) << "Processing partial mesh shape: " << spmd::ToString(mesh_shape); - Array device_mesh(mesh_shape); + spmd::DeviceMesh device_mesh(mesh_shape); int64_t total_devices = 1; for (int64_t i : mesh_shape) { @@ -4014,10 +4016,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // use the actual device order only for the final full mesh. device_mesh.SetValues(option_.device_mesh_ids); } else { - std::vector device_mesh_ids = - std::vector(total_devices); - std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); - device_mesh.SetValues(device_mesh_ids); + device_mesh.FillIota(0); } // TODO (zhuohan): Include the prof result as an option. @@ -4630,24 +4629,4 @@ absl::StatusOr AutoSharding::Run( return module_is_changed; } -absl::StatusOr DummyAutoSharding::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - // ----- Set Dummy Replicated Sharding ----- - HloComputation* entry = module->entry_computation(); - - for (HloInstruction* inst : entry->instructions()) { - const Shape& out_shape = inst->shape(); - if (out_shape.IsTuple()) { - ShapeTree tuple_sharding(out_shape, - HloSharding::Replicate()); - inst->set_sharding(HloSharding::Tuple(tuple_sharding)); - } else { - inst->set_sharding(HloSharding::Replicate()); - } - } - - return true; -} - } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index 4695efc60d0dea..bdc137a0e462c6 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -31,8 +31,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" @@ -49,18 +49,6 @@ limitations under the License. namespace xla { -class DummyAutoSharding : public HloModulePass { - public: - DummyAutoSharding() = default; - ~DummyAutoSharding() override = default; - absl::string_view name() const override { return "dummy_auto_sharding"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - enum class AutoShardingResult { kModuleUnchanged, kModuleChangedShardingPerformed, @@ -140,7 +128,7 @@ namespace spmd { // Their comments can be found in their definitions in *.cc files. HloSharding Tile(const Shape& shape, absl::Span tensor_dims, absl::Span mesh_dims, - const Array& device_mesh); + const DeviceMesh& device_mesh); std::vector CommunicationReshardingCostVector( const StrategyGroup* strategy_group, const Shape& shape, @@ -319,7 +307,7 @@ std::unique_ptr CreateTupleStrategyGroup(size_t instruction_id); // Enumerate all 1d partition strategies. void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -329,7 +317,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, // Enumerate all partitions recursively. void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc index 85127883e21937..9d28df32b04f5c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -34,6 +35,24 @@ limitations under the License. namespace xla { namespace spmd { +EdgeReshardingCostMatrix Normalize(const EdgeReshardingCostMatrix& edge_cost) { + double min_communication_cost = std::numeric_limits::max(); + for (int i = 0; i < edge_cost.n_; ++i) { + for (int j = 0; j < edge_cost.m_; ++j) { + min_communication_cost = + std::min(min_communication_cost, edge_cost(i, j).communication_cost); + } + } + if (min_communication_cost >= 0) return edge_cost; + EdgeReshardingCostMatrix normalized_edge_cost = edge_cost; + for (int i = 0; i < edge_cost.n_; ++i) { + for (int j = 0; j < edge_cost.m_; ++j) { + normalized_edge_cost(i, j).communication_cost -= min_communication_cost; + } + } + return normalized_edge_cost; +} + CostGraph::CostGraph(const StrategyGroups& strategy_groups, const AssociativeDotPairs& associative_dot_pairs) { node_lens_.reserve(strategy_groups.size()); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index fda06ee8ec1e7b..3d6bac1b139196 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -55,6 +55,10 @@ struct EdgeReshardingCost { using EdgeReshardingCostMatrix = Matrix; +// Normalizes the edge cost matrix by a fixed constant to ensure there are no +// negative communication costs. +EdgeReshardingCostMatrix Normalize(const EdgeReshardingCostMatrix& edge_cost); + // A graph data structure to simplify the edge cost graph. It merges nodes and // performs path compression. class CostGraph { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.cc new file mode 100644 index 00000000000000..07ab282f0bfa38 --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.cc @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" + +#include + +#include "absl/types/span.h" +#include "xla/array.h" + +namespace xla { +namespace spmd { + +namespace { +bool AreValuesIota(const absl::Span values) { + for (int i = 1; i < values.size(); ++i) { + if (values[i] - values[i - 1] != 1) { + return false; + } + } + return true; +} +} // namespace + +void DeviceMesh::SetValues(absl::Span values) { + device_array.SetValues(values); + is_iota = AreValuesIota(values); +} +} // namespace spmd +} // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h new file mode 100644 index 00000000000000..919ea64027833b --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_DEVICE_MESH_H_ +#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_DEVICE_MESH_H_ + +#include +#include + +#include "absl/functional/function_ref.h" +#include "absl/types/span.h" +#include "xla/array.h" + +namespace xla { +namespace spmd { +struct DeviceMesh { + Array device_array; + bool is_iota; + + explicit DeviceMesh(absl::Span sizes) + : device_array(sizes), is_iota(false) {} + + void FillIota(const int64_t value) { + device_array.FillIota(value); + is_iota = true; + } + + void SetValues(absl::Span values); + + int64_t num_dimensions() const { return device_array.num_dimensions(); } + + // Returns the size of the dimension at the given index. + int64_t dim(int64_t n) const { return device_array.dim(n); } + + // Returns a vector containing the dimensions of the array. + absl::Span dimensions() const { + return device_array.dimensions(); + } + + // Returns the total number of elements in the array. + int64_t num_elements() const { return device_array.num_elements(); } + + std::string ToString() const { return device_array.ToString(); } + + void Reshape(absl::Span new_dimensions) { + device_array.Reshape(new_dimensions); + } + + void TransposeDimensions(absl::Span permutation) { + device_array.TransposeDimensions(permutation); + is_iota = false; + } + + const int64_t& operator()(absl::Span indexes) const { + return device_array(indexes); + } + + int64_t& operator()(absl::Span indexes) { + return device_array(indexes); + } + + void Each(absl::FunctionRef, int64_t*)> f) { + device_array.Each(f); + } + + void Each( + absl::FunctionRef, int64_t)> f) const { + device_array.Each(f); + } +}; +} // namespace spmd +} // namespace xla + +#endif // XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_DEVICE_MESH_H_ diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index dbd161b365d698..9224da821db47b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" @@ -92,7 +93,7 @@ class HandlerBase { double compute_cost, double communication_cost); HloSharding CreateInputSpec(const HloInstruction* ins, const DimMap& dim_map, - const Array& device_mesh) const { + const DeviceMesh& device_mesh) const { if (dim_map.empty()) return HloSharding::Replicate(); std::vector tensor_dims; std::vector> mesh_dims; @@ -179,7 +180,7 @@ class HandlerBase { } bool IsFullyReplicatedSharding(const DimMap& dim_map, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { if (dim_map.empty()) { return true; } @@ -194,7 +195,7 @@ class HandlerBase { bool IsFullyReplicatedStrategy(const DimMap& output_dim_map, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { return IsFullyReplicatedSharding(output_dim_map, device_mesh) && IsFullyReplicatedSharding(lhs_dim_map, device_mesh) && IsFullyReplicatedSharding(rhs_dim_map, device_mesh); @@ -223,7 +224,7 @@ class HandlerBase { const AutoShardingOption& option_; const CallGraph& call_graph_; - const Array& device_mesh_; + const DeviceMesh& device_mesh_; const HloInstruction* lhs_; const HloInstruction* rhs_; }; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index f204ff43496d61..67a6fed7149280 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -67,6 +67,10 @@ using ::operations_research::MPVariable; // solver cannot guarantee exact numerical precision. constexpr double kMaxCostEpsilon = 1.0001; +// Memory contributions in the Mixed ILP are converted to units in this range; +// beware that significantly larger / smaller values can cause numerical issues. +constexpr double kMemoryMultiplier = 1e6; + bool AutoShardingSolverOutput::operator==( const AutoShardingSolverOutput& other) const { return s_val == other.s_val && cost == other.cost && @@ -261,7 +265,7 @@ std::optional> ReduceMemoryTerms( reduced_groups.push_back({group.prims().begin(), group.prims().end()}); } } - solver.MakeIntVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(), + solver.MakeNumVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(), absl::StrCat("group_", prim_type), &group_vars); for (int64_t group_idx = 0; group_idx < group_vars.size(); ++group_idx) { MPConstraint* constraint = solver.MakeRowConstraint( @@ -271,7 +275,7 @@ std::optional> ReduceMemoryTerms( for (const int64_t prim_idx : reduced_groups[group_idx]) { for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) { double memory_cost = memory_costs.at(prim_idx).costs(j); - memory_cost /= request.memory_budget() / 100.0; + memory_cost /= request.memory_budget() / kMemoryMultiplier; const double accumulated_coefficient = constraint->GetCoefficient(prim_vars[prim_idx][j]); constraint->SetCoefficient(prim_vars[prim_idx][j], @@ -302,9 +306,12 @@ void AddMemoryTerms( time_idx <= intervals[prim_idx].second; ++time_idx) { if (!reduced_times.contains(time_idx)) continue; if (!constraints.contains(time_idx)) { - MPConstraint* constraint = solver.MakeRowConstraint( - -MPSolver::infinity(), 100.0, absl::StrCat("mem[", time_idx, "]")); - if (overbudget_var) constraint->SetCoefficient(overbudget_var, -100.0); + MPConstraint* constraint = + solver.MakeRowConstraint(-MPSolver::infinity(), kMemoryMultiplier, + absl::StrCat("mem[", time_idx, "]")); + if (overbudget_var) { + constraint->SetCoefficient(overbudget_var, -kMemoryMultiplier); + } constraints[time_idx] = constraint; } MPConstraint* constraint = constraints[time_idx]; @@ -314,7 +321,7 @@ void AddMemoryTerms( } for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) { double memory_cost = memory_costs.at(prim_idx).costs(j); - memory_cost /= request.memory_budget() / 100.0; + memory_cost /= request.memory_budget() / kMemoryMultiplier; const double accumulated_coefficient = constraint->GetCoefficient(prim_vars[prim_idx][j]); constraint->SetCoefficient(prim_vars[prim_idx][j], diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 5a237b2f979a67..05631bc2090b1a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -441,6 +441,46 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { EXPECT_EQ(result, expected_result); } +TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { + AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); + const std::vector> node_intervals = + {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; + const std::vector> edge_intervals = + {{1, 2}, {2, 3}}; + const std::vector> node_groups = {{0, 1}}; + const std::vector> edge_groups = {}; + const CostMatrix memory_costs = {{1, 1, 1, 1}, // These values are tiny and + {2, 2, 2}, // shouldn't be rounded up. + {300, 300, 300, 300, 300, 300, 300}, + {4000, 4000, 4000, 4000, 4000, 4000, 4000}, + {50000, 50000, 50000}}; + const CostMatrix memory_edge_costs = {{0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0}, + {0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0}}; + request.clear_live(); + request.clear_memory_costs(); + AddIntervals(request.mutable_node_intervals(), node_intervals); + AddIntervals(request.mutable_edge_intervals(), edge_intervals); + AddGroups(request.mutable_node_groups(), node_groups); + AddGroups(request.mutable_edge_groups(), edge_groups); + AddCosts(request.mutable_memory_costs(), memory_costs); + AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); + request.set_enable_memory_edge_costs(true); + request.set_memory_budget(4321); + + const AutoShardingSolverResult result = CallORToolsSolver(request); + + const std::vector s_val = {0, 0, 0, 0, 0}; + const double objective_value = 7650.0; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; + EXPECT_EQ(result, expected_result); +} + TEST(CallORToolsSolverTest, SolvesWithEquivalences) { const AutoShardingSolverRequest request = AutoShardingSolverRequestWithEquivalences(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 27ddc790be9b99..4414c4f7340dbc 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h" @@ -77,7 +78,7 @@ std::optional ConstructImprovedSharding( std::pair ComputeSliceShardingAndCommunicationCostFromOperand( const HloSharding& input_spec, const Shape& old_shape, - const Shape& new_shape, const Array& device_mesh, + const Shape& new_shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env) { if (input_spec.IsReplicated()) { return std::make_pair(input_spec, 0); @@ -135,7 +136,7 @@ BuildStrategyAndCost( const ClusterEnvironment& cluster_env, AutoShardingOption& option, const CallGraph& call_graph, const HloCostAnalysis& hlo_cost_analysis, bool trying_multiple_mesh_shapes) { - // const Array& device_mesh = cluster_env.device_mesh_; + // const DeviceMesh& device_mesh = cluster_env.device_mesh_; StrategyMap strategy_map; // This map stores all of the trimmed strategies due to user specified // sharding. The key is the instruction id, the value is the strategies. This diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index b1c0bb12a665e0..ab040e9c77edd7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -23,12 +23,15 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" @@ -68,25 +71,72 @@ using ::testing::Pair; using ::testing::ResultOf; using ::testing::UnorderedElementsAre; -using DummyAutoShardingTest = HloTestBase; - -TEST_F(DummyAutoShardingTest, ReplicatedShardingDummy) { - constexpr absl::string_view kHloString = R"( -HloModule module -ENTRY %elementwise { - %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0) - %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1) - %add = f32[5,7,11,13]{3,2,1,0} add(%param0, %param1) - ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%add) -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, DummyAutoSharding().Run(module.get())); - EXPECT_TRUE(changed); - auto* instruction = FindInstruction(module.get(), "param0"); - ASSERT_NE(instruction, nullptr); - EXPECT_THAT(instruction, op::Sharding("{replicated}")); +TEST(DeviceMeshTest, IotaDeviceMesh2DStartsWith0) { + DeviceMesh device_mesh({2, 4}); + device_mesh.FillIota(0); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4)); + EXPECT_EQ(device_mesh.num_elements(), 8); +} + +TEST(DeviceMeshTest, IotaDeviceMesh3DStartsWithNonZero) { + DeviceMesh device_mesh({2, 4, 8}); + device_mesh.FillIota(55); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); +} + +TEST(DeviceMeshTest, ExplicitSetValuesInferIotaIotaValues) { + DeviceMesh device_mesh({2, 4, 8}); + std::vector device_mesh_values(64); + absl::c_iota(device_mesh_values, 34); + device_mesh.SetValues(device_mesh_values); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); +} + +TEST(DeviceMeshTest, ExplicitSetValuesInferIotaNonIotaValues) { + DeviceMesh device_mesh({2, 4, 8}); + std::vector device_mesh_values(64); + absl::c_iota(device_mesh_values, 34); + device_mesh_values[54] = 54; + device_mesh.SetValues(device_mesh_values); + EXPECT_FALSE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); +} + +TEST(DeviceMeshTest, ReshapeTestWithoutIota) { + DeviceMesh device_mesh({2, 4, 8}); + std::vector device_mesh_values(64); + absl::c_iota(device_mesh_values, 34); + device_mesh_values[54] = 54; + device_mesh.SetValues(device_mesh_values); + EXPECT_FALSE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); + + device_mesh.Reshape({2, 32}); + EXPECT_FALSE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 32)); + EXPECT_EQ(device_mesh.num_elements(), 64); +} + +TEST(DeviceMeshTest, ReshapeTestWithIota) { + DeviceMesh device_mesh({2, 4, 8}); + std::vector device_mesh_values(64); + absl::c_iota(device_mesh_values, 34); + device_mesh.SetValues(device_mesh_values); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); + + device_mesh.Reshape({2, 32}); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 32)); + EXPECT_EQ(device_mesh.num_elements(), 64); } class AutoShardingTest : public HloTestBase { @@ -2517,6 +2567,36 @@ ENTRY entry { input_output_alias_config_after.ToString()); } +TEST(NormalizeTest, NormalizeHandlesNegativeCosts) { + EdgeReshardingCostMatrix edge_cost(2, 2); + edge_cost(0, 0).communication_cost = -100; + edge_cost(0, 1).communication_cost = 200; + edge_cost(1, 0).communication_cost = 300; + edge_cost(1, 1).communication_cost = 400; + + const EdgeReshardingCostMatrix normalized_edge_cost = Normalize(edge_cost); + + EXPECT_EQ(normalized_edge_cost(0, 0).communication_cost, 0); + EXPECT_EQ(normalized_edge_cost(0, 1).communication_cost, 300); + EXPECT_EQ(normalized_edge_cost(1, 0).communication_cost, 400); + EXPECT_EQ(normalized_edge_cost(1, 1).communication_cost, 500); +} + +TEST(NormalizeTest, NormalizeHandlesPositiveCosts) { + EdgeReshardingCostMatrix edge_cost(2, 2); + edge_cost(0, 0).communication_cost = 100; + edge_cost(0, 1).communication_cost = 200; + edge_cost(1, 0).communication_cost = 300; + edge_cost(1, 1).communication_cost = 400; + + const EdgeReshardingCostMatrix normalized_edge_cost = Normalize(edge_cost); + + EXPECT_EQ(normalized_edge_cost(0, 0).communication_cost, 100); + EXPECT_EQ(normalized_edge_cost(0, 1).communication_cost, 200); + EXPECT_EQ(normalized_edge_cost(1, 0).communication_cost, 300); + EXPECT_EQ(normalized_edge_cost(1, 1).communication_cost, 400); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 3611c579b1455a..6cb0c6b5c2ef9f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -37,12 +37,12 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "json/json.h" #include "xla/array.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -59,7 +59,6 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" namespace xla { namespace spmd { @@ -871,7 +870,7 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { } } -bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, +bool IsDivisible(const HloInstruction* ins, const DeviceMesh& device_mesh, absl::Span tensor_dims, absl::Span mesh_dims) { CHECK_EQ(tensor_dims.size(), mesh_dims.size()); @@ -1081,7 +1080,7 @@ int64_t NumTileDimensions(const HloSharding& spec) { } bool TileAssignmentMatchesMesh(const HloSharding& spec, - const Array& mesh) { + const DeviceMesh& mesh) { int sharded_dims = 0; for (int i = 0; i < spec.tile_assignment().num_dimensions(); ++i) { if (spec.tile_assignment().dim(i) > 1) { @@ -1096,39 +1095,23 @@ bool TileAssignmentMatchesMesh(const HloSharding& spec, return sharded_dims <= 0; } -absl::StatusOr> GetTensorDimToMeshDimNoCrash( - int64_t tensor_shape_rank, const HloSharding& spec, - const Array& device_mesh, bool consider_reverse_device_meshes) { - if (spec.IsReplicated()) { - return std::vector(tensor_shape_rank, -1); - } - // Check the compatibility of tensor_shape_rank and spec - if (tensor_shape_rank != spec.TiledDataRank()) { - return absl::InvalidArgumentError( - "Tensor shape rank should be equal to the tiled data rank of the input " - "spec."); - } - +absl::StatusOr> GetMeshDimPermutationOrderInShardingSpec( + const HloSharding& spec, const DeviceMesh& device_mesh, + bool consider_reverse_device_meshes) { auto check_mesh = [&](const Array& mesh) -> std::optional> { // Permute the dimensions (or axes in numpy term), find the transform that // makes tile_assignment == device_mesh. std::vector axes(mesh.num_dimensions()); absl::c_iota(axes, 0); - bool found = false; do { Array transposed_mesh = Transpose(mesh, axes); if (std::equal(transposed_mesh.begin(), transposed_mesh.end(), spec.tile_assignment().array().begin())) { - found = true; - break; + return axes; } } while (absl::c_next_permutation(axes)); - if (found) { - return std::optional>(axes); - } else { - return std::nullopt; - } + return std::nullopt; }; // This is an expensive search, as we try all possible meshes obtained by @@ -1136,7 +1119,6 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( // the somewhat rare kReverse HLO op. The hope therefore is that most calls to // the function that reach here will find a mapping within the first iteration // of the loop below. - bool found = false; std::vector axes(device_mesh.num_dimensions()); size_t num_subsets = consider_reverse_device_meshes ? (1 << device_mesh.num_dimensions()) : 1; @@ -1157,24 +1139,35 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( *device = device_mesh(original_indices); }); if (auto result = check_mesh(new_mesh); result.has_value()) { - axes = result.value(); - found = true; - break; + return result.value(); } } + return absl::NotFoundError(absl::StrCat("Could not find mapping for ", + spec.ToString(), " with device mesh ", + device_mesh.ToString())); +} - if (!found) { - return absl::NotFoundError( - absl::StrCat("Could not find mapping for ", spec.ToString(), - " with device mesh ", device_mesh.ToString())); +absl::StatusOr> GetTensorDimToMeshDimNoCrash( + int64_t tensor_shape_rank, const HloSharding& spec, + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes) { + if (spec.IsReplicated()) { + return std::vector(tensor_shape_rank, -1); + } + // Check the compatibility of tensor_shape_rank and spec + if (tensor_shape_rank != spec.TiledDataRank()) { + return absl::InvalidArgumentError( + "Tensor shape rank should be equal to the tiled data rank of the input " + "spec."); } - if (!TileAssignmentMatchesMesh(spec, device_mesh)) { return absl::InvalidArgumentError( "Device mesh and tile assignment need to have the same number of " "sharded dims."); } + TF_ASSIGN_OR_RETURN(std::vector axes, + GetMeshDimPermutationOrderInShardingSpec( + spec, device_mesh, consider_reverse_device_meshes)); // Transform tile_assignment_dimensions using found transformation (axes). std::vector tensor_dim_to_device_dim(tensor_shape_rank, -1); int mesh_index = 0; @@ -1192,7 +1185,7 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( std::vector GetTensorDimToMeshDim( int64_t tensor_shape_rank, const HloSharding& spec, - const Array& device_mesh, bool consider_reverse_device_meshes) { + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes) { auto mapping_or = GetTensorDimToMeshDimNoCrash( tensor_shape_rank, spec, device_mesh, consider_reverse_device_meshes); if (mapping_or.ok()) { @@ -1202,9 +1195,10 @@ std::vector GetTensorDimToMeshDim( } } -absl::StatusOr ComputeIntermediateShape( - const HloSharding& src_sharding, const HloSharding& dst_sharding, - const Shape& shape, const Array& device_mesh) { +absl::StatusOr ComputeIntermediateShape(const HloSharding& src_sharding, + const HloSharding& dst_sharding, + const Shape& shape, + const DeviceMesh& device_mesh) { int64_t src_n_dim = NumTileDimensions(src_sharding); const HloSharding* sharding_1d; @@ -1240,7 +1234,7 @@ absl::StatusOr ComputeIntermediateShape( HloInstruction* ReshardTensor(HloInstruction* tensor, const HloSharding& src_sharding, const HloSharding& dst_sharding, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { const Shape& shape = tensor->shape(); HloComputation* computation = tensor->parent(); @@ -1288,7 +1282,7 @@ HloInstruction* ReshardTensor(HloInstruction* tensor, absl::Status FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( HloInstruction* inst, const std::vector>& dst_shardings, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { size_t tuple_size = inst->shape().tuple_shapes_size(); const HloSharding& current_sharding = inst->sharding(); @@ -1352,7 +1346,7 @@ absl::Status FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( absl::Status FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* inst, const HloSharding& dst_sharding, - const Array& device_mesh, + const DeviceMesh& device_mesh, absl::flat_hash_map>& preserve_shardings) { const HloInstruction* operand = inst->operand(0); @@ -1394,7 +1388,7 @@ absl::Status FixMixedMeshShapeReshardingGetTupleElement( absl::Status FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const HloSharding& dst_sharding, - const Array& device_mesh, + const DeviceMesh& device_mesh, ReshardingCache* resharding_cache) { HloInstruction* operand = inst->mutable_operand(operand_num); if (operand->opcode() == HloOpcode::kOutfeed || @@ -1493,7 +1487,7 @@ bool IsDivisible(int64_t numerator, int64_t denominator) { } std::vector> GetReplicaGroupsAlongOneDimension( - const Array& device_mesh, int32_t communication_dim) { + const DeviceMesh& device_mesh, int32_t communication_dim) { CHECK_LT(communication_dim, device_mesh.num_dimensions()); std::vector indices(device_mesh.num_dimensions(), 0); std::vector> replica_groups; @@ -1514,10 +1508,10 @@ std::vector> GetReplicaGroupsAlongOneDimension( } // Create a HloSharding that tiles some tensor dims on some device mesh dims. -HloSharding Tile(const Shape& tensor_shape, - absl::Span tensor_dims, - const std::vector>& mesh_dims, - const Array& device_mesh) { +HloSharding TileV1(const Shape& tensor_shape, + absl::Span tensor_dims, + const std::vector>& mesh_dims, + const DeviceMesh& device_mesh) { CHECK_EQ(tensor_dims.size(), mesh_dims.size()); CHECK(tensor_shape.IsArray()); std::vector tile_assignment_dimensions(tensor_shape.rank(), 1); @@ -1562,7 +1556,7 @@ HloSharding Tile(const Shape& tensor_shape, if (proceed_to_next_tensor_dim && current_tensor_dim == tensor_shape.rank() - 1) { - AppendFlattenElements(&tile_assignment_devices, device_mesh, + AppendFlattenElements(&tile_assignment_devices, device_mesh.device_array, mesh_indices); return; } @@ -1610,15 +1604,91 @@ HloSharding Tile(const Shape& tensor_shape, : HloSharding::Tile(std::move(tile_assignment)); } +HloSharding TileV2(const Shape& tensor_shape, + absl::Span tensor_dims, + const std::vector>& mesh_dims, + const DeviceMesh& device_mesh) { + CHECK_EQ(tensor_dims.size(), mesh_dims.size()); + CHECK(tensor_shape.IsArray()); + std::vector tile_assignment_dimensions(tensor_shape.rank(), 1); + std::vector transpose_perm; + absl::Span reshape_dims = device_mesh.dimensions(); + + struct TensorDimWithIndex { + int64_t tensor_dim; + int64_t idx_in_vector; + }; + + std::vector sorted_tensor_dims(tensor_dims.size()); + for (size_t i = 0; i < tensor_dims.size(); ++i) { + sorted_tensor_dims[i].tensor_dim = tensor_dims[i]; + sorted_tensor_dims[i].idx_in_vector = i; + } + + absl::c_sort(sorted_tensor_dims, + [](const TensorDimWithIndex& a, const TensorDimWithIndex& b) { + return a.tensor_dim < b.tensor_dim; + }); + + // Split on certain mesh dimensions + int64_t split_prod = 1; + for (const TensorDimWithIndex& tensor_dim_with_index : sorted_tensor_dims) { + int64_t tensor_dim = tensor_dim_with_index.tensor_dim; + const std::vector& mesh_dims_for_this_tensor_dim = + mesh_dims[tensor_dim_with_index.idx_in_vector]; + int64_t num_devices_for_tensor_dim = 1; + for (int64_t mesh_dim_idx : mesh_dims_for_this_tensor_dim) { + num_devices_for_tensor_dim *= device_mesh.dim(mesh_dim_idx); + transpose_perm.push_back(mesh_dim_idx); + } + tile_assignment_dimensions[tensor_dim] = num_devices_for_tensor_dim; + split_prod *= num_devices_for_tensor_dim; + } + // Replicate on remaining mesh dimensions + bool replicate_on_last_tile_dim = false; + if (split_prod < device_mesh.num_elements()) { + tile_assignment_dimensions.push_back(device_mesh.num_elements() / + split_prod); + replicate_on_last_tile_dim = true; + } + + for (int i = 0; i < device_mesh.num_dimensions(); ++i) { + if (absl::c_find(transpose_perm, i) == transpose_perm.end()) { + transpose_perm.push_back(i); + } + } + + // Make HloSharding + TileAssignment tile_assignment(tile_assignment_dimensions, reshape_dims, + transpose_perm); + + return replicate_on_last_tile_dim + ? HloSharding::PartialTile(std::move(tile_assignment)) + : HloSharding::Tile(std::move(tile_assignment)); +} + +HloSharding Tile(const Shape& tensor_shape, + absl::Span tensor_dims, + const std::vector>& mesh_dims, + const DeviceMesh& device_mesh) { + if (device_mesh.is_iota) { + return TileV2(tensor_shape, tensor_dims, mesh_dims, device_mesh); + } + return TileV1(tensor_shape, tensor_dims, mesh_dims, device_mesh); +} + HloSharding Tile(const Shape& tensor_shape, absl::Span tensor_dims, absl::Span mesh_dims, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { std::vector> mesh_dims_general(mesh_dims.size()); for (int i = 0; i < mesh_dims.size(); ++i) { mesh_dims_general[i].push_back(mesh_dims[i]); } - return Tile(tensor_shape, tensor_dims, mesh_dims_general, device_mesh); + if (device_mesh.is_iota) { + return TileV2(tensor_shape, tensor_dims, mesh_dims_general, device_mesh); + } + return TileV1(tensor_shape, tensor_dims, mesh_dims_general, device_mesh); } AliasMap BuildAliasMap(const HloModule* module, @@ -2207,30 +2277,37 @@ std::vector> InferMeshShapesToTry( const HloModule& module) { int64_t sharding_1d = -1; absl::flat_hash_set> shardings_nd; + int max_shardings_nd_dimension = -1; std::function process_sharding; - process_sharding = [&sharding_1d, &shardings_nd, - &process_sharding](const HloSharding& sharding) { + process_sharding = [&](const HloSharding& sharding) { if (sharding.IsTuple()) { for (const HloSharding& child : sharding.tuple_elements()) { process_sharding(child); } - } else if (!sharding.IsReplicated() && !sharding.IsTileMaximal() && - !sharding.IsManual()) { - absl::Span dims = sharding.tile_assignment().dimensions(); - std::vector dims_greater_than_one; - for (const int64_t dim : dims) { - if (dim > 1) { - dims_greater_than_one.push_back(dim); - } - } - if (dims_greater_than_one.size() == 1) { - CHECK(sharding_1d == -1 || sharding_1d == dims_greater_than_one[0]); - sharding_1d = dims_greater_than_one[0]; - } else { - std::sort(dims_greater_than_one.begin(), dims_greater_than_one.end()); - shardings_nd.insert(dims_greater_than_one); + return; + } + if (sharding.IsReplicated() || sharding.IsTileMaximal() || + sharding.IsManual()) { + return; + } + absl::Span dims = sharding.tile_assignment().dimensions(); + std::vector dims_greater_than_one; + for (const int64_t dim : dims) { + if (dim > 1) { + dims_greater_than_one.push_back(dim); } } + if (dims_greater_than_one.size() == 1) { + CHECK(sharding_1d == -1 || sharding_1d == dims_greater_than_one[0]); + sharding_1d = dims_greater_than_one[0]; + } else { + std::sort(dims_greater_than_one.begin(), dims_greater_than_one.end()); + shardings_nd.insert(dims_greater_than_one); + + max_shardings_nd_dimension = + std::max(max_shardings_nd_dimension, + static_cast(dims_greater_than_one.size())); + } }; for (const HloComputation* comp : module.computations()) { @@ -2241,20 +2318,29 @@ std::vector> InferMeshShapesToTry( } } + for (auto mesh_shape_it = shardings_nd.begin(), end = shardings_nd.end(); + mesh_shape_it != end;) { + // `erase()` will invalidate `mesh_shape_it`, so advance `mesh_shape_it` + // first. + auto copy_it = mesh_shape_it++; + if (copy_it->size() < max_shardings_nd_dimension) { + shardings_nd.erase(copy_it); + } + } + if (shardings_nd.empty() && sharding_1d < 0) { return {}; - } else if (shardings_nd.empty()) { - CHECK_GE(sharding_1d, 0); + } + if (shardings_nd.empty()) { return {{1, sharding_1d}}; - } else { - std::vector> result; - for (std::vector mesh : shardings_nd) { - do { - result.push_back(std::vector(mesh)); - } while (std::next_permutation(std::begin(mesh), std::end(mesh))); - } - return result; } + std::vector> result; + for (std::vector mesh : shardings_nd) { + do { + result.push_back(std::vector(mesh)); + } while (std::next_permutation(std::begin(mesh), std::end(mesh))); + } + return result; } std::vector> InferOrEnumerateMeshShapesToTry( @@ -2271,9 +2357,7 @@ std::vector> InferOrEnumerateMeshShapesToTry( dedup_result.insert( absl::btree_multiset(mesh_shape.begin(), mesh_shape.end())); } - mesh_shapes.clear(); - for (const absl::btree_multiset& mesh_shape_set : dedup_result) { mesh_shapes.push_back( std::vector(mesh_shape_set.begin(), mesh_shape_set.end())); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index f114a3d9057281..678030f3520fb4 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -82,7 +83,7 @@ inline std::string ToAdaptiveString(const HloInstruction* ins) { // Return whether the tensor shape is divisible by // the number of devices along multiple dimensions. -bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, +bool IsDivisible(const HloInstruction* ins, const DeviceMesh& device_mesh, absl::Span tensor_dims, absl::Span mesh_dims); @@ -376,26 +377,27 @@ int64_t NumTileDimensions(const HloSharding& spec); // When fixing mixed mesh resharding (see below), compute the correct // intermediate shape in order to insert copies. -absl::StatusOr ComputeIntermediateShape( - const HloSharding& src_sharding, const HloSharding& dst_sharding, - const Shape& shape, const Array& device_mesh); +absl::StatusOr ComputeIntermediateShape(const HloSharding& src_sharding, + const HloSharding& dst_sharding, + const Shape& shape, + const DeviceMesh& device_mesh); // Forcibly set the sharding of the operand of inst. // Also fix the resharding between 1d and 2d logical mesh. absl::Status FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* inst, const HloSharding& dst_sharding, - const Array& device_mesh, + const DeviceMesh& device_mesh, absl::flat_hash_map>& preserve_shardings); absl::Status FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( HloInstruction* inst, const std::vector>& dst_sharding, - const Array& device_mesh); + const DeviceMesh& device_mesh); absl::Status FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const HloSharding& dst_sharding, - const Array& device_mesh, + const DeviceMesh& device_mesh, ReshardingCache* resharding_cache); // Gets the mapping vector from dim_from to dim_to. @@ -410,7 +412,7 @@ bool IsDivisible(int64_t numerator, int64_t denominator); // be any number of dimensions. |communication_dim| has to be one of // |device_mesh|'s dimension. std::vector> GetReplicaGroupsAlongOneDimension( - const Array& device_mesh, int32_t communication_dim); + const DeviceMesh& device_mesh, int32_t communication_dim); // Gets values in |array| along |dim| while keeping indices at other // dimensions at 0, e.g., array is 2D and dim = 1, this returns array[0, 1], @@ -424,8 +426,7 @@ absl::StatusOr CheckArithmeticSequence( // Checks if the number of sharded dimensions in the tile assignment matches the // device mesh. -bool TileAssignmentMatchesMesh(const HloSharding& spec, - const Array& mesh); +bool TileAssignmentMatchesMesh(const HloSharding& spec, const DeviceMesh& mesh); // Get the mapped mesh dimension for every tensor dimension. // The returned value maps ith tensor dim to one mesh dim. -1 means the tensor @@ -434,23 +435,21 @@ bool TileAssignmentMatchesMesh(const HloSharding& spec, // mesh dim, and 1st tensor dim maps to the 2nd mesh dim. std::vector GetTensorDimToMeshDim( int64_t tensor_shape_rank, const HloSharding& spec, - const Array& device_mesh, - bool consider_reverse_device_meshes = false); + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes = false); absl::StatusOr> GetTensorDimToMeshDimNoCrash( int64_t tensor_shape_rank, const HloSharding& spec, - const Array& device_mesh, - bool consider_reverse_device_meshes = false); + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes = false); HloSharding Tile(const Shape& tensor_shape, absl::Span tensor_dims, const std::vector>& mesh_dims, - const Array& device_mesh); + const DeviceMesh& device_mesh); HloSharding Tile(const Shape& tensor_shape, absl::Span tensor_dims, absl::Span mesh_dims, - const Array& device_mesh); + const DeviceMesh& device_mesh); AliasMap BuildAliasMap(const HloModule* module, const HloInputOutputAliasConfig& alias_config); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc index c2d82d1766e5c5..42402e39a1496f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -195,7 +195,7 @@ double ClusterEnvironment::CollectivePermuteCost( // operation as an all-gather on all mesh dimensions. double ClusterEnvironment::OverestimateReplicationCost( const Shape& shape, const HloSharding& src_spec, - const Array& device_mesh) const { + const DeviceMesh& device_mesh) const { if (src_spec.IsTileMaximal() || src_spec.IsManual()) { // TODO(b/238210866) Do not use kInfinityCost. return kInfinityCost; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h index a70570209350b2..d17b026dd8ffb4 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -27,7 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "xla/array.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/experimental/auto_sharding/profiling_result.h" @@ -44,8 +45,8 @@ namespace spmd { // the real profiling result. class ClusterEnvironment { public: - ClusterEnvironment(const Array& original_device_mesh, - const Array& device_mesh, + ClusterEnvironment(const DeviceMesh& original_device_mesh, + const DeviceMesh& device_mesh, absl::Span mesh_alpha, absl::Span mesh_beta, const ProfilingResult& prof_result, @@ -160,7 +161,7 @@ class ClusterEnvironment { // shape `shape` sharded according to `src_spec`. double OverestimateReplicationCost(const Shape& shape, const HloSharding& src_spec, - const Array& device_mesh) const; + const DeviceMesh& device_mesh) const; double ReshardingCost(const Shape& shape, const HloSharding& src_spec, const HloSharding& dst_spec) const; @@ -176,11 +177,11 @@ class ClusterEnvironment { } // The original, complete device mesh shape that describes the hardware. - const Array original_device_mesh_; + const DeviceMesh original_device_mesh_; // When solve_nd_sharding_iteratively is true, it is a partial mesh shape from // the original_device_mesh_. When solve_nd_sharding_iteratively is false, it // is the same as original_device_mesh_. - const Array device_mesh_; + const DeviceMesh device_mesh_; // Bandwidth of the device mesh const std::vector mesh_alpha_; const std::vector mesh_beta_; @@ -190,11 +191,11 @@ class ClusterEnvironment { // Cache a flatten 1d version of the device mesh. // Used for mixed mesh shape strategies. - Array device_mesh_1d_; + DeviceMesh device_mesh_1d_; // Cache a flatten 1d version of the original device mesh. // Used for mixed mesh shape strategies. - Array original_device_mesh_1d_; + DeviceMesh original_device_mesh_1d_; // The option may override the cost of communication primitives const AutoShardingOption& auto_sharding_option_; diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 429501362c52e1..e65c48d982d89d 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -78,6 +78,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/service:compilation_environments", + "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", "//xla/service:hlo_lexer", "//xla/service:hlo_module_config", @@ -99,13 +100,13 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/gtl:iterator_range", "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", - "@local_tsl//tsl/platform:human_readable_json", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 1047082f47baa9..3e73a68762e74f 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -17,18 +17,20 @@ limitations under the License. #define XLA_HLO_IR_HLO_COMPUTATION_H_ #include -#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" @@ -42,9 +44,14 @@ limitations under the License. #include "xla/printer.h" #include "xla/service/hlo.pb.h" #include "xla/service/name_uniquer.h" +#include "xla/shape.h" #include "xla/shape_tree.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/lib/gtl/iterator_range.h" +#include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 98eb6744c2564c..37d7a39d8ee0e0 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -81,8 +81,9 @@ limitations under the License. #include "tsl/lib/gtl/iterator_range.h" #include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/human_readable_json.h" #include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -1214,6 +1215,9 @@ absl::StatusOr> HloInstruction::CreateFromProto( for (const int64_t computation_id : proto.called_computation_ids()) { instruction->AppendComputation(computation_map.at(computation_id)); } + if (instruction->opcode() == HloOpcode::kWhile) { + instruction->while_body()->SetWhileCallInstruction(instruction.get()); + } TF_RET_CHECK(!proto.has_precision_config()) << instruction->opcode() << proto.DebugString(); @@ -1261,7 +1265,6 @@ absl::StatusOr> HloInstruction::CreateFromProto( const xla::OriginalValueProto& original_value_proto = proto.original_value(); auto original_value = std::make_shared(shape); - std::cerr << __func__ << ", shape: " << shape.ToString() << "\n"; for (const auto& leaf : original_value_proto.leaves()) { *original_value->mutable_element(ShapeIndex(leaf.leaf_shape_index())) = { @@ -2611,6 +2614,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); + // Repoint the while body back at the original while instruction. + // If a context was passed, the body will be cloned and the clone will + // point to the copied instruction. + while_body()->SetWhileCallInstruction(const_cast(this)); break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), branch_count() + 1); @@ -2654,6 +2661,9 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( ? context->module()->DeepCloneComputation(callee, context) : callee; }); + if (opcode() == HloOpcode::kWhile) { + clone->while_body()->SetWhileCallInstruction(clone.get()); + } } if (!suffix.empty()) { diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index a3346273a83770..a98f9963b9c2d4 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index db0f43bad95e17..cff0907ba534d5 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -1953,7 +1953,7 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( return u->opcode() == HloOpcode::kGetTupleElement; }); if (called_computations().empty()) { - // New fusion instruction. It should not be a multioutput instruction. + // New fusion instruction. It should not be a multi-output instruction. CHECK(!add_output); auto builder = HloComputation::Builder(default_called_computation_name()); builder.AddInstruction(instruction_to_append->Clone(/*suffix=*/"")); diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index 4db70131e23918..c0f03248dbf772 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -19,13 +19,13 @@ limitations under the License. #define XLA_HLO_IR_HLO_INSTRUCTIONS_H_ #include -#include #include #include #include #include #include +#include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" @@ -38,7 +38,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/iterator_util.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/printer.h" @@ -1474,7 +1473,7 @@ class HloFusionInstruction : public HloCallableInstruction { void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); // Merges the fused instructions from instruction_to_merge into the fused - // instruction set of 'this' and generates multioutput fusion instructions. + // instruction set of 'this' and generates multi-output fusion instructions. // All the users of instruction_to_merge will be redirected to 'this' // instruction. instruction_to_merge will be removed from its parent // computation. diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index a0adc80e882f1a..cc8dda9a321ee5 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include +#include #include -#include #include #include #include @@ -30,23 +30,36 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/map_util.h" #include "xla/printer.h" #include "xla/service/compilation_environments.h" +#include "xla/service/computation_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/mapped_ptr_container_sorter.h" +#include "xla/service/name_uniquer.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/gtl/map_util.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" @@ -404,8 +417,8 @@ void HloModule::Print(Printer* printer, const HloPrintOptions& options) const { ? MakeComputationSorted() : MakeComputationPostOrder(); for (const HloComputation* computation : computations) { - // Don't print async computations when the sytax sugar is enabled since that - // is redundant information. + // Don't print async computations when the syntax sugar is enabled since + // that is redundant information. if (options.syntax_sugar_async_ops() && computation->IsAsyncComputation() && computation->CanExpandIntoSingleInstruction()) { continue; @@ -847,7 +860,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( outlined_instruction); // Mark instruction_to_outline an output if it is used outside the - // subcomputation or is the output of the original computation (i.e. used + // sub-computation or is the output of the original computation (i.e. used // externally). if (instruction_to_outline->user_count() == 0 || IsUsedOutsideSubcomputation(*instruction_to_outline, @@ -916,7 +929,7 @@ std::vector HloModule::MakeComputationPostOrder( if (computations_.empty()) { return {}; } - // First determine all root computations by building a set of nonroot + // First determine all root computations by building a set of non-root // computations (computations which are called by an instruction in the // module). absl::flat_hash_set nonroot_computations; diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 0bae42a231b7d4..8f20f63bfc9b2a 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -158,6 +158,7 @@ cc_library( "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) diff --git a/third_party/xla/xla/hlo/utils/hlo_query.cc b/third_party/xla/xla/hlo/utils/hlo_query.cc index 69a6fef79857fe..147f54822aef97 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query.cc +++ b/third_party/xla/xla/hlo/utils/hlo_query.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" @@ -269,22 +270,46 @@ HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, return gte; } -bool IsBeforeInComputation(const HloComputation* computation, - absl::string_view inst1, absl::string_view inst2) { - int index1 = -1; - int index2 = -1; +HloComputation* FindComputation(HloModule* module, absl::string_view name) { + auto computations = module->computations(); + auto it = absl::c_find_if( + computations, [&](HloComputation* c) { return c->name() == name; }); + if (it == computations.end()) { + return nullptr; + } + return *it; +} + +std::pair FindFirstInstruction( + const HloComputation* computation, absl::string_view name) { int current_index = 0; - for (auto instruction : computation->instructions()) { - if (instruction->name() == inst1) { - index1 = current_index; + for (auto* instruction : computation->instructions()) { + if (instruction->name() == name) { + return {instruction, current_index}; + break; } - if (instruction->name() == inst2) { - index2 = current_index; + current_index++; + } + return {nullptr, -1}; +} + +std::pair FindFirstInstruction( + const HloComputation* computation, HloOpcode opcode) { + int current_index = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == opcode) { + return {instruction, current_index}; + break; } current_index++; } - current_index++; - return index1 < index2; + return {nullptr, -1}; +} + +bool IsBeforeInComputation(const HloComputation* computation, + absl::string_view inst1, absl::string_view inst2) { + return FindFirstInstruction(computation, inst1).second < + FindFirstInstruction(computation, inst2).second; } } // namespace hlo_query } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_query.h b/third_party/xla/xla/hlo/utils/hlo_query.h index 950082accf14f0..ec5c0b25804d10 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query.h +++ b/third_party/xla/xla/hlo/utils/hlo_query.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_HLO_UTILS_HLO_QUERY_H_ #include +#include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" @@ -153,7 +154,19 @@ bool HasX64TransformedHostTransfer(const HloModule& module); HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, int64_t index); -// TODO: b/356153995 - refactor hlo_test_base +// Gets the computation from the given module with the given name. +HloComputation* FindComputation(HloModule* module, absl::string_view name); +// Gets the first instruction and its index from the given computation with the +// given instruction name. The function returns {nullptr, -1} if the instruction +// cannot be found. +std::pair FindFirstInstruction( + const HloComputation* computation, absl::string_view name); +// Gets the first instruction and its index from the given computation with the +// given instruction opcode. The function returns {nullptr, -1} if the +// instruction cannot be found. +std::pair FindFirstInstruction( + const HloComputation* computation, HloOpcode opcode); + // Check that one instruction comes before another one for a given computation. // The function returns true if the first instruction comes before the second // one, and false otherwise. This is useful for partial checks on the diff --git a/third_party/xla/xla/hlo/utils/hlo_query_test.cc b/third_party/xla/xla/hlo/utils/hlo_query_test.cc index acefa21aa9e2f4..e4dad1007fa685 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_query_test.cc @@ -40,6 +40,14 @@ int CountInstructions(Hlo& module, HloOpcode opcode) { return counter; } +constexpr absl::string_view kConstantAdditionHloString = R"( +HloModule test +ENTRY main { + zero = f32[] constant(0) + five = f32[] constant(5) + ROOT out = f32[] add(zero, five) +})"; + TEST_F(HloQueryTest, GetInstructionWithOpCodeReturnsMatchingInstructionForModule) { constexpr absl::string_view kHloString = R"( @@ -132,5 +140,66 @@ TEST_F(HloQueryTest, GetUniqueGteTest) { EXPECT_EQ(gte2, nullptr); } +TEST_F(HloQueryTest, FindComputationTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + EXPECT_NE(hlo_query::FindComputation(module.get(), "main"), nullptr); + EXPECT_EQ(hlo_query::FindComputation(module.get(), "foo"), nullptr); +} + +TEST_F(HloQueryTest, FindInstructionUsingNameTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); + EXPECT_NE(hlo_query::FindFirstInstruction(main, "zero").first, nullptr); + EXPECT_NE(hlo_query::FindFirstInstruction(main, "five").first, nullptr); + EXPECT_NE(hlo_query::FindFirstInstruction(main, "out").first, nullptr); + EXPECT_EQ(hlo_query::FindFirstInstruction(main, "foo").first, nullptr); +} + +TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); + EXPECT_NE( + hlo_query::FindFirstInstruction(main, StringToHloOpcode("add").value()) + .first, + nullptr); + EXPECT_NE(hlo_query::FindFirstInstruction( + main, StringToHloOpcode("constant").value()) + .first, + nullptr); + EXPECT_EQ( + hlo_query::FindFirstInstruction(main, StringToHloOpcode("select").value()) + .first, + nullptr); +} + +TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); + EXPECT_NE(main, nullptr); + auto find_beef = hlo_query::FindFirstInstruction(main, "deadbeef"); + auto find_nothing = hlo_query::FindFirstInstruction(main, ""); + EXPECT_EQ(find_beef.first, nullptr); + EXPECT_EQ(find_beef.second, -1); + EXPECT_EQ(find_nothing.first, nullptr); + EXPECT_EQ(find_nothing.second, -1); +} + +TEST_F(HloQueryTest, IsBeforeInComputationTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); + EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "zero", "five")); + EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "five", "out")); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 28c9f0f82151e1..d6fe4946fbf237 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -786,8 +786,11 @@ std::optional ReshapeSharding(const Shape& source_shape, sharding_tile_dims_stack.pop_back(); } - if (s_partitions > 1 && s_size % s_partitions == 0 && - t_size % s_partitions == 0) { + if (s_size == t_size) { + // Same dimension. + append_sharding_dim(s_partitions); + } else if (s_partitions > 1 && s_size % s_partitions == 0 && + t_size % s_partitions == 0) { // If s_partitions evenly divides both s_size and t_size, we can add this // sharding dim and work on shard sized shapes in the next iteration. source_dims_stack.push_back(s_size / s_partitions); @@ -795,9 +798,6 @@ std::optional ReshapeSharding(const Shape& source_shape, sharding_tile_dims_stack.push_back(1); append_sharding_dim(s_partitions); inplace_add_sharding_dim = true; - } else if (s_size == t_size) { - // Same dimension. - append_sharding_dim(s_partitions); } else if (t_size == 1) { // Trivial dimension added. append_sharding_dim(1); diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc index 44ec60cca97172..fcbc4a4cd4bbdf 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc @@ -295,6 +295,18 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne3) { EXPECT_EQ(result.value(), output_sharding); } +TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne4) { + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 2, 1}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 2}); + HloSharding input_sharding = HloSharding::IotaTile({4, 2, 4}); + HloSharding output_sharding = + HloSharding::PartialTile(TileAssignment({4, 2, 4})); + std::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne1) { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 64}); Shape output_shape = ShapeUtil::MakeShape(F32, {1, 64}); diff --git a/third_party/xla/xla/lit.bzl b/third_party/xla/xla/lit.bzl index d6ec58096671f3..bbee57e4246e46 100644 --- a/third_party/xla/xla/lit.bzl +++ b/third_party/xla/xla/lit.bzl @@ -1,7 +1,7 @@ """Helper rules for writing LIT tests.""" load("@bazel_skylib//lib:paths.bzl", "paths") -load("//xla/tsl:tsl.bzl", "if_oss") +load("//xla/tsl:tsl.bzl", "if_hermetic_cuda_tools", "if_oss") def enforce_glob(files, **kwargs): """A utility to enforce that a list matches a glob expression. @@ -50,6 +50,7 @@ def lit_test_suite( timeout = None, default_tags = None, tags_override = None, + hermetic_cuda_data_dir = None, **kwargs): """Creates one lit test per source file and a test suite that bundles them. @@ -74,6 +75,8 @@ def lit_test_suite( timeout: timeout argument passed to the individual tests. default_tags: string list. Tags applied to all tests. tags_override: string_dict. Tags applied in addition to only select tests. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. **kwargs: additional keyword arguments to pass to all generated rules. See https://llvm.org/docs/CommandGuide/lit.html for details on lit @@ -105,6 +108,7 @@ def lit_test_suite( env = env, timeout = timeout, tags = default_tags + tags_override.get(test_file, []), + hermetic_cuda_data_dir = hermetic_cuda_data_dir, **kwargs ) @@ -114,6 +118,23 @@ def lit_test_suite( **kwargs ) +def lit_script_with_xla_gpu_cuda_data_dir( + name, + input_file, + output_file, + xla_gpu_cuda_data_dir): + """Adds a line to the LIT script to set the XLA_FLAGS environment variable.""" + return native.genrule( + name = name, + srcs = [input_file], + outs = [output_file], + cmd = if_hermetic_cuda_tools( + """echo -e '// RUN: export XLA_FLAGS=\"--xla_gpu_cuda_data_dir={}\"' > $@; +cat $< >> $@;""".format(xla_gpu_cuda_data_dir), + "cat $< >> $@;", + ), + ) + def lit_test( name, test_file, @@ -124,6 +145,7 @@ def lit_test( visibility = None, env = None, timeout = None, + hermetic_cuda_data_dir = None, **kwargs): """Runs a single test file with LLVM's lit tool. @@ -146,6 +168,8 @@ def lit_test( env: string_dict. Environment variables available during test execution. See the common Bazel test attribute. timeout: bazel test timeout string, as per common bazel definitions. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. **kwargs: additional keyword arguments to pass to all generated rules. See https://llvm.org/docs/CommandGuide/lit.html for details on lit @@ -170,12 +194,19 @@ def lit_test( tools_on_path_target_name, "lit_bin", ) + lib_dir = paths.join( + native.package_name(), + tools_on_path_target_name, + "lit_lib", + ) _tools_on_path( name = tools_on_path_target_name, testonly = True, srcs = tools, bin_dir = bin_dir, + lib_dir = lib_dir, + deps = ["//xla/stream_executor/cuda:all_runtime"], visibility = ["//visibility:private"], **kwargs ) @@ -195,6 +226,18 @@ def lit_test( ) # copybara:comment_end + + if hermetic_cuda_data_dir: + output_file = "with_xla_gpu_cuda_data_dir_{}".format(test_file) + rule_name = "script_{}".format(output_file) + lit_script_with_xla_gpu_cuda_data_dir( + rule_name, + test_file, + output_file, + hermetic_cuda_data_dir, + ) + test_file = output_file + native_test( name = name, src = lit_name, @@ -275,6 +318,22 @@ def _tools_on_path_impl(ctx): " {} and {} conflict".format(runfiles_symlinks[bin_path], exe)) runfiles_symlinks[bin_path] = exe + # The loop below symlinks the libraries that are used by the tools. + for dep in ctx.attr.deps: + linker_inputs = dep[CcInfo].linking_context.linker_inputs.to_list() + for linker_input in linker_inputs: + if len(linker_input.libraries) == 0: + continue + lib = linker_input.libraries[0].dynamic_library + if not lib: + continue + lib_path = paths.join(ctx.attr.lib_dir, lib.basename) + if lib_path in runfiles_symlinks: + fail("All libs used by lit tests must have unique basenames, as" + + " they are added to the path." + + " {} and {} conflict".format(runfiles_symlinks[lib_path], lib)) + runfiles_symlinks[lib_path] = lib + return [ DefaultInfo(runfiles = ctx.runfiles( symlinks = runfiles_symlinks, @@ -286,6 +345,8 @@ _tools_on_path = rule( attrs = { "srcs": attr.label_list(allow_files = True, mandatory = True), "bin_dir": attr.string(mandatory = True), + "lib_dir": attr.string(mandatory = True), + "deps": attr.label_list(), }, doc = "Symlinks srcs into a single lit_bin directory. All basenames must be unique.", ) diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 34d70133ab2411..4ce52706ed97a1 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -252,7 +252,7 @@ Literal::Literal(const Shape& shape) void Literal::SetShape(const Shape& shape) { Shape shape_storage; const Shape* shape_ptr = &shape; - if (LayoutUtil::HasCustomElementSizeInBits(shape)) { + if (shape.IsArray() && LayoutUtil::HasCustomElementSizeInBits(shape)) { shape_storage = shape; shape_storage.mutable_layout()->set_element_size_in_bits(0); shape_ptr = &shape_storage; diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index 36a3c263e27c36..42b4340d2ddf82 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -2583,6 +2583,14 @@ TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { EXPECT_FALSE(c1.IsKnown()); } +TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArraysS4Tuple) { + auto inner_shape = ShapeUtil::MakeShape(S4, {4, 4}); + inner_shape.mutable_layout()->set_element_size_in_bits(4); + Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays( + ShapeUtil::MakeTupleShape({inner_shape})); + EXPECT_FALSE(c1.IsKnown()); +} + TEST_F(LiteralUtilTest, CreatePartiallyKnownTuple) { Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays( ShapeUtil::MakeShape(F32, {4, 4})); diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc index b0223f3e6ed532..6716b1660fa960 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc @@ -553,7 +553,7 @@ InterpreterValue MultiReduction(InterpreterState& state, const InterpreterValue& acc) { auto element_ty = getElementTypeOrSelf(reduction->getResultTypes()[0]); return {ReductionImpl(state, source, &acc, reduction.getKind(), - ExtractVector(reduction.getReductionDims()), + SmallVector(reduction.getReductionDims()), element_ty)}; } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 0de939289537b9..df3f11a2fc03f1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -415,7 +415,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp) //===----------------------------------------------------------------------===// // Follow async operation use-def chain to find the start of the async chain. -AsyncStartOp findAsyncChainStart(Operation* op) { +static AsyncStartOp findAsyncChainStart(Operation* op) { Operation* start = op; while (start != nullptr && !isa(start)) { start = start->getOperand(0).getDefiningOp(); @@ -423,8 +423,8 @@ AsyncStartOp findAsyncChainStart(Operation* op) { return dyn_cast_or_null(start); } -Type maybeTupleFromTypes(MLIRContext* ctx, ArrayRef types, - bool expectsTuple = false) { +static Type maybeTupleFromTypes(MLIRContext* ctx, ArrayRef types, + bool expectsTuple = false) { if (!expectsTuple && types.size() == 1 && !isa(types[0])) return types[0]; return TupleType::get(ctx, TypeRange(types)); @@ -903,13 +903,31 @@ LogicalResult DotOp::verify() { //===----------------------------------------------------------------------===// LogicalResult DotGeneralOp::verify() { + bool isDefaultPrecisionConfig = + !getPrecisionConfig().has_value() || + llvm::all_of(getPrecisionConfig().value(), [](Attribute attr) { + return cast(attr).getValue() == Precision::DEFAULT; + }); + bool hasAlgorithmSpecified = getAlgorithm().has_value(); + if (hasAlgorithmSpecified) { + DotAlgorithmAttr attr = getAlgorithm().value(); + if (failed(DotAlgorithmAttr::verify( + [&] { return this->emitError(); }, attr.getLhsPrecisionType(), + attr.getRhsPrecisionType(), attr.getAccumulationType(), + attr.getLhsComponentCount(), attr.getRhsComponentCount(), + attr.getNumPrimitiveOperations(), + attr.getAllowImpreciseAccumulation()))) + return failure(); + } + return hlo::verifyDotGeneralOp( getLoc(), getLhs(), getRhs(), getDotDimensionNumbersAttr().getLhsBatchingDimensions(), getDotDimensionNumbersAttr().getRhsBatchingDimensions(), getDotDimensionNumbersAttr().getLhsContractingDimensions(), getDotDimensionNumbersAttr().getRhsContractingDimensions(), - getPrecisionConfig(), getResult()); + getPrecisionConfig(), isDefaultPrecisionConfig, hasAlgorithmSpecified, + getResult()); } LogicalResult DotGeneralOp::reifyReturnTypeShapes( @@ -949,6 +967,17 @@ LogicalResult DotGeneralOp::reifyReturnTypeShapes( return success(); } +LogicalResult DotAlgorithmAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType, + int64_t lhsComponentCount, int64_t rhsComponentCount, + int64_t numPrimitiveOperations, bool allowImpreciseAccumulation) { + return hlo::verifyDotAlgorithmAttr( + emitError, lhsPrecisionType, rhsPrecisionType, accumulationType, + lhsComponentCount, rhsComponentCount, numPrimitiveOperations, + allowImpreciseAccumulation); +} + //===----------------------------------------------------------------------===// // SparseDotOp //===----------------------------------------------------------------------===// @@ -1002,8 +1031,9 @@ LogicalResult SparseDotOp::verify() { //===----------------------------------------------------------------------===// // FftOp //===----------------------------------------------------------------------===// -LogicalResult verify1dTensor(std::optional loc, - DenseIntElementsAttr attr, std::string attrName) { +static LogicalResult verify1dTensor(std::optional loc, + DenseIntElementsAttr attr, + std::string attrName) { auto rank = attr.getType().getRank(); if (rank != 1) { return emitOptionalError(loc, attrName, " has rank ", rank, @@ -1221,8 +1251,8 @@ LogicalResult GatherOp::inferReturnTypeComponents( //===----------------------------------------------------------------------===// // Canonicalize mhlo.dynamic_gather to mhlo.gather when slice_sizes is constant. -LogicalResult simplifyDynamicGatherToGather(DynamicGatherOp op, - PatternRewriter& rewriter) { +static LogicalResult simplifyDynamicGatherToGather(DynamicGatherOp op, + PatternRewriter& rewriter) { DenseIntElementsAttr dynamicGatherSliceSizes; if (!matchPattern(op.getSliceSizes(), m_Constant(&dynamicGatherSliceSizes))) { return failure(); @@ -1633,7 +1663,7 @@ struct ConvolutionIsDot : public OpRewritePattern { op.getContext(), {}, {}, {lhsContractDim}, {rhsContractDim}); auto dotOp = rewriter.create( op.getLoc(), op.getType(), lhs, rhs, dotNums, - op.getPrecisionConfig().value_or(nullptr)); + op.getPrecisionConfig().value_or(nullptr), DotAlgorithmAttr{}); rewriter.replaceOp(op, dotOp.getResult()); return success(); @@ -1669,7 +1699,7 @@ struct ConvolutionIsDot : public OpRewritePattern { {lhsContractDim + 1}, {rhsContractDim == 0 ? 2 : 0}); auto dotOp = rewriter.create( op.getLoc(), dotTy, lhs, rhs, dotNums, - op.getPrecisionConfig().value_or(nullptr)); + op.getPrecisionConfig().value_or(nullptr), DotAlgorithmAttr{}); llvm::SmallVector perms; perms.resize(3, dNums.getOutputBatchDimension() == 0 ? 0 : 2); @@ -3371,7 +3401,7 @@ Operation* ReduceWindowOp::getReductionOp(int resultIndex) { return nullptr; } -bool isSplatZero(SplatElementsAttr attr) { +static bool isSplatZero(SplatElementsAttr attr) { if (!attr) return false; if (isa(attr.getElementType())) { return attr.getSplatValue().isZero(); @@ -3609,7 +3639,7 @@ LogicalResult ReduceOp::fold(FoldAdaptor /*adaptor*/, return failure(); } -bool hasSameOperandAndResultTypes(Operation& op) { +static bool hasSameOperandAndResultTypes(Operation& op) { Type expected; if (op.getNumResults() != 0) expected = op.getResult(0).getType(); if (op.getNumOperands() != 0) expected = op.getOperand(0).getType(); @@ -4588,9 +4618,9 @@ struct Abs { } }; -double rsqrt(double d) { return 1.0 / std::sqrt(d); } +static double rsqrt(double d) { return 1.0 / std::sqrt(d); } -double logistic(double d) { return 1.0 / (1.0 + std::exp(-d)); } +static double logistic(double d) { return 1.0 / (1.0 + std::exp(-d)); } // NOLINTBEGIN(bugprone-macro-parentheses) #define UNARY_FOLDER(Op, Func) \ @@ -4828,7 +4858,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { return {}; } -bool isSplatOne(SplatElementsAttr attr) { +static bool isSplatOne(SplatElementsAttr attr) { if (!attr) return false; if (isa(attr.getElementType())) { return attr.getSplatValue().convertToDouble() == 1.0; @@ -5756,8 +5786,8 @@ LogicalResult ScatterOp::verify() { getScatterDimensionNumbers().getIndexVectorDim(), getUpdateComputation()); } -llvm::SmallVector evaluateMhloRegion(Region& region, - ArrayRef inputs) { +static llvm::SmallVector evaluateMhloRegion( + Region& region, ArrayRef inputs) { if (region.getNumArguments() != inputs.size()) return {}; llvm::DenseMap values; @@ -6950,8 +6980,8 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName, // Each CrossProgramPrefetchAttr specifies a parameter and a ShapeIndex // (1) the parameter must be valid // (2) there must be a subshape at the given indices -LogicalResult verifyCrossProgramPrefetchAttr(CrossProgramPrefetchAttr cpp, - ModuleOp module) { +static LogicalResult verifyCrossProgramPrefetchAttr( + CrossProgramPrefetchAttr cpp, ModuleOp module) { func::FuncOp main = module.lookupSymbol("main"); if (cpp.getParameter() >= main.getNumArguments() || cpp.getParameter() < 0) return module->emitOpError() @@ -7055,7 +7085,7 @@ Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, return builder.create(loc, type, elementsAttr); } -int64_t getNumLeafBuffers(Type type) { +static int64_t getNumLeafBuffers(Type type) { if (auto tuple = dyn_cast(type)) { auto ans = 0; for (auto type : tuple.getTypes()) ans += getNumLeafBuffers(type); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index bda156bbbdc5b2..3b68ad70a332e1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2418,7 +2418,7 @@ def MHLO_ConvolutionOp : MHLO_Op<"convolution", [Pure]> { MHLO_ConvDimensionNumbers:$dimension_numbers, ConfinedAttr:$feature_group_count, ConfinedAttr:$batch_group_count, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config ); let results = (outs MHLO_Tensor); @@ -2608,7 +2608,7 @@ def MHLO_DotOp: MHLO_Op<"dot", [Pure]> { let arguments = ( ins MHLO_Tensor:$lhs, MHLO_Tensor:$rhs, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config ); let results = (outs MHLO_Tensor); // Dot op required custom exporter to pass the preferred element type @@ -2643,7 +2643,8 @@ def MHLO_DotGeneralOp: MHLO_ShapedInterfaceOp<"dot_general", [Pure]> { MHLO_Tensor:$lhs, MHLO_Tensor:$rhs, MHLO_DotDimensionNumbers:$dot_dimension_numbers, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config, + OptionalAttr:$algorithm ); let results = (outs MHLO_Tensor); @@ -2667,7 +2668,7 @@ def MHLO_SparseDotOp: MHLO_Op<"sparse_dot", [Pure]> { OptionalAttr:$lhs_sparsity, OptionalAttr:$rhs_sparsity, MHLO_DotDimensionNumbers:$dot_dimension_numbers, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config ); let results = (outs MHLO_Tensor); // SparseDot op required custom exporter to pass the preferred element type @@ -3850,7 +3851,7 @@ def MHLO_DynamicConvOp : MHLO_Op<"dynamic_conv", [Pure]> { MHLO_ConvDimensionNumbers:$dimension_numbers, ConfinedAttr:$feature_group_count, ConfinedAttr:$batch_group_count, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config ); let results = (outs MHLO_Tensor); let hasCanonicalizer = 1; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td index c43d89a34709e8..229e0d72e0437f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td @@ -53,6 +53,32 @@ def MHLO_GatherDimensionNumbers : AttrDef { + let mnemonic = "dot_algorithm"; + let summary = "Attribute that models the algorithm constraints to use for computing dot."; + let parameters = (ins + "Type":$lhsPrecisionType, + "Type":$rhsPrecisionType, + "Type":$accumulationType, + "int64_t":$lhsComponentCount, + "int64_t":$rhsComponentCount, + "int64_t":$numPrimitiveOperations, + "bool":$allowImpreciseAccumulation + ); + let assemblyFormat = [{ + `<` + `lhs_precision_type` `=` $lhsPrecisionType `,` + `rhs_precision_type` `=` $rhsPrecisionType `,` + `accumulation_type` `=` $accumulationType `,` + `lhs_component_count` `=` $lhsComponentCount `,` + `rhs_component_count` `=` $rhsComponentCount `,` + `num_primitive_operations` `=` $numPrimitiveOperations `,` + `allow_imprecise_accumulation` `=` $allowImpreciseAccumulation + `>` + }]; + let genVerifyDecl = 1; +} + def MHLO_DotDimensionNumbers : AttrDef { let mnemonic = "dot"; let summary = "Attribute that models the dimension information for dot."; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td index 25375ac741da18..3e4039ef9598ad 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td @@ -40,8 +40,7 @@ def MHLO_PrecisionAttr : EnumAttr; // TODO(b/129153247) See if it's possible to also validate the size. def MHLO_PrecisionConfigAttr: - OptionalAttr< - TypedArrayAttrBase>; + TypedArrayAttrBase; //===----------------------------------------------------------------------===// // Custom call schedule hints diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 159a95463fa72f..196f65d068365f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -258,6 +258,13 @@ Attribute convertAttr(Attribute hloAttr) { } // NOTE: We cannot process CustomCallApiVersionAttr here because // `dyn_cast()` succeeds for IntegerAttr too. + if (auto attr = mlir::dyn_cast(hloAttr)) { + return stablehlo::DotAlgorithmAttr::get( + attr.getContext(), attr.getLhsPrecisionType(), + attr.getRhsPrecisionType(), attr.getAccumulationType(), + attr.getLhsComponentCount(), attr.getRhsComponentCount(), + attr.getNumPrimitiveOperations(), attr.getAllowImpreciseAccumulation()); + } if (auto attr = mlir::dyn_cast(hloAttr)) { return stablehlo::DotDimensionNumbersAttr::get( attr.getContext(), attr.getLhsBatchingDimensions(), diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc index bfeeaed83f89d1..e986bdc5ad694c 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc @@ -55,9 +55,9 @@ struct DotToDotGeneralPattern : public OpRewritePattern { /*lhsContractingDimensions=*/{lhs.getType().getRank() - 1}, /*rhsContractingDimensions=*/{0}); - rewriter.replaceOpWithNewOp(dotOp, dotOp.getType(), lhs, rhs, - dotDimensionNumbers, - dotOp.getPrecisionConfigAttr()); + rewriter.replaceOpWithNewOp( + dotOp, dotOp.getType(), lhs, rhs, dotDimensionNumbers, + dotOp.getPrecisionConfigAttr(), DotAlgorithmAttr{}); return success(); } }; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc index f8c0f9eafd7c83..c35ce560146dcb 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc @@ -159,7 +159,7 @@ struct EinsumToDotGeneralPattern : public OpRewritePattern { auto dotGeneralOp = rewriter.create( einsum.getLoc(), dotGeneralResultType, einsum.getLhs(), einsum.getRhs(), dimNumbers, - /*precision_config=*/ArrayAttr{}); + /*precision_config=*/ArrayAttr{}, /*dot_algorithm=*/DotAlgorithmAttr{}); if (isNaturalOrder) { // The dot_general is already in an appropriate result order. diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index cd94cb58733b33..7570d34ace0bc1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -88,6 +88,13 @@ Attribute convertAttr(Attribute stablehloAttr) { mlir::dyn_cast(stablehloAttr)) { RETURN_CONVERTED_ENUM_ATTR(CustomCallApiVersion); } + if (auto attr = mlir::dyn_cast(stablehloAttr)) { + return mhlo::DotAlgorithmAttr::get( + attr.getContext(), attr.getLhsPrecisionType(), + attr.getRhsPrecisionType(), attr.getAccumulationType(), + attr.getLhsComponentCount(), attr.getRhsComponentCount(), + attr.getNumPrimitiveOperations(), attr.getAllowImpreciseAccumulation()); + } if (auto attr = mlir::dyn_cast(stablehloAttr)) { return mhlo::DotDimensionNumbersAttr::get( diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index c25e8ff27fe486..90965f06086831 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -802,6 +802,45 @@ func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) func.return %0 : tensor<8x8x8xf32> } +// CHECK-LABEL: "op_dot_general_algorithm" +func.func @op_dot_general_algorithm(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "stablehlo.dot_general"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ + // CHECK-SAME: algorithm = #stablehlo.dot_algorithm< + // CHECK-SAME: lhs_precision_type = tf32, + // CHECK-SAME: rhs_precision_type = tf32, + // CHECK-SAME: accumulation_type = f32, + // CHECK-SAME: lhs_component_count = 1, + // CHECK-SAME: rhs_component_count = 1, + // CHECK-SAME: num_primitive_operations = 1, + // CHECK-SAME: allow_imprecise_accumulation = false + // CHECK-SAME: >, + // CHECK-SAME: dot_dimension_numbers = #stablehlo.dot< + // CHECK-SAME: lhs_batching_dimensions = [0], + // CHECK-SAME: rhs_batching_dimensions = [0], + // CHECK-SAME: lhs_contracting_dimensions = [2], + // CHECK-SAME: rhs_contracting_dimensions = [1] + // CHECK-SAME: > + // CHECK-SAME: }> : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + algorithm = #mhlo.dot_algorithm< + lhs_precision_type = tf32, + rhs_precision_type = tf32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + // CHECK-LABEL: "op_dot" func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { // CHECK: "stablehlo.dot"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 3c47a056eb638d..0f2e1b108a710f 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -786,6 +786,45 @@ func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) func.return %0 : tensor<8x8x8xf32> } +// CHECK-LABEL: "op_dot_general_algorithm" +func.func @op_dot_general_algorithm(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "mhlo.dot_general"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ + // CHECK-SAME: algorithm = #mhlo.dot_algorithm< + // CHECK-SAME: lhs_precision_type = tf32, + // CHECK-SAME: rhs_precision_type = tf32, + // CHECK-SAME: accumulation_type = f32, + // CHECK-SAME: lhs_component_count = 1, + // CHECK-SAME: rhs_component_count = 1, + // CHECK-SAME: num_primitive_operations = 1, + // CHECK-SAME: allow_imprecise_accumulation = false + // CHECK-SAME: >, + // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< + // CHECK-SAME: lhs_batching_dimensions = [0], + // CHECK-SAME: rhs_batching_dimensions = [0], + // CHECK-SAME: lhs_contracting_dimensions = [2], + // CHECK-SAME: rhs_contracting_dimensions = [1] + // CHECK-SAME: > + // CHECK-SAME: }> : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = tf32, + rhs_precision_type = tf32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + // CHECK-LABEL: "op_dot" func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { // CHECK: "mhlo.dot"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ diff --git a/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc b/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc index 2562c060e82f9b..0c8a4de9e98454 100644 --- a/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc +++ b/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc @@ -100,9 +100,10 @@ SmallVector calcMultiDimIndex(OpBuilder& b, Location loc, return calcMultiDimIndex(b, loc, linearIndex, shapeVec); } -SmallVector calcMultiDimIndexForFirstOperand(OpBuilder& b, Location loc, - Value linearIndex, - Operation* op) { +static SmallVector calcMultiDimIndexForFirstOperand(OpBuilder& b, + Location loc, + Value linearIndex, + Operation* op) { assert(op->getDialect()->getNamespace() == "lmhlo"); Value operandMemref = op->getOperand(0); return calcMultiDimIndex(b, loc, linearIndex, operandMemref); diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index b73a59f4cc4293..adf30e0833ba0f 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -311,6 +311,10 @@ cc_library( deps = [ ":pjrt_common", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index f3ae2c55c4b042..c8896b77f4d019 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -149,6 +149,9 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/backends/cpu/runtime:buffer_allocations", + "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:thunk_executor", "//xla/client:executable_build_options", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", @@ -185,12 +188,10 @@ cc_library( "//xla/service/cpu:cpu_runtime", "//xla/service/cpu:cpu_xfeed", "//xla/service/cpu:simple_orc_jit", - "//xla/service/cpu/runtime:buffer_allocations", - "//xla/service/cpu/runtime:thunk", - "//xla/service/cpu/runtime:thunk_executor", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", @@ -208,7 +209,6 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool. "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:denormal", "@local_tsl//tsl/platform:env", @@ -306,7 +306,7 @@ cc_library( xla_cc_test( name = "gloo_collectives_test", srcs = ["gloo_collectives_test.cc"], - tags = ["nomac"], + linkstatic = True, deps = [ ":gloo_collectives", ":gloo_kv_store", @@ -321,13 +321,21 @@ xla_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@gloo//:transport_tcp", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", - ], + ] + select({ + # Gloo's transport_tcp is not available on MacOS + "//xla/tsl:macos": [ + "@gloo//:transport_uv", + ], + "//conditions:default": [ + "@gloo//:transport_tcp", + ], + }), ) cc_library( diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 04d9b79ea5186a..65ed8589f4f2b6 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -47,6 +47,9 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "mlir/IR/BuiltinOps.h" #include "xla/array.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" @@ -83,9 +86,6 @@ limitations under the License. #include "xla/service/cpu/cpu_executable_run_options.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/service/cpu/cpu_xfeed.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" @@ -103,10 +103,10 @@ limitations under the License. #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/casts.h" #include "tsl/platform/denormal.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc b/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc index 0b2fd8d3c66e82..b8bb7810dd3909 100644 --- a/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc +++ b/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc @@ -25,8 +25,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/span.h" +#if defined(__linux__) #include "gloo/transport/tcp/attr.h" #include "gloo/transport/tcp/device.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#endif // defined(__linux__) #include "xla/executable_run_options.h" #include "xla/pjrt/cpu/gloo_kv_store.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" @@ -57,7 +61,11 @@ absl::StatusOr> GetCommunicator( const std::shared_ptr& kv_store, int rank) { auto collectives = std::make_shared( std::make_unique(kv_store), +#if defined(__linux__) gloo::transport::tcp::CreateDevice(gloo::transport::tcp::attr())); +#elif defined(__APPLE__) + gloo::transport::uv::CreateDevice(gloo::transport::uv::attr())); +#endif // defined(__linux__) return collectives->GetCommunicator(global_devices, rank); } diff --git a/third_party/xla/xla/pjrt/distributed/client.cc b/third_party/xla/xla/pjrt/distributed/client.cc index d0d96c6c511d9f..ede5e27b860f0d 100644 --- a/third_party/xla/xla/pjrt/distributed/client.cc +++ b/third_party/xla/xla/pjrt/distributed/client.cc @@ -92,6 +92,8 @@ DistributedRuntimeCoordinationServiceClient:: absl::ToInt64Milliseconds(options.shutdown_timeout)); config.set_agent_destruction_without_shutdown( !options.shutdown_on_destruction); + config.set_poll_for_error_from_service_at_startup( + options.poll_for_error_from_service_at_startup); auto error_fn = [timeout_fn = options.missed_heartbeat_callback]( const absl::Status& status) { LOG(ERROR) << "Coordination service agent in error status: " << status; diff --git a/third_party/xla/xla/pjrt/distributed/client.h b/third_party/xla/xla/pjrt/distributed/client.h index 79973124485452..2387fe6dd452f5 100644 --- a/third_party/xla/xla/pjrt/distributed/client.h +++ b/third_party/xla/xla/pjrt/distributed/client.h @@ -101,6 +101,12 @@ class DistributedRuntimeClient { // For testing. Should the client explicitly Shutdown() on destruction? bool shutdown_on_destruction = true; + + // Whether the client should send a request to wait for error from the + // coordination service at the startup. + // TODO(b/355706798): Enable this by default once we confirm this works for + // all cases and eventually remove this option. + bool poll_for_error_from_service_at_startup = false; }; virtual ~DistributedRuntimeClient() = default; diff --git a/third_party/xla/xla/pjrt/distributed/client_server_test.cc b/third_party/xla/xla/pjrt/distributed/client_server_test.cc index 8c04e7608ec41d..dfd46be79b29bc 100644 --- a/third_party/xla/xla/pjrt/distributed/client_server_test.cc +++ b/third_party/xla/xla/pjrt/distributed/client_server_test.cc @@ -424,6 +424,116 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { } } +TEST_F(ClientServerTest, + ClientsTerminateShutdownIfAnyClientGoesAway_WithErrorPolling) { + int num_nodes = 3; + StartService(num_nodes); + + auto thread_fn = [&](int node_id) -> absl::Status { + DistributedRuntimeClient::Options client_options; + client_options.shutdown_on_destruction = node_id != 0; + client_options.missed_heartbeat_callback = + [&](absl::Status status, bool coordinator_initiated) {}; + client_options.poll_for_error_from_service_at_startup = true; + auto client = GetClient(node_id, client_options); + + TF_RETURN_IF_ERROR(client->Connect()); + + if (node_id == 0) { + return absl::OkStatus(); + } + + // The call to Shutdown() should be interrupted if a worker stops issuing + // heartbeats. + return client->Shutdown(); + }; + + std::vector statuses(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", + num_nodes); + for (int i = 0; i < num_nodes; ++i) { + thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); }); + } + } + TF_EXPECT_OK(statuses[0]); + for (int i = 1; i < num_nodes; ++i) { + // The error type depends on whether the node turns into ERROR state during + // or before the shutdown call. + EXPECT_TRUE(absl::IsInternal(statuses[i]) || + absl::IsFailedPrecondition(statuses[i])); + } +} + +TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) { + int num_nodes = 3; + StartService(num_nodes); + + auto thread_fn = [&](int node_id) -> absl::Status { + DistributedRuntimeClient::Options client_options; + client_options.shutdown_on_destruction = true; + client_options.missed_heartbeat_callback = + [&](absl::Status status, bool coordinator_initiated) {}; + client_options.poll_for_error_from_service_at_startup = true; + auto client = GetClient(node_id, client_options); + + TF_RETURN_IF_ERROR(client->Connect()); + return client->Shutdown(); + // The error polling request will be cancelled automatically when the + // client is shutting down. + }; + + std::vector statuses(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", + num_nodes); + for (int i = 0; i < num_nodes; ++i) { + thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); }); + } + } + for (int i = 0; i < num_nodes; ++i) { + TF_EXPECT_OK(statuses[i]); + } +} + +TEST_F(ClientServerTest, + MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway_WithErrorPolling) { + int num_nodes = 3; + StartService(num_nodes); + + auto thread_fn = [&](int node_id) -> absl::Status { + DistributedRuntimeClient::Options client_options; + client_options.shutdown_on_destruction = (node_id != 0); + absl::Notification shutdown; + client_options.missed_heartbeat_callback = [&](absl::Status status, + bool coordinator_initiated) { + shutdown.Notify(); + }; + client_options.poll_for_error_from_service_at_startup = true; + auto client = GetClient(node_id, client_options); + + TF_RETURN_IF_ERROR(client->Connect()); + + if (node_id == 0) { + return absl::OkStatus(); + } + shutdown.WaitForNotification(); + return absl::OkStatus(); + }; + + std::vector statuses(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", + num_nodes); + for (int i = 0; i < num_nodes; ++i) { + thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); }); + } + } + for (int i = 0; i < num_nodes; ++i) { + TF_EXPECT_OK(statuses[i]); + } +} + TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { int num_nodes = 3; StartService(num_nodes); diff --git a/third_party/xla/xla/pjrt/distributed/service.cc b/third_party/xla/xla/pjrt/distributed/service.cc index 238e146bf044ec..6a8a77a5fca534 100644 --- a/third_party/xla/xla/pjrt/distributed/service.cc +++ b/third_party/xla/xla/pjrt/distributed/service.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include "absl/time/clock.h" #include "absl/time/time.h" diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 90ab11efde4827..13599a42d3054a 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -91,6 +91,7 @@ cc_library( "//xla/tsl/framework:bfc_allocator", "//xla/tsl/framework:device_id", "//xla/tsl/framework:device_id_impl", + "//xla/tsl/lib/strings:proto_serialization", "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -108,7 +109,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -152,6 +152,7 @@ xla_cc_test( "//xla:shape_util", "//xla:status_macros", "//xla:test", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_computation", "//xla/ffi", @@ -175,6 +176,7 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 27d3f18dbc72cb..79f8d7db00c93f 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -82,7 +82,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/allocator.h" -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" @@ -122,17 +122,28 @@ class AsyncHostToDeviceTransferManager : public xla::PjRtClient::AsyncHostToDeviceTransferManager { public: static absl::StatusOr> - Create(absl::Span shapes, PjRtStreamExecutorDevice* device, - PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space) { + Create(absl::Span shape_specs, + std::optional> device_layouts, + PjRtStreamExecutorDevice* device, PjRtStreamExecutorClient* client, + PjRtMemorySpace* memory_space) { + if (device_layouts != std::nullopt && + device_layouts->size() != shape_specs.size()) { + return InvalidArgument( + "Number of layouts %d does not match the number of shapes %d", + device_layouts->size(), shape_specs.size()); + } absl::InlinedVector, 4> buffers; absl::InlinedVector, 4> buffer_ptrs; absl::InlinedVector, 4> definition_events; - buffers.reserve(shapes.size()); - buffer_ptrs.reserve(shapes.size()); - definition_events.reserve(shapes.size()); - for (const auto& shape : shapes) { - if (shape.IsTuple()) { + absl::InlinedVector device_shapes; + buffers.reserve(shape_specs.size()); + buffer_ptrs.reserve(shape_specs.size()); + definition_events.reserve(shape_specs.size()); + device_shapes.reserve(shape_specs.size()); + for (int i = 0; i < shape_specs.size(); ++i) { + const PjRtClient::ShapeSpec& shape_spec = shape_specs[i]; + if (shape_spec.element_type == TUPLE) { return Unimplemented( "Async buffer transfer of tuples not implemented."); } @@ -140,16 +151,22 @@ class AsyncHostToDeviceTransferManager // event will block the buffer usage until the transfer is done. definition_events.push_back( std::make_shared(client->thread_pool())); - TF_ASSIGN_OR_RETURN(Shape compact_shape, - client->client() - ->backend() - .transfer_manager() - ->ChooseCompactLayoutForShape(shape)); + Shape& device_shape = device_shapes.emplace_back( + ShapeUtil::MakeShape(shape_spec.element_type, shape_spec.dims)); + if (device_layouts == std::nullopt) { + TF_ASSIGN_OR_RETURN(device_shape, + client->client() + ->backend() + .transfer_manager() + ->ChooseCompactLayoutForShape(device_shape)); + } else { + *device_shape.mutable_layout() = (*device_layouts)[i]; + } LocalDeviceState* local_device = device->local_device_state(); se::Stream* h2d_stream = local_device->host_to_device_stream(); TF_ASSIGN_OR_RETURN(auto buffer, AllocateDestinationBuffer( - compact_shape, device, local_device, h2d_stream, + device_shape, device, local_device, h2d_stream, /*is_uninitialized_create=*/true, client, definition_events.back(), memory_space)); // Get a temporary hold just so we can fish out a shared_ptr to the @@ -167,7 +184,7 @@ class AsyncHostToDeviceTransferManager return std::make_unique( std::move(buffers), std::move(buffer_ptrs), - std::move(definition_events), device); + std::move(definition_events), std::move(device_shapes), device); } AsyncHostToDeviceTransferManager( @@ -175,10 +192,12 @@ class AsyncHostToDeviceTransferManager absl::InlinedVector, 4> buffer_ptrs, absl::InlinedVector, 4> definition_events, + absl::InlinedVector device_shapes, PjRtStreamExecutorDevice* device) : buffers_(std::move(buffers)), buffer_ptrs_(std::move(buffer_ptrs)), definition_events_(std::move(definition_events)), + device_shapes_(std::move(device_shapes)), remaining_buffer_count_(buffer_ptrs_.size()), transfers_in_flight_(0), device_(device) { @@ -229,9 +248,6 @@ class AsyncHostToDeviceTransferManager TransferManager* transfer_manager = se_client->client()->backend().transfer_manager(); - TF_ASSIGN_OR_RETURN( - Shape compact_shape, - transfer_manager->ChooseCompactLayoutForShape(literal.shape())); std::shared_ptr buffer; { @@ -256,16 +272,6 @@ class AsyncHostToDeviceTransferManager } DCHECK_EQ(buffer->device_memory().size(), 1); - auto& buffer_memory = buffer->device_memory()[0]; - if (transfer_manager->GetByteSizeRequirement(compact_shape) != - buffer_memory.size()) { - return InvalidArgument( - "TransferLiteralToBuffer shape %s has size %lld " - "but buffer has size %lld", - ShapeUtil::HumanStringWithLayout(compact_shape), - transfer_manager->GetByteSizeRequirement(compact_shape), - buffer_memory.size()); - } ++transfers_in_flight_; } @@ -274,7 +280,7 @@ class AsyncHostToDeviceTransferManager // TODO(misard) assess if it would be preferable to introduce a heuristic to // put the transfer into the calling thread for small literals. auto transfer_h2d = [this, buffer_index, stream, transfer_manager, literal, - device_buffer = buffer.get(), compact_shape, + device_buffer = buffer.get(), local_device = std::move(device_->local_device_state()), on_done = std::move(on_done)]() mutable { @@ -285,7 +291,8 @@ class AsyncHostToDeviceTransferManager auto event = local_device->event_pool().AllocateEvent(stream->parent()); // Initiate linearization and transfer of the buffer on the stream. - ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape); + ShapedBuffer buffer = + device_buffer->AsShapedBuffer(device_shapes_[buffer_index]); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( stream, literal, buffer)); local_device->event_pool().ThenRecordEvent(stream, event.value()); @@ -449,6 +456,8 @@ class AsyncHostToDeviceTransferManager // corresponding buffer transfer has completed. absl::InlinedVector, 4> definition_events_ ABSL_GUARDED_BY(mu_); + // Device shapes for all buffers with either compact or custom layout. + const absl::InlinedVector device_shapes_; // Count of buffers that have not yet been fully transferred. size_t remaining_buffer_count_ ABSL_GUARDED_BY(mu_); // Count of transfers that have been started but have not yet called cleanup. @@ -544,22 +553,56 @@ absl::string_view StreamExecutorGpuClient::platform_version() const { absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( - absl::Span shapes, PjRtDevice* device) { + absl::Span shape_specs, + std::optional> device_layouts, + PjRtDevice* device) { auto* stream_executor_device = tensorflow::down_cast(device); return xla::AsyncHostToDeviceTransferManager::Create( - shapes, stream_executor_device, this, /*memory_space=*/nullptr); + shape_specs, std::move(device_layouts), stream_executor_device, this, + /*memory_space=*/nullptr); } absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( - absl::Span shapes, PjRtMemorySpace* memory_space) { + absl::Span shapes, PjRtDevice* device) { + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(PjRtClient::ShapeSpec{ + shape.element_type(), + DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, device); +} + +absl::StatusOr> +StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtMemorySpace* memory_space) { CHECK_EQ(memory_space->devices().size(), 1); PjRtDevice* device = memory_space->devices()[0]; auto* stream_executor_device = tensorflow::down_cast(device); return xla::AsyncHostToDeviceTransferManager::Create( - shapes, stream_executor_device, this, memory_space); + shape_specs, std::move(device_layouts), stream_executor_device, this, + memory_space); +} + +absl::StatusOr> +StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( + absl::Span shapes, PjRtMemorySpace* memory_space) { + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(PjRtClient::ShapeSpec{ + shape.element_type(), + DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, memory_space); } absl::StatusOr @@ -1207,7 +1250,6 @@ absl::StatusOr> GetStreamExecutorGpuClient( auto host_memory_allocator = GetGpuHostAllocator(local_device_states.begin()->second->executor()); - std::vector> devices; auto gpu_run_options = std::make_unique(); if (options.enable_mock_nccl) { gpu_run_options->set_enable_mock_nccl_collectives(); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index afb624b248f863..a481e9a59ea73d 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -207,11 +207,22 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { int num_replicas, int num_partitions) const override; absl::string_view platform_version() const override; + absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtDevice* device) override; absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtDevice* device) override; + absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtMemorySpace* memory_space) override; + absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtMemorySpace* memory_space) override; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index e034e83efd5893..54bfaf5c4b61d0 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -35,9 +35,11 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "xla/client/xla_computation.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" @@ -58,6 +60,7 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -405,6 +408,54 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsync) { literal->Relayout(src_literal.shape().layout()).data()); } +TEST(StreamExecutorGpuClientTest, ToLiteralAsyncWithNonCompactLayout) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + ASSERT_GE(client->addressable_devices().size(), 1); + + xla::Shape transposed_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + xla::S32, {2, 3}, /*minor_to_major=*/{0, 1}); + xla::Literal src_literal = xla::LiteralUtil::CreateR2WithLayout( + {{3, 14, 25}, {36, 47, 58}}, transposed_shape.layout()); + + PjRtClient::ShapeSpec spec; + spec.element_type = src_literal.shape().element_type(); + spec.dims = DimensionVector(src_literal.shape().dimensions().begin(), + src_literal.shape().dimensions().end()); + TF_ASSERT_OK_AND_ASSIGN( + auto transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + {spec}, + std::make_optional>( + {transposed_shape.layout()}), + client->addressable_devices()[0]->memory_spaces()[0])); + auto buffer = transfer_manager->RetrieveBuffer(0); + + absl::Mutex mu; + auto literal = std::make_shared( + ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape())); + bool got_literal = false; + + TF_ASSERT_OK( + transfer_manager->TransferLiteralToBuffer(0, src_literal, [&]() {})); + + buffer->ToLiteral(literal.get()).OnReady([&](absl::Status s) { + absl::MutexLock l(&mu); + TF_ASSERT_OK(s); + got_literal = true; + }); + buffer.reset(); + + { + absl::MutexLock l(&mu); + mu.Await(absl::Condition(&got_literal)); + } + + ASSERT_TRUE(ShapeUtil::Compatible(src_literal.shape(), literal->shape())); + ASSERT_EQ(src_literal.data(), + literal->Relayout(src_literal.shape().layout()).data()); +} + TEST(StreamExecutorGpuClientTest, ToLiteralAsyncBeforeBufferReady) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index 2f6c7a8b515792..e8607d23dd6709 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -494,6 +494,11 @@ struct PjRtPluginAttributes { // will eventually be able to make progress. class PjRtClient { public: + struct ShapeSpec { + PrimitiveType element_type; + DimensionVector dims; + }; + PjRtClient() = default; explicit PjRtClient(std::unique_ptr host_memory_for_device_manager) @@ -747,6 +752,32 @@ class PjRtClient { virtual void AddTransferMetadata(const TransferMetadata& metadata) = 0; }; + // Returns a manager for async transfers into a set of buffers with on-host + // shapes defined by 'shape_specs' and optional `device_layouts`. The + // `device_layout` is used when non-compact layouts are preferred. + virtual absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtDevice* device) { + return absl::UnimplementedError(absl::StrCat( + "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is " + "not implemented on platform: ", + platform_name())); + } + + // Variant of CreateBuffersForAsyncHostToDevice with PjRtMemorySpace. + virtual absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtMemorySpace* memory_space) { + return absl::UnimplementedError(absl::StrCat( + "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is " + "not implemented on platform: ", + platform_name())); + } + // Returns a manager for async transfers into a set of buffers with on-host // shapes 'shapes'. virtual absl::StatusOr> diff --git a/third_party/xla/xla/pjrt/pjrt_device_description.h b/third_party/xla/xla/pjrt/pjrt_device_description.h index ed852699e404c5..77107fdc495c71 100644 --- a/third_party/xla/xla/pjrt/pjrt_device_description.h +++ b/third_party/xla/xla/pjrt/pjrt_device_description.h @@ -20,12 +20,35 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/pjrt/pjrt_common.h" namespace xla { using PjRtDeviceAttribute = PjRtValueType; +class PjRtMemorySpaceDescription { + public: + PjRtMemorySpaceDescription(absl::string_view kind, int kind_id) + : kind_(kind), kind_id_(kind_id) {} + + // A platform-dependent string that uniquely identifies the kind of the + // memory space. + absl::string_view kind() const { return kind_; } + + // An ID uniquely identifies the kind of the memory space among those attached + // to the same `PjRtClient`. The IDs assigned to a kind is implementation + // specific. + int kind_id() const { return kind_id_; } + + private: + absl::string_view kind_; + int kind_id_; +}; + class PjRtDeviceDescription { public: virtual ~PjRtDeviceDescription() = default; @@ -60,6 +83,19 @@ class PjRtDeviceDescription { // reference will remain valid for the lifetime of the PjRtDevice. virtual const absl::flat_hash_map& Attributes() const = 0; + + // Returns all memory spaces attached to this device. + // The memory spaces are in no particular order. + virtual absl::Span memory_spaces() + const { + return {}; + } + + // Returns the default memory space attached to this device. + virtual absl::StatusOr + default_memory_space() const { + return absl::UnimplementedError("default_memory_space Not implemented."); + } }; } // namespace xla diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index f4446d410269e9..d9c04626523e31 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -427,9 +427,9 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", ] + if_cuda([ "@local_config_cuda//cuda:cuda_headers", - # TODO(b/324133505): remove this dependency after JAX OSS migrates to cuda plugin. - "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm([ + # keep sorted + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ]) + if_cuda_or_rocm([ ":py_client_gpu", # TODO(b/337876408): remove after migration to plugin @@ -754,11 +754,13 @@ cc_library( "@com_google_absl//absl/types:span", "@nanobind", "@local_config_python//:python_headers", # buildcleaner: keep + "//xla:shape_util", "//xla:util", "//xla/pjrt:exceptions", "//xla/pjrt:lru_cache", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_future", + "//xla/pjrt:pjrt_layout", "//xla/pjrt:status_casters", "//xla/python/ifrt", "//xla/tsl/concurrency:ref_count", @@ -1170,7 +1172,7 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:name_uniquer", "//xla/service:tuple_simplifier", - "@local_tsl//tsl/lib/strings:proto_serialization", + "//xla/tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -1358,8 +1360,12 @@ tsl_pybind_extension( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform/cloud:gcs_file_system", ] + select({ - # gloo transport only builds on linux - "//xla/tsl:macos": [], + # gloo tcp transport only builds on linux + "//xla/tsl:macos": [ + "//xla/pjrt/cpu:gloo_collectives", + "//xla/pjrt/cpu:gloo_kv_store", + "@gloo//:transport_uv", + ], "//xla/tsl:windows": [], "//conditions:default": [ "//xla/pjrt/cpu:gloo_collectives", diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index a3cb0ea0fb4d0a..6b751b0b079533 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -117,6 +117,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@local_tsl//tsl/lib/gtl:int_type", diff --git a/third_party/xla/xla/python/ifrt/ir/constants.h b/third_party/xla/xla/python/ifrt/ir/constants.h index 27e9d11fb6a1cf..26f8a7e999dd52 100644 --- a/third_party/xla/xla/python/ifrt/ir/constants.h +++ b/third_party/xla/xla/python/ifrt/ir/constants.h @@ -44,6 +44,15 @@ inline constexpr llvm::StringLiteral kIfrtLocalViewAttrName = "ifrt.local_view"; inline constexpr llvm::StringLiteral kIfrtCompileOptionsKey = "ifrt.compile_options_key"; +inline constexpr llvm::StringLiteral kIfrtDevicesAttrName = "ifrt.devices"; +inline constexpr llvm::StringLiteral kIfrtNumDevicesAttrName = + "ifrt.num_devices"; +inline constexpr llvm::StringLiteral kIfrtShardingAttrName = "ifrt.sharding"; +inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName = + "ifrt.entry_function"; + +inline constexpr llvm::StringLiteral kCalleeMainFuncName = "main"; + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td index e46ff35490c2a1..a430bcb38f1b41 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td @@ -95,7 +95,7 @@ def Ifrt_UnspecifiedShardingAttr : AttrDef { let mnemonic = "interval"; - let summary = [{ + let description = [{ Half-open interval attribute using the Python slice format `[start:end:step]`. Reverse iteration is not supported for simplicity. Therefore, `start` and `end` must be zero or positive, and `step` @@ -133,7 +133,7 @@ def Ifrt_MappingAttrArrayAttr : def Ifrt_ArrayMappingAttr : AttrDef { let mnemonic = "array_mapping"; - let summary = [{ + let description = [{ Mapping of shards from an input array to an output array. The shards are chosen from input array with index `in_array_index` and are used to assemble the output array with index `out_array_index`. diff --git a/third_party/xla/xla/python/ifrt/ir/tests/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/BUILD index e1faab03f4ee64..01ca1bff5c92e8 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/BUILD @@ -12,6 +12,7 @@ lit_test_suite( [ "ifrt_duplicated_callee_elimination.mlir", "ifrt_merge_reshards.mlir", + "ifrt_outline_atom_program_to_module.mlir", "ifrt_verify_donation.mlir", "ifrt_verify_sharding_specified.mlir", "spmd_expansion.mlir", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir new file mode 100644 index 00000000000000..c963b4ccb7a604 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir @@ -0,0 +1,247 @@ +// RUN: ifrt-opt %s -ifrt-outline-atom-program-to-module -split-input-file -verify-diagnostics | FileCheck %s + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @call_hlo +module @call_hlo { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + {ifrt.compile_options_key = "fake_compile_options_key"} + : (!array) -> !array + return %0 : !array + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @calls_share_a_module +module @calls_share_a_module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUTPUT:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[MODULE:.+]]::@main(%[[OUTPUT]]) + %1, %ctrl_1 = ifrt.Call @add_one(%0) on devices [0,1] : (!array) -> !array + return %1 : !array + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @calls_with_ctrl_dep_share_a_module +module @calls_with_ctrl_dep_share_a_module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUTPUT:.+]], %[[CTRL_0:.+]] = ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[MODULE:.+]]::@main(%[[OUTPUT]]) after %[[CTRL_0]] + %1, %ctrl_1 = ifrt.Call @add_one(%0) after %ctrl_0 on devices [0,1] + : (!array) -> !array + return %1 : !array + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0,1]> +// CHECK-LABEL: @call_with_diff_sharding_share_a_module +module @call_with_diff_sharding_share_a_module { + func.func @main(%arg0: !array) -> !array_unspecified + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0, 1] + : (!array) -> !array + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%[[OUT_0]]) + %1, %ctrl_1 = ifrt.Call @add_one(%0) on devices [0, 1] + : (!array) -> !array_unspecified + // CHECK: return %[[OUT_1]] + return %1 : !array_unspecified + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [2,3]> + +// CHECK-LABEL: @call_with_diff_devices_share_a_module +module @call_with_diff_devices_share_a_module { + func.func @main(%arg0: !array0, %arg1: !array1) -> (!array0, !array1) + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0, 1] + : (!array0) -> !array0 + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%arg1) + %1, %ctrl_1 = ifrt.Call @add_one(%arg1) on devices [2, 3] + : (!array1) -> !array1 + // CHECK: return %[[OUT_0]], %[[OUT_1]] + return %0, %1 : !array0, !array1 + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @shared_func_is_cloned +module @shared_func_is_cloned { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.Call @[[MODULE1:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[MODULE2:.+]]::@main(%[[OUT]]) + %1, %ctrl_1 = ifrt.Call @add_two(%0) on devices [0,1] : (!array) -> !array + return %1 : !array + } + + func.func private @add_one_internal(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + + // CHECK: module @[[MODULE1]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + // CHECK: func.func private @add_one_internal + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = func.call @add_one_internal(%arg0) : (tensor<2x2xi32>) -> (tensor<2x2xi32>) + return %0 : tensor<2x2xi32> + } + + // CHECK: module @[[MODULE2]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + // CHECK: func.func private @add_one_internal + func.func private @add_two(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = func.call @add_one_internal(%arg0) : (tensor<2x2xi32>) -> (tensor<2x2xi32>) + %1 = func.call @add_one_internal(%0) : (tensor<2x2xi32>) -> (tensor<2x2xi32>) + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +// CHECK-LABEL: @callee_with_symbol +module @callee_with_symbol { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [2] + : (!array) -> !array + return %0 : !array + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<2> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 {attr_sym = @add_two}: tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } + + // CHECK: func.func private @add_two + // CHECK-NEXT: mhlo.constant + // CHECK-NEXT: mhlo.add + func.func private @add_two(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<2> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +module @unknown_symbol_in_callee { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [2] : (!array) -> !array + return %0 : !array + } + + func.func private @add_one(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + // expected-error @+1 {{'mhlo.add' op uses a symbol in attributes `unknown` that does not exist in the ModuleOp}} + %1 = mhlo.add %arg0, %0 {f = @unknown} : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +module @wrong_type_for_symbol_in_callee { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [2] : (!array) -> !array + return %0 : !array + } + + func.func private @add_one(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + // expected-error @+1 {{'mhlo.add' op uses a symbol in attributes `a_module` that is not a FuncOp. Cannot handle such cases for now}} + %1 = mhlo.add %arg0, %0 {f = @a_module} : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + + module @a_module {} +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir index 28a1dda2b3f77d..4fef0876dc8bb8 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir @@ -3,7 +3,7 @@ #device = #ifrt #sharding = #ifrt.sharding_param<2x1 to [0] on 2> // CHECK-LABEL: @identity_axis0_sharded -module @identity_axis0_sharded attributes {ifrt.devices = #device} { +module @identity_axis0_sharded attributes {ifrt.num_devices = 2} { // CHECK-NEXT: func.func @main // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32> // CHECK-NEXT: return %[[ARG]] @@ -23,7 +23,7 @@ module @identity_axis0_sharded attributes {ifrt.devices = #device} { #sharding = #ifrt.sharding_param<1x2 to [0] on 2> // CHECK-LABEL: @identity_axis1_sharded module @identity_axis1_sharded - attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} { + attributes {ifrt.num_devices = 2, ifrt.entry_function = "entry_func"} { // CHECK-NEXT: func.func @entry_func // CHECK-SAME: %[[ARG:.*]]: tensor<2x1xi32> // CHECK-NEXT: return %[[ARG]] @@ -42,7 +42,7 @@ module @identity_axis1_sharded #device = #ifrt #sharding = #ifrt.sharding_param<3x2 to [1,0] on 2x3> // CHECK-LABEL: @identify_both_axes_sharded -module @identify_both_axes_sharded attributes {ifrt.devices = #device} { +module @identify_both_axes_sharded attributes {ifrt.num_devices = 6} { // CHECK-NEXT: func.func @main // CHECK-SAME: %[[ARG:.*]]: tensor<1x1xi32> // CHECK-NEXT: return %[[ARG]] @@ -60,7 +60,7 @@ module @identify_both_axes_sharded attributes {ifrt.devices = #device} { #device = #ifrt // CHECK-LABEL: @with_func_call -module @with_func_call attributes {ifrt.devices = #device} { +module @with_func_call attributes {ifrt.num_devices = 2} { // CHECK-NEXT: func.func @main // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32> // CHECK-SAME: tensor<1x2xi32> @@ -94,7 +94,7 @@ module @with_func_call attributes {ifrt.devices = #device} { #device = #ifrt // CHECK-LABEL: @with_nested_func_call -module @with_nested_func_call attributes {ifrt.devices = #device} { +module @with_nested_func_call attributes {ifrt.num_devices = 2} { // CHECK-NEXT: func.func @main // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32> // CHECK-SAME: tensor<1x2xi32> @@ -139,11 +139,10 @@ module @with_nested_func_call attributes {ifrt.devices = #device} { // ----- -#device = #ifrt #sharding = #ifrt.sharding_param<1x2 to [0] on 2> // expected-error@+1 {{cannot find entry function `main`}} module @missing_main_function - attributes {ifrt.devices = #device} { + attributes {ifrt.num_devices = 2} { } // ----- @@ -152,7 +151,7 @@ module @missing_main_function #sharding = #ifrt.sharding_param<1x2 to [0] on 2> // expected-error@+1 {{cannot find entry function `entry_func`}} module @missing_entry_function - attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} { + attributes {ifrt.num_devices = 2, ifrt.entry_function = "entry_func"} { func.func @main( %arg0: tensor<2x2xi32> {ifrt.sharding = #sharding, ifrt.devices = #device}) @@ -166,7 +165,7 @@ module @missing_entry_function #device = #ifrt #sharding = #ifrt.sharding_param<2x1 to [0] on 2> -module @non_divisible_global_shape attributes {ifrt.devices = #device} { +module @non_divisible_global_shape attributes {ifrt.num_devices = 2} { // expected-error@+1 {{Global shape is not divisible by the number of shards in dimension 0. Global size: 3, number of shards: 2}} func.func @main( %arg0: tensor<3x2xi32> {ifrt.sharding = #sharding, diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index ccd1919e3ccf5d..620362de4c1b50 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -31,6 +31,7 @@ cc_library( srcs = [ "ifrt_duplicated_callee_elimination_pass.cc", "ifrt_merge_reshards_pass.cc", + "ifrt_outline_atom_program_to_module_pass.cc", "ifrt_verify_donation_pass.cc", "ifrt_verify_sharding_specified_pass.cc", "spmd_expandable_interface_verification_pass.cc", @@ -39,8 +40,8 @@ cc_library( hdrs = ["passes.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":constants", ":passes_inc_gen", + ":utils", "//xla/python/ifrt/ir", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -67,8 +68,14 @@ cc_library( ) cc_library( - name = "constants", - hdrs = ["constants.h"], + name = "utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], compatible_with = get_compatible_with_portable(), - deps = ["@llvm-project//llvm:Support"], + deps = [ + "@com_google_absl//absl/log:check", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], ) diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_outline_atom_program_to_module_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_outline_atom_program_to_module_pass.cc new file mode 100644 index 00000000000000..3074e67aebcaa3 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_outline_atom_program_to_module_pass.cc @@ -0,0 +1,181 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/utils.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTOUTLINEATOMPROGRAMTOMODULEPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +class IfrtOutlineAtomProgramToModulePass + : public impl::IfrtOutlineAtomProgramToModulePassBase< + IfrtOutlineAtomProgramToModulePass> { + public: + using impl::IfrtOutlineAtomProgramToModulePassBase< + IfrtOutlineAtomProgramToModulePass>:: + IfrtOutlineAtomProgramToModulePassBase; + + void runOnOperation() override; +}; + +void IfrtOutlineAtomProgramToModulePass::runOnOperation() { + mlir::SymbolTableCollection symbol_table; + mlir::OpBuilder builder(&getContext()); + llvm::DenseSet visited; + llvm::SmallVector to_erase; + mlir::ModuleOp module_op = getOperation(); + mlir::func::FuncOp main_func = GetMainFunction(module_op); + auto result = + main_func.walk([&](xla::ifrt::CallOp call_op) -> mlir::WalkResult { + // Maybe visited by a previous CallOp with the same callee. + if (visited.contains(call_op)) { + return mlir::WalkResult::advance(); + } + + // Find the callee. + mlir::func::FuncOp callee = call_op.getCalleeOp(symbol_table); + if (callee.getSymName() == kCalleeMainFuncName && + llvm::isa(callee->getParentOp())) { + // Atom program is already outlined in module. Do nothing. + return mlir::WalkResult::advance(); + } + + // Create a ModuleOp and clone callee into it. + builder.setInsertionPointAfter(callee); + auto callee_module = builder.create( + callee->getLoc(), callee.getSymName()); + callee_module.setVisibility(mlir::SymbolTable::Visibility::Private); + + mlir::func::FuncOp cloned_callee; + // Find all symbols directly or indirectly referenced by callee and copy + // them to the newly created module. + { + // Setup for DFS. + llvm::DenseSet visited_funcs; + llvm::SmallVector func_stack = {callee}; + while (!func_stack.empty()) { + mlir::func::FuncOp current_func = func_stack.back(); + func_stack.pop_back(); + if (!visited_funcs.insert(current_func).second) { + continue; + } + + // Copy function into the new module. + mlir::func::FuncOp cloned_func = + llvm::cast(current_func->clone()); + if (current_func == callee) { + cloned_callee = cloned_func; + cloned_func.setSymName(kCalleeMainFuncName); + cloned_func.setVisibility(mlir::SymbolTable::Visibility::Public); + } + builder.setInsertionPointToEnd(callee_module.getBody()); + builder.insert(cloned_func); + + // Check all symbols in function. + std::optional sym_uses = + mlir::SymbolTable::getSymbolUses(current_func); + if (!sym_uses.has_value()) { + continue; + } + for (const mlir::SymbolTable::SymbolUse& sym_use : *sym_uses) { + // Ensure the symbol represents a function. + mlir::Operation* sym_op = module_op.lookupSymbol( + sym_use.getSymbolRef().getRootReference()); + if (sym_op == nullptr) { + return sym_use.getUser()->emitOpError() + << "uses a symbol in attributes `" + << sym_use.getSymbolRef().getRootReference().str() + << "` that does not exist in the ModuleOp."; + } + auto func = llvm::dyn_cast(sym_op); + if (func == nullptr) { + return sym_use.getUser()->emitOpError() + << "uses a symbol in attributes `" + << sym_use.getSymbolRef().getRootReference().str() + << "` that is not a FuncOp. Cannot handle such cases " + "for now."; + } + func_stack.push_back(func); + } + } + } + + // Replace all uses of old callee. + mlir::SymbolRefAttr new_symbol = mlir::SymbolRefAttr::get( + callee_module.getSymNameAttr(), + mlir::SymbolRefAttr::get(cloned_callee.getSymNameAttr())); + // It is sufficient to get the symbols in the main func because + // ifrt.Call nested within callees are not supported. + std::optional symbol_uses = + callee.getSymbolUses(main_func); + if (symbol_uses.has_value()) { + for (const mlir::SymbolTable::SymbolUse symbol_use : *symbol_uses) { + auto user = llvm::dyn_cast(symbol_use.getUser()); + if (user == nullptr) { + return symbol_use.getUser()->emitOpError() + << "requires symbol `" << callee.getSymName() + << "` only used by ifrt.Call. Found use by `" + << user.getOperationName() << "`"; + } + user.setCalleeAttr(new_symbol); + visited.insert(user); + } + } + + // Can't erase callee yet during iteration. + to_erase.push_back(callee); + return mlir::WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + return; + } + for (mlir::Operation* op : to_erase) { + op->erase(); + } +} + +} // namespace + +std::unique_ptr> +CreateIfrtOutlineAtomProgramToModulePass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h index a2cd1748a6c3b0..da7ec1ab599795 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h @@ -20,10 +20,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" namespace xla { namespace ifrt { @@ -43,6 +40,9 @@ CreateIfrtDuplicatedCalleeEliminationPass(); std::unique_ptr> CreateIfrtMergeReshardsPass(); +std::unique_ptr> +CreateIfrtOutlineAtomProgramToModulePass(); + std::unique_ptr> CreateIfrtVerifyDonationPass(); diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td index c8c8e99bdca1d4..10215b72653e0c 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td @@ -95,6 +95,47 @@ module attributes {ifrt.devices = #device} { let constructor = "CreateSpmdExpansionPass()"; } +def IfrtOutlineAtomProgramToModulePass : + Pass<"ifrt-outline-atom-program-to-module", "mlir::ModuleOp"> { + let summary = "Wraps every atom function with a ModuleOp with a @main FuncOp"; + let description = [{ +For every unique atom program this passes produces a ModuleOp with the same name +as the callee, clones the callee into the ModuleOp, and redirects all the +CallOps calling it to the new callee. + +This pass must be run if the compiler (e.g., the XLA compiler) expects each atom +program to be outlined in a ModuleOp with a @main FuncOp. + +For example, the following code + +```mlir +%0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0, 1] + : (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> + +func.func private @callee(%arg0: tensor<2x2xi32>) -> (tensor<4x4xi32>) {} +``` + +will be replaced by + +```mlir +%0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0, 1] + : (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> + +module @callee attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<4x4xi32>) {} +} +``` + }]; + + let constructor = "CreateIfrtOutlineAtomProgramToModulePass()"; +} + def IfrtDuplicatedCalleeEliminationPass : Pass<"ifrt-duplicated-callee-elimination", "mlir::ModuleOp"> { let summary = "Deduplicate callees of CallOp"; diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc index 2669dfd73d2256..13d198f2dbf8c8 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc @@ -35,9 +35,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/constants.h" #include "xla/python/ifrt/ir/ifrt_interfaces.h" -#include "xla/python/ifrt/ir/transforms/constants.h" #include "xla/python/ifrt/ir/transforms/passes.h" namespace xla::ifrt { @@ -272,15 +271,15 @@ mlir::LogicalResult SpmdExpansionPass::spmdExpand(mlir::func::FuncOp func_op) { void SpmdExpansionPass::runOnOperation() { mlir::ModuleOp module_op = getOperation(); // Skip single-device case. - auto devices = module_op->getAttrOfType( - kIfrtDevicesAttrName); - if (devices == nullptr) { + auto num_devices = + module_op->getAttrOfType(kIfrtNumDevicesAttrName); + if (num_devices == nullptr) { module_op->emitOpError() << "`" << module_op.getName()->str() << "` requires `" - << kIfrtDevicesAttrName << "` attribute."; + << kIfrtNumDevicesAttrName << "` attribute."; return signalPassFailure(); } - if (devices.getIds().size() == 1) { + if (num_devices.getInt() == 1) { return; } diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc b/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc new file mode 100644 index 00000000000000..b1cb219e5e49fe --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc @@ -0,0 +1,34 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/ir/transforms/utils.h" + +#include "absl/log/check.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" + +namespace xla { +namespace ifrt { + +mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module) { + mlir::func::FuncOp func = + mlir::dyn_cast_or_null(module.lookupSymbol("main")); + CHECK(func); + return func; +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/constants.h b/third_party/xla/xla/python/ifrt/ir/transforms/utils.h similarity index 53% rename from third_party/xla/xla/python/ifrt/ir/transforms/constants.h rename to third_party/xla/xla/python/ifrt/ir/transforms/utils.h index 98bfd12e2c19b8..81528e97f418ae 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/constants.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_IFRT_IR_TRANSFORMS_CONSTANTS_H_ -#define XLA_PYTHON_IFRT_IR_TRANSFORMS_CONSTANTS_H_ +#ifndef XLA_PYTHON_IFRT_IR_TRANSFORMS_UTILS_H_ +#define XLA_PYTHON_IFRT_IR_TRANSFORMS_UTILS_H_ -#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" -namespace xla::ifrt { +namespace xla { +namespace ifrt { -inline constexpr llvm::StringLiteral kIfrtDevicesAttrName = "ifrt.devices"; -inline constexpr llvm::StringLiteral kIfrtShardingAttrName = "ifrt.sharding"; -inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName = - "ifrt.entry_function"; +// Retrieves the function named "main" from the given module, if it exists, and +// fails otherwise. +mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module); -} // namespace xla::ifrt +} // namespace ifrt +} // namespace xla -#endif // XLA_PYTHON_IFRT_IR_TRANSFORMS_CONSTANTS_H_ +#endif // XLA_PYTHON_IFRT_IR_TRANSFORMS_UTILS_H_ diff --git a/third_party/xla/xla/python/ifrt/memory.cc b/third_party/xla/xla/python/ifrt/memory.cc index c608950e3e8aef..c04bc0bead8ec6 100644 --- a/third_party/xla/xla/python/ifrt/memory.cc +++ b/third_party/xla/xla/python/ifrt/memory.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/node_hash_set.h" -#include "xla/pjrt/pjrt_client.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "xla/python/ifrt/device.h" namespace xla { @@ -52,7 +54,7 @@ MemoryKind::MemoryKind(std::optional memory_kind) { } } -std::string MemoryKind::DebugString() const { +std::string MemoryKind::ToString() const { if (memory_kind_.has_value()) { return std::string(*memory_kind_); } diff --git a/third_party/xla/xla/python/ifrt/memory.h b/third_party/xla/xla/python/ifrt/memory.h index a3117e5e3049d7..309d49705381e3 100644 --- a/third_party/xla/xla/python/ifrt/memory.h +++ b/third_party/xla/xla/python/ifrt/memory.h @@ -62,17 +62,15 @@ class MemoryKind { template friend void AbslStringify(Sink& sink, const MemoryKind& memory_kind) { - sink.Append(memory_kind.DebugString()); + sink.Append(memory_kind.ToString()); } // Returns a platform-dependent identifier of a memory kind. std::optional memory_kind() const { return memory_kind_; } - // TODO(kedars): Rename & make private after replacing usage with - // AbslStringify. - std::string DebugString() const; - private: + std::string ToString() const; + std::optional memory_kind_; }; diff --git a/third_party/xla/xla/python/ifrt/sharding.cc b/third_party/xla/xla/python/ifrt/sharding.cc index e302535cc4f974..3cc2bcfb5668d3 100644 --- a/third_party/xla/xla/python/ifrt/sharding.cc +++ b/third_party/xla/xla/python/ifrt/sharding.cc @@ -50,6 +50,14 @@ namespace ifrt { namespace { +// Returns a canonicalized memory kind for the given devices. +// REQUIRES: !devices.empty() +MemoryKind CanonicalizeMemoryKindWithDevices(const MemoryKind& memory_kind, + const DeviceList& devices) { + CHECK(!devices.empty()); + return CanonicalizeMemoryKind(memory_kind, devices.front()); +} + // Returns if `sharding_param` indicates a fully replicated sharding. bool ComputeIsFullyReplicated(const ShardingParam& sharding_param) { return llvm::all_of(sharding_param.dim_shards(), @@ -155,6 +163,12 @@ char ShardingParamSharding::ID = 0; char DeserializeShardingOptions::ID = 0; +Sharding::Sharding(DeviceList devices, MemoryKind memory_kind, + bool is_fully_replicated) + : devices_(std::move(devices)), + memory_kind_(memory_kind), + is_fully_replicated_(is_fully_replicated) {} + bool Sharding::operator==(const Sharding& other) const { if (this == &other) { return true; @@ -184,6 +198,7 @@ std::ostream& operator<<(std::ostream& os, const Sharding& sharding) { std::unique_ptr SingleDeviceSharding::Create( Device* device, MemoryKind memory_kind) { + memory_kind = CanonicalizeMemoryKind(memory_kind, device); return std::unique_ptr( new SingleDeviceSharding(device, memory_kind)); } @@ -240,13 +255,13 @@ absl::StatusOr> SingleDeviceSharding::IndexDomains( std::string SingleDeviceSharding::DebugString() const { DCHECK(this); - return absl::StrFormat("SingleDeviceSharding(%s, memory_kind: %s)", - devices_.front()->ToString(), - memory_kind_.DebugString()); + return absl::StrFormat("SingleDeviceSharding(%s, memory_kind: %v)", + devices_.front()->ToString(), memory_kind_); } std::unique_ptr OpaqueSharding::Create(DeviceList devices, MemoryKind memory_kind) { + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr( new OpaqueSharding(std::move(devices), memory_kind)); } @@ -306,18 +321,19 @@ absl::StatusOr> OpaqueSharding::IndexDomains( std::string OpaqueSharding::DebugString() const { DCHECK(this); return absl::StrFormat( - "OpaqueSharding(devices: %s, memory_kind: %s)", + "OpaqueSharding(devices: %s, memory_kind: %v)", absl::StrJoin(devices_, ",", [](std::string* out, const Device* device) { absl::StrAppend(out, device->ToString()); }), - memory_kind_.DebugString()); + memory_kind_); } std::unique_ptr ConcreteSharding::Create( DeviceList devices, MemoryKind memory_kind, Shape shape, std::vector shard_shapes) { CHECK_EQ(devices.size(), shard_shapes.size()); + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr( new ConcreteSharding(std::move(devices), memory_kind, std::move(shape), std::move(shard_shapes))); @@ -327,6 +343,7 @@ std::unique_ptr ConcreteSharding::Create( DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape, std::vector shard_dynamic_shapes) { CHECK_EQ(devices.size(), shard_dynamic_shapes.size()); + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr(new ConcreteSharding( std::move(devices), memory_kind, std::move(dynamic_shape), std::move(shard_dynamic_shapes))); @@ -454,7 +471,7 @@ std::string ConcreteSharding::DebugString() const { [this](const auto& shape, const auto& shard_shapes) { return absl::StrFormat( "ConcreteSharding(devices: %s, shape: %s, shard_shapes: %s, " - "memory_kind: %s)", + "memory_kind: %v)", absl::StrJoin(devices_, ",", [](std::string* out, const Device* device) { absl::StrAppend(out, device->ToString()); @@ -464,7 +481,7 @@ std::string ConcreteSharding::DebugString() const { [](std::string* out, const auto& shard_shape) { absl::StrAppend(out, shard_shape.DebugString()); }), - memory_kind_.DebugString()); + memory_kind_); }, shape_, shard_shapes_); } @@ -472,6 +489,7 @@ std::string ConcreteSharding::DebugString() const { std::unique_ptr ConcreteEvenSharding::Create( DeviceList devices, MemoryKind memory_kind, Shape shape, Shape shard_shape, bool is_fully_replicated) { + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr(new ConcreteEvenSharding( std::move(devices), memory_kind, std::move(shape), std::move(shard_shape), is_fully_replicated)); @@ -565,13 +583,12 @@ std::string ConcreteEvenSharding::DebugString() const { DCHECK(this); return absl::StrFormat( "ConcreteEvenSharding(devices: %s, shape: %s, shard_shape: %s, " - "memory_kind: %s)", + "memory_kind: %v)", absl::StrJoin(devices_, ",", [](std::string* out, const Device* device) { absl::StrAppend(out, device->ToString()); }), - shape_.DebugString(), shard_shape_.DebugString(), - memory_kind_.DebugString()); + shape_.DebugString(), shard_shape_.DebugString(), memory_kind_); } absl::StatusOr> @@ -586,6 +603,7 @@ ShardingParamSharding::Create(ShardingParam sharding_param, DeviceList devices, "%d", device_count, devices.size()); } + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr(new ShardingParamSharding( std::move(sharding_param), std::move(devices), memory_kind)); } @@ -595,7 +613,8 @@ ShardingParamSharding::ShardingParamSharding(ShardingParam sharding_param, DeviceList devices, MemoryKind memory_kind) : llvm::RTTIExtends( - devices, memory_kind, ComputeIsFullyReplicated(sharding_param)), + std::move(devices), memory_kind, + ComputeIsFullyReplicated(sharding_param)), sharding_param_(sharding_param) {} absl::StatusOr>>> @@ -710,13 +729,13 @@ absl::StatusOr> ShardingParamSharding::IndexDomains( std::string ShardingParamSharding::DebugString() const { DCHECK(this); return absl::StrFormat( - "ShardingParamSharding(%s, devices: %s, memory_kind: %s)", + "ShardingParamSharding(%s, devices: %s, memory_kind: %v)", sharding_param_.DebugString(), absl::StrJoin(devices_, ",", [](std::string* out, const Device* device) { absl::StrAppend(out, device->ToString()); }), - memory_kind_.DebugString()); + memory_kind_); } } // namespace ifrt diff --git a/third_party/xla/xla/python/ifrt/sharding.h b/third_party/xla/xla/python/ifrt/sharding.h index c7fbd258cee56d..91b8b8ad1b31cb 100644 --- a/third_party/xla/xla/python/ifrt/sharding.h +++ b/third_party/xla/xla/python/ifrt/sharding.h @@ -125,10 +125,8 @@ class Sharding : public llvm::RTTIExtends { static char ID; // NOLINT protected: - Sharding(DeviceList devices, MemoryKind memory_kind, bool is_fully_replicated) - : devices_(devices), - memory_kind_(memory_kind), - is_fully_replicated_(is_fully_replicated) {} + Sharding(DeviceList devices, MemoryKind memory_kind, + bool is_fully_replicated); DeviceList devices_; MemoryKind memory_kind_; @@ -189,6 +187,7 @@ class SingleDeviceSharding final class OpaqueSharding : public llvm::RTTIExtends { public: // Creates an opaque sharding. `Disassemble()` will fail. + // REQUIRES: !devices.empty() static std::unique_ptr Create(DeviceList devices, MemoryKind memory_kind); @@ -230,6 +229,7 @@ class ConcreteSharding : public llvm::RTTIExtends { public: // Creates a concrete sharding that may contain non-identical shard shapes. // REQUIRES: `devices`.size() == `shard_shapes`.size() + // REQUIRES: !devices.empty() static std::unique_ptr Create( DeviceList devices, MemoryKind memory_kind, Shape shape, std::vector shard_shapes); @@ -237,6 +237,7 @@ class ConcreteSharding : public llvm::RTTIExtends { // Creates a concrete sharding that may contain non-identical shard dynamic // shapes. // REQUIRES: `devices`.size() == `shard_dynamic_shapes`.size() + // REQUIRES: !devices.empty() static std::unique_ptr Create( DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape, std::vector shard_dynamic_shapes); @@ -321,6 +322,7 @@ class ConcreteEvenSharding // Creates a concrete even sharding. // TODO(hyeontaek): Remove the default value of `is_fully_replicated` once all // callers are updated to provide it explicitly. + // REQUIRES: !devices.empty() static std::unique_ptr Create( DeviceList devices, MemoryKind memory_kind, Shape shape, Shape shard_shape, bool is_fully_replicated = false); @@ -371,6 +373,7 @@ class ConcreteEvenSharding class ShardingParamSharding : public llvm::RTTIExtends { public: + // REQUIRES: !devices.empty() static absl::StatusOr> Create( ShardingParam sharding_param, DeviceList devices, MemoryKind memory_kind); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 442b8f7bf85d51..8e947a4e68beac 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -103,7 +103,11 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status_to_from_proto", + "@local_tsl//tsl/profiler/lib:traceme", + "@local_tsl//tsl/profiler/lib:traceme_encode", + "@local_tsl//tsl/profiler/utils:xplane_schema", ] + if_google(["@com_google_absl//absl/types:source_location"]), ) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index 1c334a8d7346a5..ff116a759e31a6 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -23,19 +23,23 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#if defined(PLATFORM_GOOGLE) -#include "absl/types/source_location.h" -#endif #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/random.h" #include "tsl/platform/status_to_from_proto.h" +#include "tsl/profiler/lib/traceme.h" +#include "tsl/profiler/lib/traceme_encode.h" +#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace ifrt { namespace proxy { +using ::tsl::profiler::XFlow; + // DoRpc is a templated function that implements the logic of all RPC-wrapping // functions of `RpcHelper`, such as `RpcHelper::MakeArrayFromHostBuffer()`. template @@ -44,14 +48,28 @@ Future> DoRpc(ClientSession* session, void (IfrtRequest::*set_req)(Req*), Resp* (IfrtResponse::*get_resp)(), bool (IfrtResponse::*has_resp)() const, - std::unique_ptr req) { + std::unique_ptr req, + absl::string_view profiling_send_name, + absl::string_view profiling_recv_name) { auto ifrt_req = std::make_unique(); *ifrt_req->mutable_request_metadata() = metadata; (ifrt_req.get()->*set_req)(req.release()); + const uint64_t xflow_id = tsl::random::New64() >> 8; // XFlow IDs are 56 bits + tsl::profiler::TraceMe traceme([xflow_id, profiling_send_name]() { + const XFlow flow(xflow_id, XFlow::FlowDirection::kFlowOut); + return tsl::profiler::TraceMeEncode(profiling_send_name, + {{"flow", flow.ToStatValue()}}); + }); + auto promise = Future>::CreatePromise(); - auto on_ready = [promise, has_resp, get_resp]( + auto on_ready = [promise, has_resp, get_resp, xflow_id, profiling_recv_name]( absl::StatusOr> r) mutable { + tsl::profiler::TraceMe traceme([xflow_id, profiling_recv_name]() { + const XFlow flow(xflow_id, XFlow::FlowDirection::kFlowIn); + return tsl::profiler::TraceMeEncode(profiling_recv_name, + {{"flow", flow.ToStatValue()}}); + }); if (!r.ok()) { LOG_EVERY_N_SEC(ERROR, 10) << "Connection to IFRT proxy server was terminated: " << r.status(); @@ -127,13 +145,14 @@ void RpcHelper::Disconnect() { // TODO(b/266635130): Remove this preprocessor macro. Preprocessor macros // go against the style guide, but are convenient as we are introducing more // RPCs and are making changes to the exact signature of the DoRpc function. -#define RPC(METHOD, PROPERTY) \ - RpcHelper::ResponseFuture RpcHelper::METHOD( \ - std::unique_ptr req) { \ - return DoRpc(session_.get(), ManufactureRequestMetadata(), \ - &IfrtRequest::set_allocated_##PROPERTY##_request, \ - &IfrtResponse::mutable_##PROPERTY##_response, \ - &IfrtResponse::has_##PROPERTY##_response, std::move(req)); \ +#define RPC(METHOD, PROPERTY) \ + RpcHelper::ResponseFuture RpcHelper::METHOD( \ + std::unique_ptr req) { \ + return DoRpc(session_.get(), ManufactureRequestMetadata(), \ + &IfrtRequest::set_allocated_##PROPERTY##_request, \ + &IfrtResponse::mutable_##PROPERTY##_response, \ + &IfrtResponse::has_##PROPERTY##_response, std::move(req), \ + "" #PROPERTY "_send", "" #PROPERTY "_recv"); \ } RPC(Init, init); diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.cc b/third_party/xla/xla/python/ifrt_proxy/common/types.cc index 9d222a453c58ee..db981531c24c27 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/types.cc +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.cc @@ -83,7 +83,6 @@ proto::ArrayCopySemantics ToArrayCopySemanticsProto(ArrayCopySemantics s) { absl::StatusOr FromArrayCopySemanticsProto( proto::ArrayCopySemantics s) { - MakeArrayFromHostBufferRequest req; switch (s) { case proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY: return ArrayCopySemantics::kAlwaysCopy; diff --git a/third_party/xla/xla/python/pjit.cc b/third_party/xla/xla/python/pjit.cc index 51a7bb7ff976ad..6bbe898eba48fd 100644 --- a/third_party/xla/xla/python/pjit.cc +++ b/third_party/xla/xla/python/pjit.cc @@ -46,8 +46,10 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" @@ -90,6 +92,7 @@ struct PjitCacheEntry { // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` // in PjitFunction::Call before calling into compiled computation. std::vector kept_var_bitvec; + std::vector in_device_local_layouts; // Ensures a single thread performs the compilation for a given executable. // @@ -351,11 +354,12 @@ PjitFunction::PjitFunction( PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } void CallShardArgFallback( - nb::handle arg, nb::handle sharding, const nb::callable& fallback, + nb::handle arg, nb::handle sharding, nb::handle layout, + const nb::callable& fallback, std::vector>& num_args_arrays, std::vector& keep_alive_objects) { tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); - auto py_array_or_bufs = fallback(arg, sharding); + auto py_array_or_bufs = fallback(arg, sharding, layout); auto py_array = nb::cast(py_array_or_bufs); num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); keep_alive_objects.push_back(std::move(py_array_or_bufs)); @@ -368,6 +372,7 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, absl::Span flat_dynamic_args, bool enable_x64, const std::vector& kept_args, const std::vector& in_shardings, + const std::vector& in_device_local_layouts, const nb::callable& shard_arg_fallback, std::vector& keep_alive_objects) { const auto& addressable_devices = @@ -401,11 +406,13 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, ++dce_i; const nb::object& arg = flat_dynamic_args[i]; + const nb::object& in_device_local_layout = + in_device_local_layouts[dce_index]; auto transfer_guard_formatter = [] { return std::string(""); }; if (arg.type().ptr() != xla::PyArray::type().ptr()) { - if (data_device != nullptr) { + if (data_device != nullptr && in_device_local_layout.is_none()) { TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); TF_ASSIGN_OR_RETURN( @@ -426,8 +433,8 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, continue; } else { CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, - keep_alive_objects); + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); continue; } } @@ -442,17 +449,31 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, DCHECK(py_array.committed() || (!py_array.committed() && sharding_num_devices == 1)); + if (!in_device_local_layout.is_none()) { + TF_ASSIGN_OR_RETURN(auto arr_layout, py_array.ifrt_array()->layout()); + xla::Layout in_xc_layout = nb::cast( + in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); + if (in_xc_layout != GetXlaLayoutUnsafe(arr_layout)) { + CallShardArgFallback(arg.ptr(), in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { + CHECK(in_device_local_layout.is_none()); CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, - keep_alive_objects); + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); continue; } if (py_array.num_shards() != addressable_devices.size()) { + CHECK(in_device_local_layout.is_none()); CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, - keep_alive_objects); + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); continue; } @@ -659,7 +680,8 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, auto num_args_arrays = PrepareIfrtInputs( *cache_entry->executable, flat_dynamic_args, call_signature.jax_enable_x64, cache_entry->kept_var_bitvec, - cache_entry->in_shardings, shard_arg_fallback_, keep_alive_objects); + cache_entry->in_shardings, cache_entry->in_device_local_layouts, + shard_arg_fallback_, keep_alive_objects); if (!num_args_arrays.ok()) { VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); @@ -821,6 +843,13 @@ void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, for (nb::handle k : kept_var_bitvec) { cache_entry.kept_var_bitvec.push_back(nb::cast(k)); } + + nb::sequence in_device_local_layouts = + fastpath_data.attr("in_device_local_layouts"); + cache_entry.in_device_local_layouts.reserve(nb::len(in_device_local_layouts)); + for (nb::handle dll : in_device_local_layouts) { + cache_entry.in_device_local_layouts.push_back(nb::borrow(dll)); + } } // Helper function used by the tp_clear GC method. diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 73324e6b1c8c91..751b00c9b37620 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -92,9 +92,8 @@ absl::Status ValidateArrayCreationInput( if (canonicalized_sharding_memory_kind != buffer_memory_kind) { return InvalidArgument( "PjRtBuffer's memory kind does not match sharding's memory kind. Got " - "PjRtBuffer's memory kind: %s vs shardings's memory kind: %s", - buffer_memory_kind.DebugString(), - canonicalized_sharding_memory_kind.DebugString()); + "PjRtBuffer's memory kind: %v vs shardings's memory kind: %v", + buffer_memory_kind, canonicalized_sharding_memory_kind); } } return absl::OkStatus(); @@ -116,8 +115,8 @@ absl::StatusOr GetMemoryKindFromPjRtBuffers( pjrt_buffer->device())) { return InvalidArgument( "Memory kind mismatch between PjRtBuffers. Got one buffer with " - "memory kind: %s and another with memory_kind: %s", - first_memory_kind.DebugString(), memory_kind.DebugString()); + "memory kind: %v and another with memory_kind: %v", + first_memory_kind, memory_kind); } } return first_memory_kind; @@ -440,11 +439,10 @@ absl::StatusOr GetMemorySpaceFromMemoryKind( } if (memory == nullptr) { return InvalidArgument( - "Invalid memory kind: %s; available memory kinds: %s", - memory_kind.DebugString(), + "Invalid memory kind: %v; available memory kinds: %s", memory_kind, absl::StrJoin(device->Memories(), ", ", [](std::string* out, Memory* m) { - absl::StrAppend(out, m->Kind().DebugString()); + absl::StrAppend(out, m->Kind()); })); } return memory; diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index d77d1c0bf69650..42ffc9aca0353d 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -651,8 +651,8 @@ PjRtLoadedExecutable::Execute(absl::Span> args, memory_kind, pjrt_outputs[j][i]->device())) { return FailedPrecondition( "Memory kind mismatch between PjRtBuffers. Got one buffer with " - "memory kind '%s' and another with memory_kind '%s'", - first_memory_kind.DebugString(), memory_kind.DebugString()); + "memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, memory_kind); } } buffers.push_back( diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc index 6f79e56502eb77..62a07724d1cc42 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc @@ -97,11 +97,20 @@ std::vector IndexDomainsSlowPath( return result; } +// Returns a canonicalized memory kind for the given devices. +// REQUIRES: !devices.empty() +MemoryKind CanonicalizeMemoryKindWithDevices(const MemoryKind& memory_kind, + const DeviceList& devices) { + CHECK(!devices.empty()); + return CanonicalizeMemoryKind(memory_kind, devices.front()); +} + } // namespace std::unique_ptr HloSharding::Create( DeviceList devices, MemoryKind memory_kind, xla::HloSharding xla_hlo_sharding) { + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr(new HloSharding( std::move(devices), memory_kind, std::move(xla_hlo_sharding))); } @@ -340,9 +349,8 @@ absl::StatusOr> HloSharding::IndexDomains( } std::string HloSharding::DebugString() const { - return absl::StrFormat("HloSharding(memory_kind: %s, hlo_sharding: %s)", - memory_kind_.DebugString(), - xla_hlo_sharding_.ToString()); + return absl::StrFormat("HloSharding(memory_kind: %v, hlo_sharding: %s)", + memory_kind_, xla_hlo_sharding_.ToString()); } std::vector TEST_HloShardingIndexDomainsSlowPath( diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index 8fd0aa4e6d7370..b350116d53b043 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -87,11 +87,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" -// TODO(b/324133505): remove this GOOGLE_CUDA block after JAX OSS migrates -// to cuda plugin. -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_driver.h" -#endif #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -183,9 +178,8 @@ tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( throw nb::value_error( absl::StrFormat( "Memory kind mismatch between PjRtBuffers. Got one buffer with " - "memory kind '%s' and another with memory_kind '%s'", - first_memory_kind.DebugString(), - ifrt_arrays.back()->sharding().memory_kind().DebugString()) + "memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_arrays.back()->sharding().memory_kind()) .c_str()); } } @@ -638,10 +632,8 @@ absl::Status PyArray::set_arrays(nb::object obj) { throw nb::value_error( absl::StrFormat( "Memory kind mismatch between single-device arrays. Got one " - "array " - "with memory kind '%s' and another with memory_kind '%s'", - first_memory_kind.DebugString(), - ifrt_array->sharding().memory_kind().DebugString()) + "array with memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_array->sharding().memory_kind()) .c_str()); } } @@ -865,19 +857,6 @@ absl::StatusOr CudaArrayInterfaceToBuffer( PrimitiveType element_type, DtypeToPrimitiveType(nb_dtype::from_args(cai["typestr"]))); - // TODO(b/324133505): remove this GOOGLE_CUDA block after JAX OSS migrates - // to cuda plugin. -#ifdef GOOGLE_CUDA - if (!device_id.has_value()) { - // cannot determine device_id/stream when device pointer is NULL. - device_id.emplace( - (data_value == 0 - ? 0 - : stream_executor::gpu::CreatedContexts::GetDeviceOrdinal( - data_ptr))); - } -#endif // GOOGLE_CUDA - if (!device_id.has_value()) { throw XlaRuntimeError( "This operation requires CUDA support from jaxlib or jax cuda plugin."); diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc index 9d9db9afccffec..6f5aff61938b88 100644 --- a/third_party/xla/xla/python/py_compile_only_client.cc +++ b/third_party/xla/xla/python/py_compile_only_client.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include #include +#include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" @@ -79,6 +81,40 @@ namespace xla { namespace { +class CompileOnlyMemory + : public llvm::RTTIExtends { + public: + explicit CompileOnlyMemory( + int id, const PjRtMemorySpaceDescription* memory_description, + ifrt::Device* device) + : id_(id), + kind_(memory_description->kind()), + debug_string_(absl::StrFormat("CompileOnlyMemory(id=%d, kind=%s)", id, + memory_description->kind())), + device_(device) {} + + ifrt::MemoryId Id() const override { return ifrt::MemoryId(id_); } + + const ifrt::MemoryKind& Kind() const override { return kind_; } + + absl::string_view ToString() const override { return debug_string_; } + absl::string_view DebugString() const override { return debug_string_; } + + absl::Span Devices() const override { + return absl::Span{&device_, 1}; + } + + static char ID; // NOLINT + + private: + int id_; + ifrt::MemoryKind kind_; + std::string debug_string_; + ifrt::Device* device_; +}; + +[[maybe_unused]] char CompileOnlyMemory::ID = 0; + class CompileOnlyDevice : public llvm::RTTIExtends { public: @@ -108,16 +144,31 @@ class CompileOnlyDevice return description_->DebugString(); } - absl::Span Memories() const override { return {}; } + absl::Span Memories() const override { + return unowned_memories_; + } absl::StatusOr DefaultMemory() const override { + if (default_memory_) { + return default_memory_; + } return Unimplemented("DefaultMemory is not supported"); } const ifrt::AttributeMap& Attributes() const override { return attributes_; } + void AttachMemory(std::unique_ptr memory) { + unowned_memories_.push_back(memory.get()); + owned_memories_.push_back(std::move(memory)); + } + + void SetDefaultMemory(ifrt::Memory* memory) { default_memory_ = memory; } + private: const PjRtDeviceDescription* description_; ifrt::AttributeMap attributes_; + ifrt::Memory* default_memory_ = nullptr; + std::vector unowned_memories_; + std::vector> owned_memories_; }; class InvalidIfrtCompiler final @@ -153,10 +204,24 @@ class CompileOnlyIfRtClient final : topology_(std::move(topology)), descriptions_(topology_->DeviceDescriptions()), attributes_(ifrt::AttributeMap::Map()) { + int offset = 0; for (auto& description : descriptions_) { owned_devices_.push_back( std::make_unique(description.get())); - devices_.push_back(owned_devices_.back().get()); + auto* device = owned_devices_.back().get(); + devices_.push_back(device); + if (description->process_index() == process_index()) { + auto default_memory = description->default_memory_space(); + for (auto* memory_description : description->memory_spaces()) { + auto memory = std::make_unique( + offset, memory_description, device); + if (default_memory.ok() && memory_description == *default_memory) { + device->SetDefaultMemory(memory.get()); + } + device->AttachMemory(std::move(memory)); + ++offset; + } + } } } diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 68a483cd51e97f..65bfb3fe5305e4 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -1249,7 +1249,7 @@ nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( nb::cast(nb::repr(node_data->first)))); } node.kind = registration->kind; - if (node.kind == PyTreeKind::kCustom) { + if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { node.custom = registration; node.node_data = node_data->second; } else if (node.kind == PyTreeKind::kNamedTuple) { diff --git a/third_party/xla/xla/python/pytree_test.py b/third_party/xla/xla/python/pytree_test.py index 4125d7a28257a3..922a4d78fd6b56 100644 --- a/third_party/xla/xla/python/pytree_test.py +++ b/third_party/xla/xla/python/pytree_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import collections +import dataclasses from absl.testing import absltest @@ -44,6 +45,15 @@ def from_iterable(state, values): registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable) +@dataclasses.dataclass +class Custom: + a: int + b: str + + +registry.register_dataclass_node(Custom, ["a"], ["b"]) + + class PyTreeTest(absltest.TestCase): def roundtrip(self, example): @@ -92,6 +102,15 @@ def testCompose(self): y = registry.flatten((0, 0))[1] self.assertEqual((x.compose(y)).num_leaves, 2) + def testDataclassMakeFromNodeData(self): + c = Custom(1, "a") + c_leafs, c_tree = registry.flatten(c) + c_tree2 = c_tree.make_from_node_data_and_children( + registry, c_tree.node_data(), c_tree.children() + ) + self.assertEqual(c_tree2.unflatten(c_leafs), c) + self.assertEqual(str(c_tree2), str(c_tree)) + if __name__ == "__main__": absltest.main() diff --git a/third_party/xla/xla/python/sharding.cc b/third_party/xla/xla/python/sharding.cc index acbe324ac75f47..e995df9285d8b9 100644 --- a/third_party/xla/xla/python/sharding.cc +++ b/third_party/xla/xla/python/sharding.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" @@ -176,8 +178,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, nb::object memory_kind, nb::object parsed_pspec, nb::object manual_axes) : Sharding(/*num_devices=*/[&mesh]() { - xla::nb_numpy_ndarray devices = mesh.attr("devices"); - return devices.size(); + return nb::cast(mesh.attr("size")); }()), mesh_(std::move(mesh)), spec_(std::move(spec)), @@ -185,10 +186,18 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, parsed_pspec_(std::move(parsed_pspec)), manual_axes_(std::move(manual_axes)) { nb::object idl = nb::object(mesh_.attr("_internal_device_list")); - internal_device_list_ = nb::cast>( - nb::object(mesh_.attr("_internal_device_list"))); - memory_kind_ = - CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); + if (idl.is_none()) { + internal_device_list_ = std::nullopt; + } else { + internal_device_list_ = nb::cast>( + nb::object(mesh_.attr("_internal_device_list"))); + } + if (internal_device_list_) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); + } else { + memory_kind_ = nb::none(); + } nb::module_ si = nb::module_::import_("jax._src.sharding_impls"); parsed_pspec_ = @@ -265,8 +274,9 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_manual_axes", &NamedSharding::manual_axes) .def_prop_rw("_parsed_pspec", &NamedSharding::parsed_pspec, &NamedSharding::set_parsed_pspec) - .def_prop_ro("_internal_device_list", - &NamedSharding::internal_device_list); + .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { + return xla::ValueOrThrow(s.internal_device_list()); + }); nb::class_(m, "SingleDeviceSharding", nb::dynamic_attr()) diff --git a/third_party/xla/xla/python/sharding.h b/third_party/xla/xla/python/sharding.h index d3b1211619cd7d..1e28b7aecff6b8 100644 --- a/third_party/xla/xla/python/sharding.h +++ b/third_party/xla/xla/python/sharding.h @@ -22,6 +22,7 @@ limitations under the License. // placeholder for index annotation headers #include "absl/hash/hash.h" +#include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" @@ -86,8 +87,13 @@ class NamedSharding : public Sharding { return type; } - xla::nb_class_ptr internal_device_list() const { - return internal_device_list_; + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); } private: @@ -96,7 +102,7 @@ class NamedSharding : public Sharding { nanobind::object memory_kind_; nanobind::object parsed_pspec_; nanobind::object manual_axes_; - xla::nb_class_ptr internal_device_list_; + std::optional> internal_device_list_; }; class SingleDeviceSharding : public Sharding { diff --git a/third_party/xla/xla/python/tools/BUILD b/third_party/xla/xla/python/tools/BUILD index 6d57e560d70cf6..cc0c5e0c189713 100644 --- a/third_party/xla/xla/python/tools/BUILD +++ b/third_party/xla/xla/python/tools/BUILD @@ -86,7 +86,7 @@ py_strict_test( ":types", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", - #internal proto upb dep + # copybara:uncomment "//third_party/py/google/protobuf:use_fast_cpp_protos", "//third_party/py/numpy", "//xla:xla_data_proto_py", ], diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 4feb8cb1e18398..2136e981507f10 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -61,14 +61,18 @@ limitations under the License. #include "xla/python/py_client.h" #include "xla/python/py_program.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT -#ifdef __linux__ +#if defined(__linux__) #include "gloo/transport/tcp/attr.h" #include "gloo/transport/tcp/device.h" #include "xla/pjrt/cpu/gloo_collectives.h" #include "xla/pjrt/cpu/gloo_kv_store.h" -#endif // __linux__ +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/pjrt/cpu/gloo_collectives.h" // NOLINT +#include "xla/pjrt/cpu/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) #if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) #include "xla/pjrt/cpu/mpi_collectives.h" @@ -254,7 +258,7 @@ NB_MODULE(xla_extension, m_nb) { std::optional hostname, std::optional interface) -> std::shared_ptr { -#ifdef __linux__ +#if defined(__linux__) std::shared_ptr kv_store = nullptr; if (distributed_client != nullptr) { kv_store = GetDistributedKeyValueStore(distributed_client, @@ -271,10 +275,27 @@ NB_MODULE(xla_extension, m_nb) { auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); return std::make_shared(std::move(gloo_kv_store), std::move(tcp_device)); -#else // __linux__ +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(uv_device)); +#else // defined(__linux__) throw xla::XlaRuntimeError( - "make_gloo_tcp_collectives only implemented for linux"); -#endif // __linux__ + "make_gloo_tcp_collectives only implemented for linux and macos"); +#endif // defined(__linux__) }, nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, nb::arg("interface").none() = std::nullopt); diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 7d8a2e3c05a9fe..294f109a8d8f7e 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 280 +_version = 282 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index 65d6d7f3c749d2..37484ccffec93b 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -65,8 +65,12 @@ # pylint: disable=invalid-name -def jax_array_convert_to_array(self): - return self._single_device_array_to_np_array() +def jax_array_convert_to_array(self, dtype=None, copy=None): + del copy + out = self._single_device_array_to_np_array() + if dtype is not None: + out = out.astype(dtype) + return out def jax_array_device(self): @@ -586,7 +590,10 @@ class ParametersTest(ComputationTest): def testScalarTimesVector(self, dtype): c = self._NewComputation() arg0 = np.array(3, dtype=dtype) - arg1 = np.array([10, 15, -2, 7], dtype=dtype) + if np.issubdtype(dtype, np.unsignedinteger): + arg1 = np.array([10, 15, 2, 7], dtype=dtype) + else: + arg1 = np.array([10, 15, -2, 7], dtype=dtype) p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) ops.Mul(p0, p1) diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index 7e2504c7ba44e2..f58a59a3cc715c 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -80,10 +80,10 @@ limitations under the License. #include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/python_api/xla_literal.py b/third_party/xla/xla/python_api/xla_literal.py index 3471f3a99cc2db..4ad7bf0a36c587 100644 --- a/third_party/xla/xla/python_api/xla_literal.py +++ b/third_party/xla/xla/python_api/xla_literal.py @@ -50,9 +50,8 @@ def ConvertLiteralToNumpyArray(literal): numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') else: raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) - ndarray = _np.array( + ndarray = _np.asarray( getattr(literal, type_record.literal_field_name), - copy=False, dtype=type_record.numpy_dtype) return numpy_reshaper(ndarray) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 65bee16eef0b3c..9b55155ffab786 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -256,53 +256,6 @@ xla_cc_test( ], ) -cc_library( - name = "all_reduce_splitter", - srcs = ["all_reduce_splitter.cc"], - hdrs = ["all_reduce_splitter.h"], - deps = [ - ":collective_opt_utils", - ":hlo_module_config", - ":hlo_pass", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "all_reduce_splitter_test", - srcs = ["all_reduce_splitter_test.cc"], - deps = [ - ":all_reduce_splitter", - ":hlo_module_config", - ":hlo_pass_pipeline", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_reduce_scatter_creator", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "float_support", srcs = ["float_support.cc"], @@ -603,7 +556,9 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_instruction_utils", "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", "//xla/service:hlo_parser", + "//xla/service:tuple_points_to_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -687,6 +642,7 @@ cc_library( "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", @@ -703,7 +659,6 @@ cc_library( "@llvm-project//mlir:Transforms", "@local_tsl//tsl/lib/io:zlib_compression_options", "@local_tsl//tsl/lib/io:zlib_outputbuffer", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -774,7 +729,6 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", @@ -1089,7 +1043,8 @@ xla_cc_test( deps = [ ":pattern_matcher", ":pattern_matcher_gmock", - "//xla:literal", + "//xla:comparison_util", + "//xla:literal_util", "//xla:protobuf_util", "//xla:shape_util", "//xla:test", @@ -1103,6 +1058,9 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1217,17 +1175,18 @@ xla_cc_test( srcs = ["call_inliner_test.cc"], deps = [ ":call_inliner", - ":hlo_pass", + ":hlo_parser", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", - "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:statusor", ], @@ -1739,6 +1698,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -1746,7 +1706,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -4551,7 +4510,9 @@ xla_cc_test( ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", - "//xla:literal", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_tree", "//xla:shape_util", "//xla:test", "//xla:test_helpers", @@ -4561,7 +4522,9 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) @@ -4583,10 +4546,10 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -5546,6 +5509,7 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", @@ -5571,6 +5535,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", ], ) @@ -6435,6 +6400,58 @@ xla_cc_test( ], ) +cc_library( + name = "host_offload_utils", + srcs = ["host_offload_utils.cc"], + hdrs = ["host_offload_utils.h"], + deps = [ + ":call_graph", + ":hlo_buffer", + ":host_memory_offload_annotations_hdr", + ":pattern_matcher", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_offload_utils_test", + srcs = ["host_offload_utils_test.cc"], + deps = [ + ":hlo_verifier", + ":host_memory_offload_annotations_hdr", + ":host_offload_utils", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "host_offloader", srcs = ["host_offloader.cc"], @@ -6447,6 +6464,7 @@ cc_library( ":hlo_pass", ":hlo_value", ":host_memory_offload_annotations_hdr", + ":host_offload_utils", ":pattern_matcher", "//xla:literal_util", "//xla:shape_util", @@ -6945,6 +6963,7 @@ cc_library( "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", ], ) @@ -6953,9 +6972,12 @@ xla_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo_lexer", + ":hlo_module_config", ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:array", "//xla:shape_util", "//xla:window_util", "//xla:xla_data_proto_cc", @@ -6963,8 +6985,13 @@ xla_cc_test( "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -7268,6 +7295,7 @@ cc_library( ":__subpackages__", "//tensorflow/compiler/tf2xla:__pkg__", "//xla/pjrt:__subpackages__", + "//xla/backends/cpu/runtime:__subpackages__", ]), deps = [ ":custom_call_status", @@ -7962,6 +7990,25 @@ cc_library( ], ) +cc_library( + name = "batched_gather_scatter_normalizer", + srcs = ["batched_gather_scatter_normalizer.cc"], + hdrs = ["batched_gather_scatter_normalizer.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "reduce_window_rewriter", srcs = ["reduce_window_rewriter.cc"], @@ -8088,6 +8135,18 @@ xla_cc_test( ], ) +xla_cc_test( + name = "batched_gather_scatter_normalizer_test", + srcs = ["batched_gather_scatter_normalizer_test.cc"], + deps = [ + ":batched_gather_scatter_normalizer", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + xla_cc_test( name = "change_op_data_type_test", srcs = ["change_op_data_type_test.cc"], @@ -8413,10 +8472,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index 7607b8da61957a..f54864220e2e69 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -3563,48 +3563,28 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { other_index = outer_dnums.lhs_batch_dimensions(i); } - // Once we have the inner_index, we determine whether this index - // corresponds to a dimension coming from the lhs or rhs of inner - bool from_inner_lhs = map_inner_rhs[inner_index] == -1; - - // The map we use depends on which operand of inner this dim comes from - std::vector map; - if (from_inner_lhs) { - map = map_inner_lhs; - } else { - map = map_inner_rhs; - } - - // Whether the mapped value goes into the lhs or rhs of the new dnums - // depends on whether inner was the lhs or rhs operand of outer - int64_t lhs_index, rhs_index; - if (outer_lhs_dot) { - lhs_index = map[inner_index]; - rhs_index = other_index; - } else { - lhs_index = other_index; - rhs_index = map[inner_index]; - } - - // Finally, we have to determine which dnums to add to - DotDimensionNumbers* dnums; - if (outer_lhs_dot) { - if (from_inner_lhs) { - dnums = &ac_dnums; - } else { - dnums = &bc_dnums; - } - } else { - if (from_inner_lhs) { - dnums = &ab_dnums; - } else { - dnums = &ac_dnums; + auto add_batch_dims = [](DotDimensionNumbers& dnums, int64_t lhs_ix, + int64_t rhs_ix) { + dnums.add_lhs_batch_dimensions(lhs_ix); + dnums.add_rhs_batch_dimensions(rhs_ix); + }; + + for (auto& map : {map_inner_lhs, map_inner_rhs}) { + int64_t mapped_index = map[inner_index]; + if (mapped_index != -1) { + // Whether the mapped value is the lhs or rhs of the new dnums + // depends on whether inner is the lhs or rhs operand of outer. The + // dnums itself depends on this and also on which map we are + // iterating through + if (outer_lhs_dot) { + add_batch_dims(map == map_inner_lhs ? ac_dnums : bc_dnums, + mapped_index, other_index); + } else { + add_batch_dims(map == map_inner_lhs ? ab_dnums : ac_dnums, + other_index, mapped_index); + } } } - - // Add the batch dimensions - dnums->add_lhs_batch_dimensions(lhs_index); - dnums->add_rhs_batch_dimensions(rhs_index); } // We now do the same thing for the contracting dimensions of outer @@ -3623,7 +3603,14 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // Once we have the inner_index, we determine whether this index // corresponds to a dimension coming from the lhs or rhs of inner - bool from_inner_lhs = map_inner_rhs[inner_index] == -1; + bool from_inner_lhs = map_inner_lhs[inner_index] != -1; + bool from_inner_rhs = map_inner_rhs[inner_index] != -1; + + // If a dimension of inner is the result of batching and it is + // contracted in outer, we stop trying to reorder + if (from_inner_lhs && from_inner_rhs) { + return absl::OkStatus(); + } // The map we use depends on which operand of inner this dim comes from std::vector map; @@ -3723,8 +3710,11 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { rhs_index = other_index; } - new_outer_dnums.add_lhs_batch_dimensions(lhs_index); - new_outer_dnums.add_rhs_batch_dimensions(rhs_index); + if (!absl::c_linear_search(new_outer_dnums.lhs_batch_dimensions(), + lhs_index)) { + new_outer_dnums.add_lhs_batch_dimensions(lhs_index); + new_outer_dnums.add_rhs_batch_dimensions(rhs_index); + } } for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { int64_t new_inner_index, other_index; @@ -6968,7 +6958,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleDynamicSlice( // Convert a dynamic slice into a slice if all offsets are constant, the // operand is not constant, and the input and output memory spaces are the // same. - if (operand->opcode() != HloOpcode::kConstant && + if (!options_.disable_dynamic_slice_to_slice_conversion() && + operand->opcode() != HloOpcode::kConstant && absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, dynamic_slice->operands().end()), [](HloInstruction* operand) { @@ -7911,28 +7902,52 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } - // Replace Reduce(Broadcast(Scalar), +, init_value) with - // Broadcast(Add(Multiply(Scalar), init_value))) + // Replace Reduce(Broadcast(x), +, init_value) with Broadcast(Add(Multiply(x), + // init_value))) if all reduction dimensions were introduced by Broadcast if (arg->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(arg->operand(0)->shape())) { - if (Match(reduce->to_apply()->root_instruction(), - m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) { - int64_t reduction_dims_prod = 1; - for (auto i : reduce->dimensions()) { - reduction_dims_prod *= arg->shape().dimensions(i); + Match(reduce->to_apply()->root_instruction(), + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) { + bool only_reduce_dims_from_broadcast = true; + int64_t common_dims_prod = 1; + int64_t num_common_dims = 0; + Shape new_broadcast_shape = arg->shape(); + std::vector new_broadcast_dims; + + // Now we build up the new broadcast shape and dims vector + for (int64_t i = 0; i < arg->shape().rank(); ++i) { + bool added_by_broadcast = !absl::c_linear_search(arg->dimensions(), i); + bool removed_by_reduce = absl::c_linear_search(reduce->dimensions(), i); + + if (removed_by_reduce && !added_by_broadcast) { + only_reduce_dims_from_broadcast = false; + break; + } else if (removed_by_reduce && added_by_broadcast) { + new_broadcast_shape.DeleteDimension(i - num_common_dims); + common_dims_prod *= arg->shape().dimensions(i); + num_common_dims++; + } else if (!removed_by_reduce && !added_by_broadcast) { + new_broadcast_dims.push_back(i - num_common_dims); } + } + + if (only_reduce_dims_from_broadcast) { // HloConstantFolding will later remove any unnecessary multiply and add // instructions. HloInstruction* multiplier = - MakeScalarLike(arg->mutable_operand(0), reduction_dims_prod); + MakeScalarLike(arg->mutable_operand(0), common_dims_prod); TF_ASSIGN_OR_RETURN(HloInstruction * multiplied_scalar, MakeBinaryHlo(HloOpcode::kMultiply, arg->mutable_operand(0), multiplier)); TF_ASSIGN_OR_RETURN( HloInstruction * add, - MakeBinaryHlo(HloOpcode::kAdd, init_value, multiplied_scalar)); + MakeBinaryHlo( + HloOpcode::kAdd, + MakeBroadcastHlo(init_value, {}, multiplied_scalar->shape()), + multiplied_scalar)); + VLOG(10) << "Converting common reduce(broadcast) dimensions to multiply"; return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateBroadcast(reduce->shape(), add, {})); + reduce, HloInstruction::CreateBroadcast(new_broadcast_shape, add, + new_broadcast_dims)); } } diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h index 49e87fa504e9c2..185792b336cfa8 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.h +++ b/third_party/xla/xla/service/algebraic_simplifier.h @@ -287,6 +287,14 @@ class AlgebraicSimplifierOptions { executing_on_cpu_ = executing_on_cpu; } + // Option to disable conversion of dynamic-slice to slice. + void set_disable_dynamic_slice_to_slice_conversion(bool disable) { + disable_dynamic_slice_to_slice_conversion_ = disable; + } + bool disable_dynamic_slice_to_slice_conversion() const { + return disable_dynamic_slice_to_slice_conversion_; + } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplifierOptions that can be later used in an @@ -325,6 +333,7 @@ class AlgebraicSimplifierOptions { bool raise_slice_and_reduce_through_dot_{false}; double raise_slice_and_reduce_through_dot_threshold_{2.0}; bool use_convert_constant_folding_{false}; + bool disable_dynamic_slice_to_slice_conversion_{false}; Metadata metadata_; }; diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 8f4cf2ee5e2b2f..b36c9ca5b5cf79 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -6411,6 +6411,94 @@ TEST_F(AlgebraicSimplifierTest, DotAssociativeReorder) { m::Dot(m::Parameter(1), m::Parameter(2))))); } +TEST_F(AlgebraicSimplifierTest, DotLeftDotSharedBatchReorder) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + a = f32[5,150,5] parameter(0) + b = f32[5,5,5] parameter(1) + c = f32[5,5,5] parameter(2) + + inner = f32[5,150,5] dot(a,b), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} + ROOT outer = f32[5,150,5] dot(inner,c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + options.set_associative_reordering_threshold(1.5); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).value()); + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Dot(m::Parameter(0), + m::Dot(m::Parameter(1), m::Parameter(2))))); +} + +TEST_F(AlgebraicSimplifierTest, DotRightDotSharedBatchReorder) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + a = f32[2,3,3] parameter(0) + b = f32[2,3,3] parameter(1) + c = f32[2,3,16] parameter(2) + + inner = f32[2,3,16] dot(b,c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT outer = f32[2,3,16] dot(a,inner), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + options.set_associative_reordering_threshold(1.5); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).value()); + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Dot(m::Dot(m::Parameter(0), m::Parameter(1)), + m::Parameter(2)))); +} + +TEST_F(AlgebraicSimplifierTest, DotRightDotContractBatchReorder) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + a = f32[80,38,1536] parameter(0) + b = f32[80,38,4] parameter(1) + c = f32[80,4,1536] parameter(2) + inner = f32[80,38,1536] dot(b, c), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + ROOT outer = f32[1536,1536] dot(a, inner), + lhs_contracting_dims={0,1}, + rhs_contracting_dims={0,1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + options.set_associative_reordering_threshold(1.5); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} + TEST_F(AlgebraicSimplifierTest, DotReverseLeftReorder) { const char* hlo_string = R"( HloModule module @@ -10092,7 +10180,8 @@ TEST_F(AlgebraicSimplifierTest, ReplaceReduceSumOfConstantBroadcast) { )"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(m.get()).value()); int64_t reduce_count = absl::c_count_if(m->entry_computation()->instructions(), HloPredicateIsOp); @@ -11767,6 +11856,35 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastBF16) { EXPECT_EQ(0, reduce_count); } +TEST_F(AlgebraicSimplifierTest, ReduceOfNonScalarBroadcast) { + const std::string hlo_string = R"( + HloModule module + add { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT sum = f32[] add(a, b) + } + + ENTRY test { + a = f32[64,1001] parameter(0) + broadcast = f32[64,7,7,1001] broadcast(a), dimensions={0,3} + zero = f32[] constant(0) + ROOT reduce = f32[64,7,1001] reduce(broadcast, zero), dimensions={2}, + to_apply=add + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(m.get()).value()); + HloInstruction* root = m->entry_computation()->root_instruction(); + int64_t reduce_count = + absl::c_count_if(m->entry_computation()->instructions(), + HloPredicateIsOp); + // Expect no Reduce operation after simplification. + EXPECT_EQ(0, reduce_count); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Multiply()))); +} + TEST_F(AlgebraicSimplifierTest, RemoveConvertConstant) { const std::string hlo_string = R"( HloModule module @@ -11852,5 +11970,60 @@ TEST_F(AlgebraicSimplifierTest, SinkCbrtThroughMax) { root, GmockMatch(m::Cbrt(m::Maximum(m::Parameter(0), m::Parameter(1))))); } +TEST_F(AlgebraicSimplifierTest, + DynamicSlicePreservedWithTrivialConstantIndices) { + const char* hlo_string = R"( + HloModule module + + ENTRY f { + %operand = s32[2,2] parameter(0) + %constant = u32[] constant(0) + ROOT %dynamic-slice = s32[2,1] dynamic-slice(%operand, %constant, %constant), + dynamic_slice_sizes={2,1} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Disable dynamic-slice to slice conversion + default_options_.set_disable_dynamic_slice_to_slice_conversion(true); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_FALSE(simplifier.Run(module.get()).value()); + + // Expect the dynamic-slice to be preserved + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicSlice(m::Parameter(0), m::Constant(), + m::Constant()))); +} + +TEST_F(AlgebraicSimplifierTest, + DynamicSliceConvertedToConstantSliceWithConstantIndices) { + const char* hlo_string = R"( + HloModule module + + ENTRY f { + %operand = s32[2,2] parameter(0) + %constant = u32[] constant(0) + ROOT %dynamic-slice = s32[2,1] dynamic-slice(%operand, %constant, %constant), + dynamic_slice_sizes={2,1} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Enable dynamic-slice to slice conversion (default behavior) + ASSERT_FALSE(default_options_.disable_dynamic_slice_to_slice_conversion()); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).value()); + + // Expect the dynamic-slice to be converted to a constant slice + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc new file mode 100644 index 00000000000000..441c3b69f3da28 --- /dev/null +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc @@ -0,0 +1,190 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/batched_gather_scatter_normalizer.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" + +namespace xla { + +namespace { +bool IsBatchGather(const HloGatherInstruction* gather) { + const auto& dims = gather->gather_dimension_numbers(); + return !dims.operand_batching_dims().empty(); +} + +bool IsBatchScatter(const HloScatterInstruction* scatter) { + const auto& dims = scatter->scatter_dimension_numbers(); + return !dims.input_batching_dims().empty(); +} + +// Update gather/scater indices by adding fake batching iota dimensions. +HloInstruction* CreateConcatIndices( + HloInstruction* inst, HloInstruction* indices, int64_t index_vector_dim, + absl::Span indices_batching_dims, + BatchedGatherScatterNormalizer* normalizer) { + const bool index_vector_dim_on_last_dim = + index_vector_dim == indices->shape().rank(); + + Shape iota_shape = indices->shape(); + if (index_vector_dim_on_last_dim) { + std::vector dimensions(iota_shape.dimensions().begin(), + iota_shape.dimensions().end()); + dimensions.push_back(1); + iota_shape = ShapeUtil::MakeShape(iota_shape.element_type(), dimensions); + } + iota_shape.set_dimensions(index_vector_dim, 1); + normalizer->UpdateLayout(&iota_shape); + + std::vector indices_to_concat; + for (int64_t indices_batching_dim : indices_batching_dims) { + indices_to_concat.push_back(inst->parent()->AddInstruction( + HloInstruction::CreateIota(iota_shape, indices_batching_dim))); + } + if (index_vector_dim_on_last_dim) { + std::vector dimensions(indices->shape().dimensions().begin(), + indices->shape().dimensions().end()); + dimensions.push_back(1); + Shape reshape_shape = + ShapeUtil::MakeShape(indices->shape().element_type(), dimensions); + normalizer->UpdateLayout(&reshape_shape); + HloInstruction* reshaped_indices = inst->AddInstruction( + HloInstruction::CreateReshape(reshape_shape, indices)); + indices_to_concat.push_back(reshaped_indices); + } else { + indices_to_concat.push_back(indices); + } + Shape concat_shape = iota_shape; + concat_shape.set_dimensions( + index_vector_dim, + indices_batching_dims.size() + + (index_vector_dim_on_last_dim + ? 1 + : indices->shape().dimensions(index_vector_dim))); + normalizer->UpdateLayout(&concat_shape); + return inst->AddInstruction(HloInstruction::CreateConcatenate( + concat_shape, indices_to_concat, index_vector_dim)); +} + +absl::StatusOr NormalizeBatchGather( + HloGatherInstruction* gather, BatchedGatherScatterNormalizer* normalizer) { + HloInstruction* gather_operand = gather->mutable_operand(0); + HloInstruction* gather_indices = gather->mutable_operand(1); + const auto& dims = gather->gather_dimension_numbers(); + CHECK_EQ(dims.operand_batching_dims_size(), + dims.start_indices_batching_dims_size()); + // Update start_index_map. + std::vector start_index_map(dims.operand_batching_dims().begin(), + dims.operand_batching_dims().end()); + absl::c_copy(dims.start_index_map(), std::back_inserter(start_index_map)); + gather_indices = + CreateConcatIndices(gather, gather_indices, dims.index_vector_dim(), + dims.start_indices_batching_dims(), normalizer); + // Update collapsed_slice_dims. + std::vector collapsed_slice_dims(dims.collapsed_slice_dims().begin(), + dims.collapsed_slice_dims().end()); + absl::c_copy(dims.operand_batching_dims(), + std::back_inserter(collapsed_slice_dims)); + absl::c_sort(collapsed_slice_dims); + + GatherDimensionNumbers updated_dims = + HloGatherInstruction::MakeGatherDimNumbers( + dims.offset_dims(), collapsed_slice_dims, start_index_map, + dims.index_vector_dim()); + return gather->AddInstruction(HloInstruction::CreateGather( + gather->shape(), gather_operand, gather_indices, updated_dims, + gather->gather_slice_sizes(), gather->indices_are_sorted())); +} + +absl::StatusOr NormalizeBatchScatter( + HloScatterInstruction* scatter, + BatchedGatherScatterNormalizer* normalizer) { + auto scatter_operands = scatter->scatter_operands(); + HloInstruction* scatter_indices = scatter->scatter_indices(); + auto scatter_updates = scatter->scatter_updates(); + const auto& dims = scatter->scatter_dimension_numbers(); + CHECK_EQ(dims.input_batching_dims_size(), + dims.scatter_indices_batching_dims_size()); + // Update scatter_dims_to_operand_dims. + std::vector scatter_dims_to_operand_dims( + dims.input_batching_dims().begin(), dims.input_batching_dims().end()); + absl::c_copy(dims.scatter_dims_to_operand_dims(), + std::back_inserter(scatter_dims_to_operand_dims)); + scatter_indices = + CreateConcatIndices(scatter, scatter_indices, dims.index_vector_dim(), + dims.scatter_indices_batching_dims(), normalizer); + // Update inserted_window_dims. + std::vector inserted_window_dims(dims.inserted_window_dims().begin(), + dims.inserted_window_dims().end()); + absl::c_copy(dims.input_batching_dims(), + std::back_inserter(inserted_window_dims)); + absl::c_sort(inserted_window_dims); + + ScatterDimensionNumbers updated_dims = + HloScatterInstruction::MakeScatterDimNumbers( + dims.update_window_dims(), inserted_window_dims, + scatter_dims_to_operand_dims, dims.index_vector_dim()); + return scatter->AddInstruction(HloInstruction::CreateScatter( + scatter->shape(), scatter_operands, scatter_indices, scatter_updates, + scatter->to_apply(), updated_dims, scatter->indices_are_sorted(), + scatter->unique_indices())); +} + +} // namespace + +absl::StatusOr +BatchedGatherScatterNormalizer::ExpandInstruction(HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kGather) { + auto* gather = DynCast(inst); + return NormalizeBatchGather(gather, this); + } + if (inst->opcode() == HloOpcode::kScatter) { + auto* scatter = DynCast(inst); + return NormalizeBatchScatter(scatter, this); + } + return absl::InvalidArgumentError(absl::StrFormat( + "Instruction: %s is not a batch gather or scatter.", inst->ToString())); +} + +bool BatchedGatherScatterNormalizer::InstructionMatchesPattern( + HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kGather) { + auto* gather = DynCast(inst); + return IsBatchGather(gather); + } + if (inst->opcode() == HloOpcode::kScatter) { + auto* scatter = DynCast(inst); + return IsBatchScatter(scatter); + } + return false; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.h b/third_party/xla/xla/service/batched_gather_scatter_normalizer.h new file mode 100644 index 00000000000000..4b5560d38dceec --- /dev/null +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.h @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_BATCHED_GATHER_SCATTER_NORMALIZER_H_ +#define XLA_SERVICE_BATCHED_GATHER_SCATTER_NORMALIZER_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/op_expander_pass.h" + +namespace xla { + +// This pass rewrites normalize batch gather and scatter operations into a +// non-batch version. +class BatchedGatherScatterNormalizer : public OpExpanderPass { + public: + absl::string_view name() const override { + return "gather_scatter_normalizer"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_BATCHED_GATHER_SCATTER_NORMALIZER_H_ diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc new file mode 100644 index 00000000000000..81f0882c977ca2 --- /dev/null +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/batched_gather_scatter_normalizer.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class BatchedGatherScatterNormalizerTest : public HloTestBase {}; + +TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchGather) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0}} + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512], start_indices: s64[10,9,8,7,5,512]) -> f32[10,9,8,7,30,29,28,27,26,512] { + %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0) + %start_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} + gather(f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,5,512]{5,4,3,2,1,0} %start_indices), + offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, operand_batching_dims={5}, + start_indices_batching_dims={5}, index_vector_dim=4, slice_sizes={30,29,28,27,26,1} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA:.*]] = s64[10,9,8,7,1,512]{{.*}} iota() + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,6,512]{{.*}} concatenate(%[[IOTA]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[10,9,8,7,30,29,28,27,26,512]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={4,5,6,7,8}, + CHECK-SAME: collapsed_slice_dims={5}, + CHECK-SAME: start_index_map={5,0,1,2,3,4}, + CHECK-SAME: index_vector_dim=4, + CHECK-SAME: slice_sizes={30,29,28,27,26,1} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchGather2) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0}, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0}} + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512,1024,100], start_indices: s64[10,9,8,7,6,512,1024]) -> f32[10,9,8,7,30,29,28,27,26,512,1024] { + %input_tensor = f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} parameter(0) + %start_indices = s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0} + gather(f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} %start_indices), + offset_dims={4,5,6,7,8}, collapsed_slice_dims={7}, start_index_map={0,1,2,3,4,7}, operand_batching_dims={5,6}, + start_indices_batching_dims={5,6}, index_vector_dim=4, slice_sizes={30,29,28,27,26,1,1,1} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[10,9,8,7,1,512,1024]{{.*}} iota() + CHECK: %[[IOTA2:.*]] = s64[10,9,8,7,1,512,1024]{{.*}} iota() + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,8,512,1024]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[10,9,8,7,30,29,28,27,26,512,1024]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={4,5,6,7,8}, + CHECK-SAME: collapsed_slice_dims={5,6,7}, + CHECK-SAME: start_index_map={5,6,0,1,2,3,4,7}, + CHECK-SAME: index_vector_dim=4, + CHECK-SAME: slice_sizes={30,29,28,27,26,1,1,1} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchScatter) { + constexpr absl::string_view kModuleStr = R"( + +HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46,512]{5,4,3,2,1,0}} + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512], scatter_indices: s64[10,9,8,7,5,512], updates: f32[10,9,8,7,30,29,28,27,26,512]) -> f32[50,49,48,47,46,512] { + %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46,512]{5,4,3,2,1,0} scatter( + f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, + s64[10,9,8,7,5,512]{5,4,3,2,1,0} %scatter_indices, + f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} %updates), + update_window_dims={4,5,6,7,8}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1,2,3,4}, input_batching_dims={5}, + scatter_indices_batching_dims={5}, index_vector_dim=4, to_apply=%add_F32.v3 +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA:.*]] = s64[10,9,8,7,1,512]{{.*}} iota() + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,6,512]{{.*}} concatenate(%[[IOTA]], %scatter_indices) + CHECK: ROOT %[[SCATTER:.*]] = f32[50,49,48,47,46,512]{{.*}} scatter( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]], %updates), + CHECK-SAME: update_window_dims={4,5,6,7,8}, + CHECK-SAME: inserted_window_dims={5}, + CHECK-SAME: scatter_dims_to_operand_dims={5,0,1,2,3,4}, + CHECK-SAME: index_vector_dim=4, + CHECK-SAME: to_apply=%add_F32.v3 + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchScatter2) { + constexpr absl::string_view kModuleStr = R"( + +HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0}, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0}} + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512,1024,100], scatter_indices: s64[10,9,8,7,6,512,1024], updates: f32[10,9,8,7,30,29,28,27,26,512,1024]) -> f32[50,49,48,47,46,512,1024,100] { + %input_tensor = f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} scatter( + f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} %input_tensor, + s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} %scatter_indices, + f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0} %updates), + update_window_dims={4,5,6,7,8}, inserted_window_dims={7}, + scatter_dims_to_operand_dims={0,1,2,3,4,7}, input_batching_dims={5,6}, + scatter_indices_batching_dims={5,6}, index_vector_dim=4, to_apply=%add_F32.v3 +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[10,9,8,7,1,512,1024]{{.*}} iota() + CHECK: %[[IOTA2:.*]] = s64[10,9,8,7,1,512,1024]{{.*}} iota() + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,8,512,1024]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %scatter_indices) + CHECK: ROOT %[[SCATTER:.*]] = f32[50,49,48,47,46,512,1024,100]{{.*}} scatter( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]], %updates), + CHECK-SAME: update_window_dims={4,5,6,7,8}, + CHECK-SAME: inserted_window_dims={5,6,7}, + CHECK-SAME: scatter_dims_to_operand_dims={5,6,0,1,2,3,4,7}, + CHECK-SAME: index_vector_dim=4, + CHECK-SAME: to_apply=%add_F32.v3 + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, IndexVectorDimOnLastDim) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[50,512,1024]{2,1,0}, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0})->f32[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0}} + +ENTRY %Gather (input_tensor: f32[50,512,1024], start_indices: s64[10,9,8,7,6,512,1024]) -> f32[10,9,8,7,6,512,1024] { + %input_tensor = f32[50,512,1024]{2,1,0} parameter(0) + %start_indices = s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} + gather(f32[50,512,1024]{2,1,0} %input_tensor, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} %start_indices), + offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, operand_batching_dims={1,2}, + start_indices_batching_dims={5,6}, index_vector_dim=7, slice_sizes={1,1,1} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[10,9,8,7,6,512,1024,1]{{.*}} iota() + CHECK: %[[IOTA2:.*]] = s64[10,9,8,7,6,512,1024,1]{{.*}} iota() + CHECK: %[[RESHAPE:.*]] = s64[10,9,8,7,6,512,1024,1]{{.*}} reshape(%start_indices) + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,6,512,1024,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %[[RESHAPE]]) + CHECK: ROOT %[[GATHER:.*]] = f32[10,9,8,7,6,512,1024]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={1,2,0}, + CHECK-SAME: index_vector_dim=7, + CHECK-SAME: slice_sizes={1,1,1} + )"); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/call_graph.cc b/third_party/xla/xla/service/call_graph.cc index ea16ca0c57f7e7..80515e13ea7515 100644 --- a/third_party/xla/xla/service/call_graph.cc +++ b/third_party/xla/xla/service/call_graph.cc @@ -214,8 +214,8 @@ CallContext UnionContexts(CallContext a, CallContext b) { } else if (a == b) { return a; } else { - // Contexts are different and neither is kNone, ie one is kSequential and - // the other is kParallel. + // Contexts are different and neither is kNone, i.e. one is kControlFlow and + // the other is kEmbedded. return CallContext::kBoth; } } diff --git a/third_party/xla/xla/service/call_inliner_test.cc b/third_party/xla/xla/service/call_inliner_test.cc index 56a4e15ef52f64..ad6ee73eb14e8a 100644 --- a/third_party/xla/xla/service/call_inliner_test.cc +++ b/third_party/xla/xla/service/call_inliner_test.cc @@ -15,25 +15,23 @@ limitations under the License. #include "xla/service/call_inliner.h" +#include #include -#include #include -#include -#include +#include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/service/hlo_pass_fix.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index 11095f87bee210..e3949569386bf3 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -582,7 +582,7 @@ bool ReplicaGroupsEqual(absl::Span first, return true; } -bool IsCollective(const HloInstruction* instruction) { +bool IsNonFusionCollective(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kAllReduce: case HloOpcode::kAllReduceStart: @@ -597,24 +597,30 @@ bool IsCollective(const HloInstruction* instruction) { case HloOpcode::kCollectivePermuteDone: case HloOpcode::kReduceScatter: return true; - case HloOpcode::kFusion: - if (instruction->IsCustomFusion()) { - for (const auto* inner_inst : instruction->fused_instructions()) { - if (IsCollective(inner_inst)) { - return true; - } - } - } - return false; case HloOpcode::kAsyncStart: case HloOpcode::kAsyncUpdate: case HloOpcode::kAsyncDone: - return IsCollective(instruction->async_wrapped_instruction()); + return IsNonFusionCollective(instruction->async_wrapped_instruction()); default: return false; } } +bool IsCollective(const HloInstruction* instruction) { + if (IsNonFusionCollective(instruction)) { + return true; + } + if (instruction->opcode() == HloOpcode::kFusion && + instruction->IsCustomFusion()) { + for (const auto* inner_inst : instruction->fused_instructions()) { + if (IsCollective(inner_inst)) { + return true; + } + } + } + return false; +} + HloInstruction* IsOrHasCollectiveWithChannelId(HloInstruction* instruction) { if (instruction->opcode() == HloOpcode::kFusion) { for (auto* inner_inst : instruction->fused_instructions()) { diff --git a/third_party/xla/xla/service/collective_ops_utils.h b/third_party/xla/xla/service/collective_ops_utils.h index c611d57a6e6264..3c2ebd3d523da0 100644 --- a/third_party/xla/xla/service/collective_ops_utils.h +++ b/third_party/xla/xla/service/collective_ops_utils.h @@ -196,6 +196,10 @@ inline constexpr absl::string_view kNopCustomCallTarget = "AllocateBuffer"; inline constexpr absl::string_view kNopReturnTokenCustomCallTarget = "NopReturnToken"; +// Returns true if instruction is a collective op that is not a collective +// fusion. +bool IsNonFusionCollective(const HloInstruction* instruction); + // Returns true if instruction is a collective op or a collective fusion. bool IsCollective(const HloInstruction* instruction); diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 859b6c9b2540c2..232dae4ec7718d 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -50,10 +50,12 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/map_util.h" #include "xla/primitive_util.h" +#include "xla/service/call_graph.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/constant_value.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_parser.h" +#include "xla/service/tuple_points_to_analysis.h" #include "xla/service/value_range.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -445,7 +447,6 @@ std::vector CollectDependenciesToPipeline( ops.end()); formatting_set.insert(source_ops.begin(), source_ops.end()); std::vector to_return; - absl::flat_hash_set already_inserted; for (const HloInstruction* op : ops) { for (HloInstruction* operand : op->operands()) { if (!formatting_set.count(operand)) { @@ -697,10 +698,13 @@ class WhileLoopAnalysis { explicit WhileLoopAnalysis( HloInstruction* while_instr, int64_t max_pipelining_per_loop, bool pipeline_use_tree, bool process_different_sized_options, + TuplePointsToAnalysis* tuple_points_to_analysis, CallGraph* call_graph, std::optional known_start = std::nullopt) : while_(while_instr), loop_start_(known_start), max_pipelining_per_loop_(max_pipelining_per_loop), + tuple_points_to_analysis_(tuple_points_to_analysis), + call_graph_(call_graph), pipeline_use_tree_(pipeline_use_tree), process_different_sized_options_(process_different_sized_options) {} std::optional GetLoopIterationCount() const; @@ -796,6 +800,14 @@ class WhileLoopAnalysis { absl::flat_hash_set invariant_loop_parameters_; absl::flat_hash_set invariant_loop_instructions_; int64_t max_pipelining_per_loop_; + + // Precomputed TuplePointsToAnalysis for the HLO module containing `while_`. + // May be null, in which case the analysis will be performed from scratch. + TuplePointsToAnalysis* tuple_points_to_analysis_; + // Precomputed CallGraph analysis for the HLO module containing `while_`. + // May be null, in which case the analysis will be performed from scratch. + CallGraph* call_graph_; + bool pipeline_use_tree_; bool process_different_sized_options_; }; @@ -834,8 +846,8 @@ bool WhileLoopAnalysis::ComputeLoopStatistics() { if (loop_iteration_count_) { return true; } - std::optional parsed_loop = - PatternMatchParseWhileLoop(while_); + std::optional parsed_loop = PatternMatchParseWhileLoop( + while_, {tuple_points_to_analysis_, call_graph_}); if (!parsed_loop || !parsed_loop->static_while_loop) { return false; } @@ -1380,7 +1392,6 @@ bool IsLoopInvariant( // to still visit before visiting the HLO itself. std::vector> stack( 1, std::make_pair(instr, 0)); - absl::flat_hash_set visited; while (!stack.empty()) { auto& current = stack.back(); invariant_cache[std::get<0>(current)] = true; @@ -1796,6 +1807,8 @@ absl::Status TransformLoopForward( WhileLoopAnalysis new_loop_analysis( new_while_loop, loop_analysis.GetMaxPipeliningPerLoop(), pipeline_use_tree, process_different_sized_ops, + /*tuple_points_to_analysis=*/nullptr, + /*call_graph=*/nullptr, loop_analysis.GetLoopStart()->add(*loop_analysis.GetLoopIncrement())); new_loop_analysis.ComputeLoopStatistics(); new_loop_analysis.CollectCollectivesToMove( @@ -2035,6 +2048,17 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, << "Expected only one parameter"; HloInstruction* loop_parameter = while_body->parameter_instructions()[0]; HloInstruction* loop_init = while_loop->mutable_operand(0); + + // Clean up the SunkByPreviousStep custom calls that were inserted before. + for (HloInstruction* inst : while_body->root_instruction()->operands()) { + if (inst->opcode() == HloOpcode::kDynamicUpdateSlice && + inst->operand(1)->IsCustomCall( + CollectivePipeliner::kSunkByPreviousStep)) { + HloInstruction* cc = inst->mutable_operand(1); + TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(1, cc->mutable_operand(0))); + TF_RETURN_IF_ERROR(cc->parent()->RemoveInstruction(cc)); + } + } CHECK_EQ(while_body->root_instruction()->opcode(), HloOpcode::kTuple); for (int i = 0; i < while_body->root_instruction()->operand_count(); ++i) { is_output_instruction[while_body->root_instruction()->mutable_operand(i)] = @@ -2125,7 +2149,6 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, new_parameter_shapes.push_back(expanded_shape); new_init_operands.push_back(CreateZero(loop_computation, expanded_shape, expanded_shape.element_type())); - indices_to_insert.insert(new_root_operands.size()); Shape extra_trivial_dim_shape = ShapeUtil::PrependMajorDimension(1, pipelined->shape()); HloInstruction* reshaped = body_computation->AddInstruction( @@ -2255,8 +2278,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, TF_RETURN_IF_ERROR(output->ReplaceOperandWith(0, new_param)); TF_RETURN_IF_ERROR( old_operand_param->parent()->RemoveInstruction(old_operand_param)); - // TODO(sacer): Consider relaxing this to all inserted operands. - if (insert_non_alias_custom_call && original_to_move_indices.contains(i)) { + if (insert_non_alias_custom_call && indices_to_insert.contains(i)) { auto* old_operand = output->mutable_operand(1); auto* custom_call = cloned_body->AddInstruction(HloInstruction::CreateCustomCall( @@ -2491,17 +2513,6 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, pipelined_map[formatting_op] = expanded_transpose; continue; } - if (formatting_op->IsCustomCall( - CollectivePipeliner::kSunkByPreviousStep)) { - HloInstruction* expanded_custom_call = - loop_computation->AddInstruction(HloInstruction::CreateCustomCall( - ComputeFullOutputShape(to_move, formatting_op->shape()), - collect_operands(formatting_op), - /*custom_call_target=*/ - CollectivePipeliner::kSunkByPreviousStep)); - pipelined_map[formatting_op] = expanded_custom_call; - continue; - } CHECK(false) << "Unsupported instruction " << formatting_op->ToString(); } for (int64_t i = 0; i < to_move.output_indices.size(); ++i) { @@ -2775,8 +2786,6 @@ static absl::Status TransformLoopBackward( instruction, false, CollectivePipeliner::PipeliningDirection::kBackward, loop_analysis)); } - absl::flat_hash_map - loop_cond_replacements; auto cond_builder = HloComputation::Builder(while_loop->while_condition()->name()); HloInstruction* new_cond_param = @@ -2878,12 +2887,38 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; - std::vector while_loop_instructions; + + // Precompute module-scoped analyses. Because we are running a while-loop + // analysis over all while instructions in the module, computing them here and + // passing them in avoids recomputing them once for each while instruction. + TF_ASSIGN_OR_RETURN( + std::unique_ptr tuple_points_to_analysis, + TuplePointsToAnalysis::Run(module)); + std::unique_ptr call_graph = CallGraph::Build(module); + + std::vector>> + loop_analyses; for (HloComputation* computation : module->MakeComputationPostOrder()) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kWhile) { - while_loop_instructions.push_back(instruction); + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + if (std::none_of(instruction->while_body()->instructions().begin(), + instruction->while_body()->instructions().end(), + config_.should_process)) { + continue; + } + VLOG(1) << "Pipelinable while: " << instruction->name(); + auto loop_analysis = std::make_unique( + instruction, config_.max_pipelining_per_loop, + config_.pipeline_use_tree, config_.process_different_sized_ops, + tuple_points_to_analysis.get(), call_graph.get()); + loop_analysis->ComputeLoopStatistics(); + if (loop_analysis->GetLoopIterationCount() && + loop_analysis->GetLoopIterationCount()->GetUnsignedValue() > 0) { + loop_analyses.push_back( + std::make_pair(instruction, std::move(loop_analysis))); } } } @@ -2892,32 +2927,23 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( int64_t next_channel_id = hlo_query::NextChannelId(*module); VLOG(1) << "Pipelining on direction: " << GetPipelineDirectionString(config_.pipelining_direction); - for (HloInstruction* instruction : while_loop_instructions) { - VLOG(1) << "While: " << instruction->name(); - WhileLoopAnalysis loop_analysis( - instruction, config_.max_pipelining_per_loop, config_.pipeline_use_tree, - config_.process_different_sized_ops); - loop_analysis.ComputeLoopStatistics(); - if (!loop_analysis.GetLoopIterationCount() || - loop_analysis.GetLoopIterationCount()->GetUnsignedValue() == 0) { - continue; - } + for (auto& [instruction, loop_analysis] : loop_analyses) { VLOG(1) << "While iterations: " - << loop_analysis.GetLoopIterationCount()->ToString(); - loop_analysis.CollectCollectivesToMove( + << loop_analysis->GetLoopIterationCount()->ToString(); + loop_analysis->CollectCollectivesToMove( config_.level_to_operate_on, config_.pipelining_direction, config_.should_process, config_.acceptable_formatting, config_.should_allow_loop_variant_parameter_in_chain, config_.should_allow_control_dependencies, config_.should_add_loop_invariant_op_in_chain); - if (loop_analysis.GetMoveInfos().empty()) { + if (loop_analysis->GetMoveInfos().empty()) { continue; } - transformed_instructions += loop_analysis.GetMoveInfos().size(); + transformed_instructions += loop_analysis->GetMoveInfos().size(); VLOG(1) << "Found Collectives to optimize"; if (VLOG_IS_ON(1)) { int64_t id = 0; - for (auto& to_move : loop_analysis.GetMoveInfos()) { + for (auto& to_move : loop_analysis->GetMoveInfos()) { VLOG(1) << "Move info id: " << id++ << " with " << to_move.collectives_to_move.size() << " collectives " << to_move.dynamic_update_slices.size() @@ -2937,20 +2963,20 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( if (config_.pipelining_direction == PipeliningDirection::kForward) { CHECK(config_.reuse_pipelined_op_buffer); TF_RETURN_IF_ERROR(TransformLoopForward( - loop_analysis, !config_.last_run, config_.level_to_operate_on, + *loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.pipeline_use_tree, config_.process_different_sized_ops, config_.should_process, config_.acceptable_formatting, config_.reuse_pipelined_op_buffer, next_channel_id)); } else if (config_.pipelining_direction == PipeliningDirection::kForwardSink) { TF_RETURN_IF_ERROR(TransformLoopForwardSink( - loop_analysis, !config_.last_run, config_.level_to_operate_on, + *loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.pipeline_use_tree, config_.process_different_sized_ops, config_.should_process, next_channel_id)); } else { CHECK_EQ(config_.pipelining_direction, PipeliningDirection::kBackward); TF_RETURN_IF_ERROR(TransformLoopBackward( - loop_analysis, !config_.last_run, config_.level_to_operate_on, + *loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.process_different_sized_ops, config_.should_process, config_.acceptable_formatting, config_.postprocess_backward_peeled_op, config_.postprocess_backward_rotated_op, next_channel_id)); diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index a924a9a4be3855..53529e822bf72f 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -3083,6 +3083,16 @@ ENTRY entry { const HloInstruction* all_reduce2 = find_all_reduce(all_reduce1); EXPECT_NE(all_reduce2, nullptr); EXPECT_THAT(all_reduce2, op::AllReduce(op::GetTupleElement(op::While()))); + // The root of while body should have a dynamic-update-slice operand which has + // a custom call at operand index 1. + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + CHECK_NE(while_instr, nullptr); + const HloInstruction* dynamic_update_slice = + while_instr->while_body()->root_instruction()->operands().back(); + CHECK_EQ(dynamic_update_slice->opcode(), HloOpcode::kDynamicUpdateSlice); + const HloInstruction* custom_call = dynamic_update_slice->operand(1); + CHECK(custom_call->IsCustomCall("SunkByPreviousStep")); } TEST_F(CollectivePipelinerTest, ForwardSinkFirstDimNotMatchingLoopCount) { @@ -3375,6 +3385,7 @@ ENTRY entry { XLA_VLOG_LINES(1, module->ToString()); const HloInstruction* while_instr = FindInstruction(module.get(), HloOpcode::kWhile); + CHECK_NE(while_instr, nullptr); EXPECT_TRUE( absl::c_any_of(while_instr->users(), [](const HloInstruction* user) { return absl::c_any_of( @@ -3394,6 +3405,13 @@ ENTRY entry { return operand->opcode() == HloOpcode::kReshape; }), 2); + // The root of while body should have a dynamic-update-slice operand which has + // a custom call at operand index 1. + const HloInstruction* dynamic_update_slice = + while_instr->while_body()->root_instruction()->operand(4); + CHECK_EQ(dynamic_update_slice->opcode(), HloOpcode::kDynamicUpdateSlice); + const HloInstruction* custom_call = dynamic_update_slice->operand(1); + CHECK(custom_call->IsCustomCall("SunkByPreviousStep")); } TEST_F(CollectivePipelinerTest, CollectiveWithMultipleDUSSameBuffer) { @@ -3670,6 +3688,22 @@ ENTRY entry { op::Reshape(op::Multiply()), op::Reshape(op::Divide()), op::Reshape(op::Abs()), op::GetTupleElement(op::While()), op::GetTupleElement(op::While())))); + // The root of while body should have two dynamic-update-slice operands each + // of which has a custom call at operand index 1. + std::function is_dus_with_custom_call = + [&](const HloInstruction* inst) -> bool { + if (inst->opcode() != HloOpcode::kDynamicUpdateSlice) { + return false; + } + return inst->operand(1)->IsCustomCall("SunkByPreviousStep"); + }; + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + CHECK_NE(while_instr, nullptr); + CHECK(is_dus_with_custom_call( + while_instr->while_body()->root_instruction()->operand(7))); + CHECK(is_dus_with_custom_call( + while_instr->while_body()->root_instruction()->operand(8))); } } // namespace diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 3ed45fed71a156..39b5b70d51de30 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -99,7 +99,7 @@ filegroup( "runtime_matmul_f64.cc", "runtime_matmul_s32.cc", "runtime_fork_join.cc", - "//xla/service/cpu/runtime:runtime_srcs", + "//xla/backends/cpu/runtime:runtime_srcs", #"runtime_handle_ffi_call.cc", # TODO(b/338344732): Add "runtime_handle_ffi_call.cc". ], visibility = internal_visibility([":friends"]), @@ -127,7 +127,7 @@ filegroup( "runtime_fork_join.h", "runtime_lightweight_check.h", "runtime_matmul.h", - "//xla/service/cpu/runtime:runtime_hdrs", + "//xla/backends/cpu/runtime:runtime_hdrs", #"runtime_handle_ffi_call.h", # TODO(b/338344732): Add "runtime_handle_ffi_call.h" ], visibility = internal_visibility([":friends"]), @@ -239,6 +239,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/backends/cpu/runtime:thunk", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/mlir_hlo", @@ -328,7 +329,6 @@ cc_library( "//xla/service:while_loop_simplifier", "//xla/service:while_loop_trip_count_annotator", "//xla/service:zero_sized_hlo_elimination", - "//xla/service/cpu/runtime:thunk", "//xla/service/llvm_ir:llvm_command_line_options", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd:stateful_rng_spmd_partitioner", @@ -557,6 +557,9 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/runtime:buffer_allocations", + "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:thunk_executor", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:computation_layout", @@ -570,9 +573,6 @@ cc_library( "//xla/service:maybe_owning_device_memory", "//xla/service:shaped_buffer", "//xla/service:xla_debug_info_manager", - "//xla/service/cpu/runtime:buffer_allocations", - "//xla/service/cpu/runtime:thunk", - "//xla/service/cpu/runtime:thunk_executor", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/host:host_kernel_c_api", @@ -704,12 +704,10 @@ xla_cc_test( "//xla/service/llvm_ir:llvm_util", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", - "@llvm-project//llvm:OrcJIT", - "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -851,33 +849,33 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla/backends/cpu/runtime:all_gather_thunk", + "//xla/backends/cpu/runtime:all_reduce_thunk", + "//xla/backends/cpu/runtime:all_to_all_thunk", + "//xla/backends/cpu/runtime:call_thunk", + "//xla/backends/cpu/runtime:collective_permute_thunk", + "//xla/backends/cpu/runtime:collective_thunk", + "//xla/backends/cpu/runtime:conditional_thunk", + "//xla/backends/cpu/runtime:convolution_thunk", + "//xla/backends/cpu/runtime:copy_thunk", + "//xla/backends/cpu/runtime:custom_call_thunk", + "//xla/backends/cpu/runtime:dot_thunk", + "//xla/backends/cpu/runtime:fft_thunk", + "//xla/backends/cpu/runtime:infeed_thunk", + "//xla/backends/cpu/runtime:kernel_thunk", + "//xla/backends/cpu/runtime:logical_id_thunk", + "//xla/backends/cpu/runtime:outfeed_thunk", + "//xla/backends/cpu/runtime:reduce_scatter_thunk", + "//xla/backends/cpu/runtime:resource_use", + "//xla/backends/cpu/runtime:rng_state_thunk", + "//xla/backends/cpu/runtime:sort_thunk", + "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:topk_thunk", + "//xla/backends/cpu/runtime:while_thunk", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:hlo_module_config", - "//xla/service/cpu/runtime:all_gather_thunk", - "//xla/service/cpu/runtime:all_reduce_thunk", - "//xla/service/cpu/runtime:all_to_all_thunk", - "//xla/service/cpu/runtime:call_thunk", - "//xla/service/cpu/runtime:collective_permute_thunk", - "//xla/service/cpu/runtime:collective_thunk", - "//xla/service/cpu/runtime:conditional_thunk", - "//xla/service/cpu/runtime:convolution_thunk", - "//xla/service/cpu/runtime:copy_thunk", - "//xla/service/cpu/runtime:custom_call_thunk", - "//xla/service/cpu/runtime:dot_thunk", - "//xla/service/cpu/runtime:fft_thunk", - "//xla/service/cpu/runtime:infeed_thunk", - "//xla/service/cpu/runtime:kernel_thunk", - "//xla/service/cpu/runtime:logical_id_thunk", - "//xla/service/cpu/runtime:outfeed_thunk", - "//xla/service/cpu/runtime:reduce_scatter_thunk", - "//xla/service/cpu/runtime:resource_use", - "//xla/service/cpu/runtime:rng_state_thunk", - "//xla/service/cpu/runtime:sort_thunk", - "//xla/service/cpu/runtime:thunk", - "//xla/service/cpu/runtime:topk_thunk", - "//xla/service/cpu/runtime:while_thunk", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -1066,7 +1064,7 @@ cc_library( deps = [ ":runtime_lightweight_check", "//xla:executable_run_options", - "//xla/service/cpu/runtime:conv_impl", + "//xla/backends/cpu/runtime:convolution_thunk_internal", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1084,7 +1082,7 @@ cc_library( deps = [ ":runtime_lightweight_check", "//xla:executable_run_options", - "//xla/service/cpu/runtime:conv_impl", + "//xla/backends/cpu/runtime:convolution_thunk_internal", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1212,7 +1210,7 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ - "//xla/service/cpu/runtime:conv_impl", + "//xla/backends/cpu/runtime:convolution_thunk_internal", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1228,7 +1226,7 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ - "//xla/service/cpu/runtime:conv_impl", + "//xla/backends/cpu/runtime:convolution_thunk_internal", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", diff --git a/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc index 94418f2ab82aee..bf06650e96173f 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc @@ -52,13 +52,63 @@ static void BM_AddF32(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); } -BENCHMARK(BM_AddF32) - ->MeasureProcessCPUTime() - ->Arg(128) - ->Arg(256) - ->Arg(512) - ->Arg(1024) - ->Arg(8192) - ->Arg(16384); +static void BM_AddBF16(benchmark::State& state) { + int64_t d0 = state.range(0); + + std::string_view hlo = R"( + HloModule add_bf16_$d0 + + ENTRY e { + p0 = bf16[1,2,1,$d0,256] parameter(0) + p1 = bf16[1,2,1,$d0,256] parameter(1) + ROOT add = bf16[1,2,1,$d0,256] add(p0, p1) + } + )"; + + std::minstd_rand0 engine; + + auto shape = ShapeUtil::MakeShape(BF16, {1, 2, 1, d0, 256}); + auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + auto p1 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + + std::vector args = {&p0, &p1}; + CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); +} + +static void BM_ConvertF32ToBF16(benchmark::State& state) { + int64_t d0 = state.range(0); + + std::string_view hlo = R"( + HloModule convert_f32_to_bf16_$d0 + + ENTRY e { + p0 = f32[1,2,1,$d0,256] parameter(0) + ROOT convert = bf16[1,2,1,$d0,256] convert(p0) + } + )"; + + std::minstd_rand0 engine; + + auto shape = ShapeUtil::MakeShape(F32, {1, 2, 1, d0, 256}); + auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + + std::vector args = {&p0}; + CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); +} + +#define BENCHMARK_SIZES(NAME) \ + BENCHMARK(NAME) \ + ->MeasureProcessCPUTime() \ + ->Arg(128) \ + ->Arg(256) \ + ->Arg(512) \ + ->Arg(1024) \ + ->Arg(8192) \ + ->Arg(16384) \ + ->Arg(32768) + +BENCHMARK_SIZES(BM_AddF32); +BENCHMARK_SIZES(BM_AddBF16); +BENCHMARK_SIZES(BM_ConvertF32ToBF16); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc index 1ec04bc4cae4d8..c5399e93c8d7cd 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc @@ -57,13 +57,45 @@ static void BM_ReduceAddF32(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); } -BENCHMARK(BM_ReduceAddF32) - ->MeasureProcessCPUTime() - ->Arg(128) - ->Arg(256) - ->Arg(512) - ->Arg(1024) - ->Arg(8192) - ->Arg(16384); +static void BM_ReduceAddBF16(benchmark::State& state) { + int64_t d0 = state.range(0); + + std::string_view hlo = R"( + HloModule reduce_add_bf16_$d0 + + add { + p0 = bf16[] parameter(0) + p1 = bf16[] parameter(1) + ROOT add = bf16[] add(p0, p1) + } + + ENTRY e { + p0 = bf16[1,2,1,$d0,256] parameter(0) + c0 = bf16[] constant(0) + ROOT reduce = bf16[1,2] reduce(p0, c0), dimensions={2,3,4}, to_apply=add + } + )"; + + std::minstd_rand0 engine; + + auto shape = ShapeUtil::MakeShape(BF16, {1, 2, 1, d0, 256}); + auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + + std::vector args = {&p0}; + CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); +} + +#define BENCHMARK_SIZES(NAME) \ + BENCHMARK(NAME) \ + ->MeasureProcessCPUTime() \ + ->Arg(128) \ + ->Arg(256) \ + ->Arg(512) \ + ->Arg(1024) \ + ->Arg(8192) \ + ->Arg(16384) + +BENCHMARK_SIZES(BM_ReduceAddF32); +BENCHMARK_SIZES(BM_ReduceAddBF16); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index fc292c9378976e..13a47eb2b36b61 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -68,6 +68,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/DialectConversion.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -113,7 +114,6 @@ limitations under the License. #include "xla/service/cpu/ir_emitter.h" #include "xla/service/cpu/ir_emitter2.h" #include "xla/service/cpu/parallel_task_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/cpu/target_machine_features.h" #include "xla/service/cpu/thunk_emitter.h" diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index 82f48a90d32e2b..e1f4b213170651 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -41,6 +41,9 @@ limitations under the License. #include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/Error.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -48,9 +51,6 @@ limitations under the License. #include "xla/literal.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/cpu_runtime.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" @@ -183,6 +183,7 @@ absl::StatusOr> CpuExecutable::Create( std::move(hlo_profile_index_map), std::move(assignment))); executable->jit_ = std::move(jit); + executable->jit_->DoneCompiling(); executable->function_registry_ = FunctionRegistry(executable->jit_.get()); TF_ASSIGN_OR_RETURN(executable->thunks_, diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h index 8c8883d8673684..2c2aa248bcbe5d 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.h +++ b/third_party/xla/xla/service/cpu/cpu_executable.h @@ -28,13 +28,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" diff --git a/third_party/xla/xla/service/cpu/executable.proto b/third_party/xla/xla/service/cpu/executable.proto index bca8a2cc2c64e4..d222660d0f0c35 100644 --- a/third_party/xla/xla/service/cpu/executable.proto +++ b/third_party/xla/xla/service/cpu/executable.proto @@ -17,7 +17,6 @@ syntax = "proto3"; package xla.cpu; -import "xla/service/cpu/xla_framework.proto"; import "xla/service/hlo.proto"; import "xla/xla.proto"; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 1ee2f6997695bc..e043b5c2e13bec 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -113,7 +113,9 @@ class IrEmitter::CpuElementalIrEmitter : public ElementalIrEmitter { public: CpuElementalIrEmitter(const HloModuleConfig& module_config, IrEmitter* ir_emitter, llvm::Module* module) - : ElementalIrEmitter(module, ir_emitter->b()), + : ElementalIrEmitter( + module, ir_emitter->b(), + Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/true}), hlo_module_config_(module_config), ir_emitter_(ir_emitter) {} diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index 2e64beadb55546..e7b671268093fc 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -129,7 +129,9 @@ class IrEmitter2::ElementalIrEmitter : public xla::ElementalIrEmitter { ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b, const HloModule* hlo_module, IrEmitter* nested_ir_emitter, bool fast_min_max) - : xla::ElementalIrEmitter(module, b), + : xla::ElementalIrEmitter( + module, b, + Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/true}), hlo_module_(hlo_module), nested_ir_emitter_(nested_ir_emitter), fast_min_max_(fast_min_max) {} @@ -222,6 +224,13 @@ IrEmitter2::IrEmitter2(const HloModule& hlo_module, llvm::Module* module, bool IrEmitter2::fast_min_max() const { return hlo_module_.config().debug_options().xla_cpu_enable_fast_min_max(); } +IrEmitter2::KernelInfo::KernelInfo(KernelPrototype prototype, + const se::BlockDim& block_dims, + const se::ThreadDim& thread_dims) + : name(prototype.function->getName().str()), + block_dims(block_dims), + thread_dims(thread_dims), + invariant_buffers(std::move(prototype.invariant_buffers)) {} absl::StatusOr IrEmitter2::EmitElementalHostKernel( const HloInstruction* instr) { @@ -250,8 +259,8 @@ absl::StatusOr IrEmitter2::EmitElementalHostKernel( se::ThreadDim thread_dims, EmitElementalLoops(b, instr, kernel_prototype, element_generator)); - return kernels_.emplace_back(KernelInfo{ - kernel_prototype.function->getName().str(), se::BlockDim(), thread_dims}); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); } absl::StatusOr IrEmitter2::EmitPadHostKernel( @@ -281,8 +290,7 @@ absl::StatusOr IrEmitter2::EmitPadHostKernel( nested_ir_emitter_->PopComputeFunction(); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitFusionHostKernel( @@ -326,9 +334,8 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( const_cast(fusion), kernel_prototype.results[0], &fused_emitter, &b)); - return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), + se::BlockDim(), se::ThreadDim())); } // Emit plain elemental loops for the fusion operation. @@ -340,8 +347,8 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( se::ThreadDim thread_dims, EmitElementalLoops(b, fusion, kernel_prototype, element_generator)); - return kernels_.emplace_back(KernelInfo{ - kernel_prototype.function->getName().str(), se::BlockDim(), thread_dims}); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); } absl::StatusOr IrEmitter2::EmitReductionHostKernel( @@ -393,8 +400,7 @@ absl::StatusOr IrEmitter2::EmitDotHostKernel( /*allow_runtime_calls=*/false)); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitConcatenateHostKernel( @@ -414,9 +420,8 @@ absl::StatusOr IrEmitter2::EmitConcatenateHostKernel( llvm_ir::IrArray output_array = kernel_prototype.results[0]; TF_RETURN_IF_ERROR(::xla::cpu::EmitFastConcatenate( instr, kernel_prototype.arguments, output_array, module_, ir_builder)); - return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), + se::BlockDim(), se::ThreadDim())); } VLOG(1) << "Could not emit fast concatenate for " << instr->ToString() << ": " << fast_impl_reason.message(); @@ -477,8 +482,7 @@ absl::StatusOr IrEmitter2::EmitDotFusionHostKernel( /*allow_runtime_calls=*/false)); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitSliceToDynamicHostKernel( @@ -496,8 +500,7 @@ absl::StatusOr IrEmitter2::EmitSliceToDynamicHostKernel( TF_RETURN_IF_ERROR(nested_ir_emitter_->EmitSliceToDynamic( instr, kernel_prototype.arguments, output_array)); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr @@ -514,8 +517,7 @@ IrEmitter2::EmitSelectAndScatterHostKernel(const HloInstruction* instr) { output_array)); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr @@ -535,9 +537,8 @@ IrEmitter2::EmitDynamicUpdateSliceHostKernel(const HloInstruction* instr) { kernel_prototype.arguments, kernel_prototype.results.front(), llvm_ir::IrName(instr, "in_place"), &b)); - return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), + se::BlockDim(), se::ThreadDim())); } return EmitElementalHostKernel(instr); @@ -794,10 +795,6 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( // Collect a set of invariant (read-only) buffer slices. If a buffer slice is // not a part of result set, then it must be a read-only buffer. - // - // TODO(ezhulenev): Pass this information to KernelThunk and add an extra run - // time check to verify that this property holds, as otherwise it can lead to - // hard to debug errors. absl::flat_hash_set invariant_slices; for (const KernelParameter& argument : arguments) { if (!result_slices.contains(argument.slice)) { @@ -878,7 +875,8 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( kernel_thread_dims, kernel_thread, std::move(ir_arguments), - std::move(ir_results)}; + std::move(ir_results), + std::move(invariant_slices)}; } absl::StatusOr IrEmitter2::EmitKernelPrototype( diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.h b/third_party/xla/xla/service/cpu/ir_emitter2.h index a205e91c10e057..b10f9034d19d2b 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.h +++ b/third_party/xla/xla/service/cpu/ir_emitter2.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -62,6 +63,12 @@ namespace xla::cpu { // // WARNING: This is under construction and will eventually replace IrEmitter. class IrEmitter2 { + public: + friend class IrEmitter2Test; + + private: + struct KernelPrototype; + public: IrEmitter2(const HloModule& hlo_module, llvm::Module* module, IrEmitter* nested_ir_emitter); @@ -88,28 +95,16 @@ class IrEmitter2 { llvm::Value* z; }; - // A kernel function prototype with all the LLVM values that might be needed - // to emit the actual kernel body. - struct KernelPrototype { - llvm::Function* function; - llvm::BasicBlock* return_block; - - // LLVM values identifying kernel invocation thread coordinates. - KernelThreadDims thread_dims; - KernelThread thread; - - // LLVM values corresponding to the kernel arguments and results arrays. All - // tuples are flattened as we do not have any tuples at run time and only - // read and write data from/to leaf arrays. - std::vector arguments; - std::vector results; - }; - // Emitted kernel information that defines how to launch it at run time. struct KernelInfo { + explicit KernelInfo(KernelPrototype prototype, + const se::BlockDim& block_dims, + const se::ThreadDim& thread_dims); + std::string name; se::BlockDim block_dims; se::ThreadDim thread_dims; + absl::flat_hash_set invariant_buffers; }; // Emitted comparator function information (for sort operation). @@ -166,6 +161,30 @@ class IrEmitter2 { absl::StatusOr EmitSortComparator( const HloInstruction* instr); + private: + class ElementalIrEmitter; + + // A kernel function prototype with all the LLVM values that might be needed + // to emit the actual kernel body. + struct KernelPrototype { + llvm::Function* function; + llvm::BasicBlock* return_block; + + // LLVM values identifying kernel invocation thread coordinates. + KernelThreadDims thread_dims; + KernelThread thread; + + // LLVM values corresponding to the kernel arguments and results arrays. All + // tuples are flattened as we do not have any tuples at run time and only + // read and write data from/to leaf arrays. + std::vector arguments; + std::vector results; + + // Set containing all invariant (read-only) buffers. A buffer is read-only + // if it is not aliased with any result. + absl::flat_hash_set invariant_buffers; + }; + // Emits a host kernel prototype and prepares function for emitting kernel // body into it. absl::StatusOr EmitKernelPrototype( @@ -176,9 +195,6 @@ class IrEmitter2 { absl::StatusOr EmitKernelPrototype( const HloInstruction* instr); - private: - class ElementalIrEmitter; - // Parallel partition bounds for parallelized outer dimensions: // vector<[i64 lower_bound, i64 upper_bound]> using ParallelPartitionBounds = diff --git a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc index 89c6c16863aec9..b2e8414a344983 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc @@ -17,8 +17,11 @@ limitations under the License. #include #include +#include #include +#include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" @@ -42,9 +45,63 @@ limitations under the License. #include "tsl/platform/test.h" namespace xla::cpu { -namespace { -using IrEmitter2Test = HloTestBase; +class IrEmitter2Test : public HloTestBase { + public: + // This is a proxy function that allows us call private method + // IrEmitter2::EmitKernelPrototype. + static auto EmitKernelPrototype( + IrEmitter2& ir_emitter, + const std::vector& arguments, + const std::vector& results) { + return ir_emitter.EmitKernelPrototype("test", arguments, results); + } + + absl::StatusOr MakeIrEmitter2(llvm::Module& module, + const HloModule& hlo) { + TF_ASSIGN_OR_RETURN( + buffer_assignment_, + BufferAssigner::Run( + &hlo, std::make_unique(&hlo), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return /*alignment=*/1; })); + + target_machine_ = + std::make_unique( + [](int64_t size) { return 1; }); + + nested_ir_emitter_ = absl::WrapUnique( + new IrEmitter(nullptr, hlo, *buffer_assignment_, &module, {}, {}, {}, + target_machine_.get(), false)); + + return IrEmitter2(hlo, &module, nested_ir_emitter_.get()); + } + + // TODO(abanas): This function could be static. It requires making the + // underlying FindInstruction function static first. + absl::StatusOr EmitElementalHostKernel( + IrEmitter2& ir_emitter, HloModule& hlo, + std::string_view instruction_name) { + HloInstruction* instruction = FindInstruction(&hlo, instruction_name); + + if (instruction == nullptr) { + return absl::InternalError("Instruction not found"); + } + TF_ASSIGN_OR_RETURN(IrEmitter2::KernelInfo kernel, + ir_emitter.EmitElementalHostKernel(instruction)); + return kernel; + } + + private: + // Dependencies of IrEmitter2. These are created in MakeIrEmitter2 and kept + // alive for the duration of the test, because IrEmitter2 does not take + // ownership of them. + std::unique_ptr buffer_assignment_; + std::unique_ptr target_machine_; + std::unique_ptr nested_ir_emitter_; +}; + +namespace { TEST_F(IrEmitter2Test, BuildKernelPrototype) { auto hlo = std::make_unique("test", HloModuleConfig()); @@ -66,9 +123,8 @@ TEST_F(IrEmitter2Test, BuildKernelPrototype) { {shape, res1}}; IrEmitter2 ir_emitter(*hlo, module.get(), /*nested_ir_emitter=*/nullptr); - TF_ASSERT_OK_AND_ASSIGN( - IrEmitter2::KernelPrototype prototype, - ir_emitter.EmitKernelPrototype("test", arguments, results)); + TF_ASSERT_OK_AND_ASSIGN(auto prototype, + EmitKernelPrototype(ir_emitter, arguments, results)); llvm::IRBuilder<> b(context); b.SetInsertPoint(prototype.function->getEntryBlock().getTerminator()); @@ -85,38 +141,38 @@ TEST_F(IrEmitter2Test, BuildKernelPrototype) { ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( CHECK: define ptr @test(ptr %0) #0 { - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 0 - CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 0 - CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 1 - CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 2 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 2 CHECK: load i64 CHECK: load i64 CHECK: load i64 - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 1 - CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 0 - CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 1 - CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 2 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 2 CHECK: load i64 CHECK: load i64 CHECK: load i64 - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 0, i32 0 CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0:.+]], !dereferenceable ![[DEREF_BYTES:.*]], !align ![[ALIGNMENT:.+]] - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 1, i32 0 CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 2, i32 0 CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 3, i32 0 CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] @@ -167,25 +223,9 @@ TEST_F(IrEmitter2Test, EmitElementalKernel) { })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - HloInstruction* convert = FindInstruction(hlo.get(), "convert"); - ASSERT_NE(convert, nullptr); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr buffer_assignment, - BufferAssigner::Run( - hlo.get(), std::make_unique(hlo.get()), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return /*alignment=*/1; })); - - TargetMachineFeaturesWithFakeAlignmentLogic target_machine( - [](int64_t size) { return 1; }); - - IrEmitter nested_ir_emitter(nullptr, *hlo, *buffer_assignment, module.get(), - {}, {}, {}, &target_machine, false); - - IrEmitter2 ir_emitter(*hlo, module.get(), &nested_ir_emitter); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - ir_emitter.EmitElementalHostKernel(convert)); + EmitElementalHostKernel(ir_emitter, *hlo, "convert")); ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( CHECK: define ptr @convert(ptr %0) #0 { @@ -207,25 +247,9 @@ TEST_F(IrEmitter2Test, EmitParallelKernel) { })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - HloInstruction* convert = FindInstruction(hlo.get(), "convert"); - ASSERT_NE(convert, nullptr); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr buffer_assignment, - BufferAssigner::Run( - hlo.get(), std::make_unique(hlo.get()), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return /*alignment=*/1; })); - - TargetMachineFeaturesWithFakeAlignmentLogic target_machine( - [](int64_t size) { return 1; }); - - IrEmitter nested_ir_emitter(nullptr, *hlo, *buffer_assignment, module.get(), - {}, {}, {}, &target_machine, false); - - IrEmitter2 ir_emitter(*hlo, module.get(), &nested_ir_emitter); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - ir_emitter.EmitElementalHostKernel(convert)); + EmitElementalHostKernel(ir_emitter, *hlo, "convert")); ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( CHECK: @convert_parallel_bounds = private constant [8 x [4 x [2 x i64]]] @@ -244,5 +268,66 @@ TEST_F(IrEmitter2Test, EmitParallelKernel) { )")); } +using IrEmitter2InvariantBuffersTest = IrEmitter2Test; + +TEST_F(IrEmitter2InvariantBuffersTest, AllInvariantBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m + ENTRY main { + p0 = f32[2,2] parameter(0) + ROOT add.0 = f32[2,2] add(p0, p0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, + EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); + + ASSERT_EQ(kernel.invariant_buffers.size(), 1); +} + +TEST_F(IrEmitter2InvariantBuffersTest, NoInvariantBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m, input_output_alias={ {}: (0, {}, must-alias) } + ENTRY main { + p0 = f32[2,2] parameter(0) + ROOT add.0 = f32[2,2] add(p0, p0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, + EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); + + ASSERT_EQ(kernel.invariant_buffers.size(), 0); +} + +TEST_F(IrEmitter2InvariantBuffersTest, MixedBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m, input_output_alias={ {}: (1, {}, must-alias) } + ENTRY main { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT add.0 = f32[2,2] add(p0, p1) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, + EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); + + // TODO(abanas): Verify also which buffer is read-only, not only the count. + ASSERT_EQ(kernel.invariant_buffers.size(), 1); +} + } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/conv_impl.cc b/third_party/xla/xla/service/cpu/runtime/conv_impl.cc deleted file mode 100644 index 199a97919fa53e..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/conv_impl.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#define EIGEN_USE_THREADS - -#include "xla/service/cpu/runtime/conv_impl.h" - -namespace tensorflow::xla { - -// Instantiate Conv2D template for all supported devices and data types. -#define CONV2D_INSTANTIATE_TEMPLATE(EigenDevice, ScalarType) \ - template void EigenConv2DImpl( \ - const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ - Eigen::Index input_y, Eigen::Index input_channels, \ - Eigen::Index kernel_x, Eigen::Index kernel_y, \ - Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ - Eigen::Index output_x, Eigen::Index output_y, Eigen::Index x_stride, \ - Eigen::Index y_stride, Eigen::Index padding_x_before, \ - Eigen::Index padding_x_after, Eigen::Index padding_y_before, \ - Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \ - Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \ - Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \ - std::optional> done_callback) - -CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); -CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); -CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); -CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); - -#undef CONV2D_INSTANTIATE_TEMPLATE - -// Instantiate Conv3D template for all supported devices and data types. -#define CONV3D_INSTANTIATE_TEMPLATE(EigenDevice, ScalarType) \ - template void EigenConv3DImpl( \ - const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ - Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \ - Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \ - Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ - Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, \ - Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, \ - Eigen::Index padding_x_before, Eigen::Index padding_x_after, \ - Eigen::Index padding_y_before, Eigen::Index padding_y_after, \ - Eigen::Index padding_z_before, Eigen::Index padding_z_after, \ - Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, \ - Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \ - Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \ - Eigen::Index feature_group_count, \ - std::optional> done_callback) - -CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); -CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); -CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); -CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); - -} // namespace tensorflow::xla diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc deleted file mode 100644 index 63696c0e83278c..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/runtime/kernel_thunk.h" - -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/maybe_owning_device_memory.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" - -namespace xla::cpu { -namespace { - -class AddF32HostKernel : public Thunk::FunctionRegistry { - public: - absl::StatusOr FindKernel(std::string_view name) override { - return +[](const SE_HOST_KernelCallFrame* call_frame) { - const SE_HOST_KernelArg& in = call_frame->args[0]; - const SE_HOST_KernelArg& out = call_frame->args[1]; - - float* in_ptr = reinterpret_cast(in.data); - float* out_ptr = reinterpret_cast(out.data); - - uint64_t i = call_frame->thread->x; - *(out_ptr + i) = *(in_ptr + i) + *(in_ptr + i); - - return static_cast(nullptr); - }; - } -}; - -TEST(KernelThunkTest, CheckAlignment) { - auto thunk = KernelThunk::Create({"test"}, {}, {}, "test", se::ThreadDim(), - /*min_alignment=*/3); - EXPECT_TRUE(absl::StrContains(thunk.status().message(), - "minimum alignment 3 is not a power of 2")); -} - -TEST(KernelThunkTest, AddF32) { - std::vector buffers; - std::vector in = {1.0, 2.0, 3.0, 4.0}; - std::vector out(4, 0.0); - - size_t size_in_bytes = in.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); - - BufferAllocation in_alloc(0, size_in_bytes, 0); - BufferAllocation out_alloc(1, size_in_bytes, 0); - - BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); - BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); - - TF_ASSERT_OK_AND_ASSIGN( - auto thunk, KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, - "add_f32", se::ThreadDim(4))); - - AddF32HostKernel host_kernels; - Thunk::ExecuteParams params = {&host_kernels, &allocations}; - - auto execute_event = thunk->Execute(params); - tsl::BlockUntilReady(execute_event); - ASSERT_FALSE(execute_event.IsError()); - - std::vector expected = {2.0, 4.0, 6.0, 8.0}; - EXPECT_EQ(out, expected); -} - -} // namespace -} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime_conv2d.cc b/third_party/xla/xla/service/cpu/runtime_conv2d.cc index 4bc0d03fe8099e..696f556b20fd7a 100644 --- a/third_party/xla/xla/service/cpu/runtime_conv2d.cc +++ b/third_party/xla/xla/service/cpu/runtime_conv2d.cc @@ -20,8 +20,8 @@ limitations under the License. #define EIGEN_USE_THREADS #include "absl/base/dynamic_annotations.h" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" #include "xla/executable_run_options.h" -#include "xla/service/cpu/runtime/conv_impl.h" #include "xla/service/cpu/runtime_lightweight_check.h" ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF32( @@ -37,7 +37,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF32( const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tensorflow::xla::EigenConv2DImpl( + xla::cpu::internal::EigenConv2D( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, @@ -59,7 +59,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16( const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tensorflow::xla::EigenConv2DImpl( + xla::cpu::internal::EigenConv2D( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, diff --git a/third_party/xla/xla/service/cpu/runtime_conv3d.cc b/third_party/xla/xla/service/cpu/runtime_conv3d.cc index 7e83269e289fdd..fee2293d73fd97 100644 --- a/third_party/xla/xla/service/cpu/runtime_conv3d.cc +++ b/third_party/xla/xla/service/cpu/runtime_conv3d.cc @@ -20,8 +20,8 @@ limitations under the License. #define EIGEN_USE_THREADS #include "absl/base/dynamic_annotations.h" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" #include "xla/executable_run_options.h" -#include "xla/service/cpu/runtime/conv_impl.h" #include "xla/service/cpu/runtime_lightweight_check.h" ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF32( @@ -39,7 +39,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF32( const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tensorflow::xla::EigenConv3DImpl( + xla::cpu::internal::EigenConv3D( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_x, input_y, input_z, input_channels, kernel_x, kernel_y, kernel_z, kernel_channels, kernel_filters, output_x, output_y, output_z, x_stride, @@ -64,7 +64,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16( const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tensorflow::xla::EigenConv3DImpl( + xla::cpu::internal::EigenConv3D( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_x, input_y, input_z, input_channels, kernel_x, kernel_y, kernel_z, kernel_channels, kernel_filters, output_x, output_y, output_z, x_stride, diff --git a/third_party/xla/xla/service/cpu/runtime_single_threaded_conv2d.cc b/third_party/xla/xla/service/cpu/runtime_single_threaded_conv2d.cc index a770681987400d..bc749f5c42be20 100644 --- a/third_party/xla/xla/service/cpu/runtime_single_threaded_conv2d.cc +++ b/third_party/xla/xla/service/cpu/runtime_single_threaded_conv2d.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/base/dynamic_annotations.h" -#include "xla/service/cpu/runtime/conv_impl.h" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedConv2DF16( @@ -31,7 +31,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF16( int64_t padding_left, int64_t padding_right, int64_t lhs_row_dilation, int64_t lhs_col_dilation, int64_t rhs_row_dilation, int64_t rhs_col_dilation, int64_t feature_group_count) { - tensorflow::xla::EigenConv2DImpl( + xla::cpu::internal::EigenConv2D( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, @@ -51,7 +51,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF32( int64_t padding_right, int64_t lhs_row_dilation, int64_t lhs_col_dilation, int64_t rhs_row_dilation, int64_t rhs_col_dilation, int64_t feature_group_count) { - tensorflow::xla::EigenConv2DImpl( + xla::cpu::internal::EigenConv2D( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, diff --git a/third_party/xla/xla/service/cpu/runtime_single_threaded_conv3d.cc b/third_party/xla/xla/service/cpu/runtime_single_threaded_conv3d.cc index 08ff94d06e7e71..d0d807aeb26e69 100644 --- a/third_party/xla/xla/service/cpu/runtime_single_threaded_conv3d.cc +++ b/third_party/xla/xla/service/cpu/runtime_single_threaded_conv3d.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/base/dynamic_annotations.h" -#include "xla/service/cpu/runtime/conv_impl.h" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedConv3DF32( @@ -33,7 +33,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF32( int64_t lhs_y_dilation, int64_t lhs_z_dilation, int64_t rhs_x_dilation, int64_t rhs_y_dilation, int64_t rhs_z_dilation, int64_t feature_group_count) { - tensorflow::xla::EigenConv3DImpl( + xla::cpu::internal::EigenConv3D( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_x, input_y, input_z, input_channels, kernel_x, kernel_y, kernel_z, kernel_channels, kernel_filters, output_x, output_y, output_z, x_stride, y_stride, @@ -56,7 +56,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF16( int64_t lhs_y_dilation, int64_t lhs_z_dilation, int64_t rhs_x_dilation, int64_t rhs_y_dilation, int64_t rhs_z_dilation, int64_t feature_group_count) { - tensorflow::xla::EigenConv3DImpl( + xla::cpu::internal::EigenConv3D( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_x, input_y, input_z, input_channels, kernel_x, kernel_y, kernel_z, kernel_channels, kernel_filters, output_x, output_y, output_z, x_stride, y_stride, diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 49ed9d3d0b343c..d2be391fa9fe8f 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -28,6 +28,29 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/all_gather_thunk.h" +#include "xla/backends/cpu/runtime/all_reduce_thunk.h" +#include "xla/backends/cpu/runtime/all_to_all_thunk.h" +#include "xla/backends/cpu/runtime/call_thunk.h" +#include "xla/backends/cpu/runtime/collective_permute_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/conditional_thunk.h" +#include "xla/backends/cpu/runtime/convolution_thunk.h" +#include "xla/backends/cpu/runtime/copy_thunk.h" +#include "xla/backends/cpu/runtime/custom_call_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/fft_thunk.h" +#include "xla/backends/cpu/runtime/infeed_thunk.h" +#include "xla/backends/cpu/runtime/kernel_thunk.h" +#include "xla/backends/cpu/runtime/logical_id_thunk.h" +#include "xla/backends/cpu/runtime/outfeed_thunk.h" +#include "xla/backends/cpu/runtime/reduce_scatter_thunk.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/rng_state_thunk.h" +#include "xla/backends/cpu/runtime/sort_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/topk_thunk.h" +#include "xla/backends/cpu/runtime/while_thunk.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -42,29 +65,6 @@ limitations under the License. #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/ir_emission_utils.h" #include "xla/service/cpu/ir_emitter2.h" -#include "xla/service/cpu/runtime/all_gather_thunk.h" -#include "xla/service/cpu/runtime/all_reduce_thunk.h" -#include "xla/service/cpu/runtime/all_to_all_thunk.h" -#include "xla/service/cpu/runtime/call_thunk.h" -#include "xla/service/cpu/runtime/collective_permute_thunk.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/conditional_thunk.h" -#include "xla/service/cpu/runtime/convolution_thunk.h" -#include "xla/service/cpu/runtime/copy_thunk.h" -#include "xla/service/cpu/runtime/custom_call_thunk.h" -#include "xla/service/cpu/runtime/dot_thunk.h" -#include "xla/service/cpu/runtime/fft_thunk.h" -#include "xla/service/cpu/runtime/infeed_thunk.h" -#include "xla/service/cpu/runtime/kernel_thunk.h" -#include "xla/service/cpu/runtime/logical_id_thunk.h" -#include "xla/service/cpu/runtime/outfeed_thunk.h" -#include "xla/service/cpu/runtime/reduce_scatter_thunk.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/rng_state_thunk.h" -#include "xla/service/cpu/runtime/sort_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/topk_thunk.h" -#include "xla/service/cpu/runtime/while_thunk.h" #include "xla/service/cpu/target_machine_features.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" @@ -516,9 +516,9 @@ absl::StatusOr ThunkEmitter::EmitConcatenateKernelThunk( ir_emitter_.EmitConcatenateHostKernel(concatenate)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitGetDimensionSizeThunk( @@ -609,9 +609,9 @@ absl::StatusOr ThunkEmitter::EmitElementalKernelThunk( ir_emitter_.EmitElementalHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitPadKernelThunk( @@ -620,9 +620,9 @@ absl::StatusOr ThunkEmitter::EmitPadKernelThunk( TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitPadHostKernel(padInstr)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(padInstr)); - return ThunkSequence::Of( - ThunkInfo(padInstr), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + padInstr, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitFusionKernelThunk( @@ -631,9 +631,9 @@ absl::StatusOr ThunkEmitter::EmitFusionKernelThunk( TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitFusionHostKernel(fusion)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitReductionKernelThunk( @@ -642,9 +642,9 @@ absl::StatusOr ThunkEmitter::EmitReductionKernelThunk( ir_emitter_.EmitReductionHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitRngThunk( @@ -799,9 +799,7 @@ absl::StatusOr ThunkEmitter::EmitDotThunk( TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of(ThunkInfo(instruction), - buffers.arguments, buffers.results, - kernel.name, kernel.thread_dims); + return MakeKernelThunkSequence(instruction, buffers, kernel); } // Emit DotThunk implementing dot instruction as a library call. @@ -980,9 +978,9 @@ absl::StatusOr ThunkEmitter::EmitSliceToDynamicThunk( ir_emitter_.EmitSliceToDynamicHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitSelectAndScatterThunk( @@ -991,9 +989,7 @@ absl::StatusOr ThunkEmitter::EmitSelectAndScatterThunk( ir_emitter_.EmitSelectAndScatterHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of(ThunkInfo(instruction), - buffers.arguments, buffers.results, - kernel.name, kernel.thread_dims); + return MakeKernelThunkSequence(instruction, buffers, kernel); } absl::StatusOr ThunkEmitter::EmitSliceThunk( @@ -1010,9 +1006,7 @@ absl::StatusOr ThunkEmitter::EmitDynamicUpdateSliceThunk( auto kernel, ir_emitter_.EmitDynamicUpdateSliceHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of(ThunkInfo(instruction), - buffers.arguments, buffers.results, - kernel.name, kernel.thread_dims); + return MakeKernelThunkSequence(instruction, buffers, kernel); } absl::StatusOr ThunkEmitter::EmitSortThunk( @@ -1098,4 +1092,14 @@ absl::Status ThunkEmitter::ElementTypesSameAndSupported( return absl::OkStatus(); } +absl::StatusOr ThunkEmitter::MakeKernelThunkSequence( + const HloInstruction* instruction, + const ThunkEmitter::HostKernelAllocationSlices& buffers, + const IrEmitter2::KernelInfo& kernel, + std::optional min_alignment) { + return ThunkSequence::Of( + ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, + kernel.thread_dims, kernel.invariant_buffers, min_alignment); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 6921f76e75179b..ad6eb8863b5ee6 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -16,21 +16,23 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_THUNK_EMITTER_H_ #define XLA_SERVICE_CPU_THUNK_EMITTER_H_ +#include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/ir_emitter2.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/target_machine_features.h" #include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" @@ -195,6 +197,13 @@ class ThunkEmitter { absl::Span operands, absl::Span supported_types); + // Convenience function that creates a thunk sequence containing given kernel. + static absl::StatusOr MakeKernelThunkSequence( + const HloInstruction* instruction, + const ThunkEmitter::HostKernelAllocationSlices& buffers, + const IrEmitter2::KernelInfo& kernel, + std::optional min_alignment = std::nullopt); + IrEmitter2& ir_emitter_; const BufferAssignment& buffer_assignment_; diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index 598e0e521116ad..3aa3a8862011a3 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -50,10 +50,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_proto_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "tsl/lib/io/zlib_compression_options.h" #include "tsl/lib/io/zlib_outputbuffer.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 614e709c356cf2..21b5de8dee234a 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -30,10 +30,14 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -60,6 +64,11 @@ limitations under the License. namespace xla { using absl::StrCat; +using llvm::PatternMatch::m_BitCast; +using llvm::PatternMatch::m_Intrinsic; +using llvm::PatternMatch::m_Select; +using llvm::PatternMatch::m_Value; +using llvm::PatternMatch::match; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; @@ -713,6 +722,48 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type; + + // LLVM optimizes away `fpcast` and `fpext` operations and optimized + // LLVM IR has arithmetic operations on `bfloat16` that are not natively + // supported on any of the CPUs, and LLVM inserts very expensive calls to + // fp conversion functions around bf16 operations. To avoid this, we use + // bitcasts and shifts to convert bf16 to f32 and back using truncation + // with rounding, and suppress LLVM optimizations that hurt performance. + // This is enabled explicitly by a flag only for XLA:CPU backend. + if (options_.xla_cpu_use_truncate_f32_to_bf16_conversion) { + if (from_type == F32 && to_type == BF16) { + // This implementation is based on Eigen `float_to_bfloat16_rtne` with + // a special case for nans. + auto* i32 = b_->CreateBitCast(operand_value, b_->getInt32Ty()); + + // Rounding bias for non-nan values. + auto* lsb = + b_->CreateAnd(b_->CreateLShr(i32, 16), + llvm::ConstantInt::get(b_->getInt32Ty(), 1)); + auto* rounding_bias = b_->CreateAdd( + llvm::ConstantInt::get(b_->getInt32Ty(), 0x7fff), lsb); + + // For nan values, we simply truncate the original value. + auto* is_nan = + b_->createIsFPClass(operand_value, llvm::FPClassTest::fcNan); + auto* i16 = b_->CreateTrunc( + b_->CreateLShr( + b_->CreateSelect(is_nan, i32, + b_->CreateAdd(i32, rounding_bias)), + 16), + b_->getInt16Ty()); + + return b_->CreateBitCast(i16, b_->getBFloatTy()); + } + if (from_type == BF16 && to_type == F32) { + auto* i16 = b_->CreateBitCast(operand_value, b_->getInt16Ty()); + auto* i32 = b_->CreateZExt(i16, b_->getInt32Ty()); + auto* i32s = b_->CreateShl(i32, 16); + auto* f32 = b_->CreateBitCast(i32s, b_->getFloatTy()); + return f32; + } + } + if (from_type == to_type) { return operand_value; } diff --git a/third_party/xla/xla/service/elemental_ir_emitter.h b/third_party/xla/xla/service/elemental_ir_emitter.h index cff5bbb1648389..fe33977297572d 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.h +++ b/third_party/xla/xla/service/elemental_ir_emitter.h @@ -37,11 +37,21 @@ namespace xla { class ElementalIrEmitter : public IrBuilderMixin { public: + struct Options { + // Instead of relying on builtin `fpext` and `fpcast` emit a bitcast and + // truncate to convert f32 to bf16 (and emit extend to convert bf16 to f32). + bool xla_cpu_use_truncate_f32_to_bf16_conversion = false; + }; + using HloToElementGeneratorMap = absl::flat_hash_map; + ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b, + const Options& options) + : b_(b), module_(module), options_(options) {} + ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b) - : b_(b), module_(module) {} + : ElementalIrEmitter(module, b, Options()) {} virtual ~ElementalIrEmitter() = default; @@ -314,6 +324,8 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Module* module_; + Options options_; + friend class ElementalIrEmitterForTests; }; diff --git a/third_party/xla/xla/service/executable.cc b/third_party/xla/xla/service/executable.cc index ed86114607cf6f..aa81fce3e80e1c 100644 --- a/third_party/xla/xla/service/executable.cc +++ b/third_party/xla/xla/service/executable.cc @@ -25,7 +25,6 @@ limitations under the License. #include "xla/service/maybe_owning_device_memory.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 64d100bc332288..a0a8824815e254 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -313,7 +313,6 @@ cc_library( ":execution_stream_assignment", ":gpu_asm_opts_util", ":gpu_conv_runner", - ":gpu_fused_mha_runner", ":gpu_norm_runner", ":hlo_fusion_analysis", ":ir_emission_utils", @@ -356,9 +355,9 @@ cc_library( "//xla/service/gpu/runtime:conditional_thunk", "//xla/service/gpu/runtime:convolution_thunk", "//xla/service/gpu/runtime:copy_thunk", + "//xla/service/gpu/runtime:cudnn_thunk", "//xla/service/gpu/runtime:custom_call_thunk", "//xla/service/gpu/runtime:fft_thunk", - "//xla/service/gpu/runtime:fused_mha_thunk", "//xla/service/gpu/runtime:gemm_thunk", "//xla/service/gpu/runtime:gpublas_lt_matmul_thunk", "//xla/service/gpu/runtime:infeed_thunk", @@ -659,6 +658,7 @@ cc_library( "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -670,7 +670,6 @@ cc_library( "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", ], @@ -687,11 +686,12 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:types", - "//xla:util", "//xla/hlo/ir:backend_config", + "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", @@ -922,44 +922,6 @@ xla_cc_test( ], ) -cc_library( - name = "softmax_rewriter_triton", - srcs = ["softmax_rewriter_triton.cc"], - hdrs = ["softmax_rewriter_triton.h"], - deps = [ - ":backend_configs_cc", - ":hlo_traversal", - ":ir_emission_utils", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/service/gpu/model:fusion_analysis_cache", - "//xla/service/gpu/model:gpu_indexing_performance_model", - "//xla/service/gpu/model:symbolic_tile_analysis", - "//xla/service/gpu/model:tiled_hlo_computation", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "matmul_utils", srcs = ["matmul_utils.cc"], @@ -1030,78 +992,6 @@ xla_cc_test( ], ) -cc_library( - name = "gpu_async_collective_annotator", - srcs = ["gpu_async_collective_annotator.cc"], - hdrs = ["gpu_async_collective_annotator.h"], - deps = [ - ":backend_configs_cc", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_async_collective_annotator_test", - srcs = ["gpu_async_collective_annotator_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_async_collective_annotator", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/tests:hlo_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_convert_async_collectives_to_sync", - srcs = ["gpu_convert_async_collectives_to_sync.cc"], - hdrs = ["gpu_convert_async_collectives_to_sync.h"], - deps = [ - ":backend_configs_cc", - "//xla/hlo/ir:hlo", - "//xla/service:convert_async_collectives_to_sync", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_convert_async_collectives_to_sync_test", - srcs = ["gpu_convert_async_collectives_to_sync_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_convert_async_collectives_to_sync", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "gpu_conv_runner", srcs = ["gpu_conv_runner.cc"], @@ -1156,604 +1046,141 @@ cc_library( ) cc_library( - name = "gpu_fused_mha_runner", - srcs = ["gpu_fused_mha_runner.cc"], - hdrs = ["gpu_fused_mha_runner.h"], + name = "cusolver_context", + srcs = if_gpu_is_configured(["cusolver_context.cc"]), + hdrs = if_gpu_is_configured(["cusolver_context.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":stream_executor_util", - "//xla:shape_util", + "//xla:comparison_util", + "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", + "//xla/stream_executor:blas", + "//xla/stream_executor/gpu:gpu_stream", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "//xla/tsl/cuda:cusolver", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + "//xla/stream_executor/rocm:rocblas_wrapper", + "//xla/stream_executor/rocm:rocsolver_wrapper", + "//xla/stream_executor/rocm:hipsolver_wrapper", + ]), +) + +tf_proto_library( + name = "fusion_process_dump_proto", + srcs = ["fusion_process_dump.proto"], + cc_api_version = 2, + protodeps = [ + "//xla/stream_executor:device_description_proto", ], ) cc_library( - name = "gpu_conv_rewriter", - srcs = ["gpu_conv_rewriter.cc"], - hdrs = ["gpu_conv_rewriter.h"], + name = "fusion_process_dump", + srcs = ["fusion_process_dump.cc"], + hdrs = ["fusion_process_dump.h"], deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - "//xla:permutation_util", - "//xla:shape_util", + ":fusion_process_dump_proto_cc", "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", + "//xla/service:hlo_graph_dumper", "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", + "//xla/tools:hlo_module_loader", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) -cc_library( - name = "gpu_sort_rewriter", - srcs = if_gpu_is_configured( - ["gpu_sort_rewriter.cc"], - ["gpu_sort_rewriter_stub.cc"], - ), - hdrs = ["gpu_sort_rewriter.h"], +xla_cc_test( + name = "fusion_process_dump_test", + srcs = ["fusion_process_dump_test.cc"], deps = [ - ":cublas_cudnn", - "//xla:comparison_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", + ":fusion_process_dump", + ":fusion_process_dump_proto_cc", + ":gpu_device_info_for_tests", + "//xla:test", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:stable_sort_expander", - "//xla/service/gpu/runtime:cub_sort_thunk", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "move_copy_to_users", - srcs = ["move_copy_to_users.cc"], - hdrs = ["move_copy_to_users.h"], + name = "cudnn_support_utils", + srcs = ["cudnn_support_utils.cc"], + hdrs = ["cudnn_support_utils.h"], deps = [ + ":cublas_cudnn", "//xla:shape_util", + "//xla:util", + "//xla:window_util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", + "//xla/stream_executor:device_description", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "move_copy_to_users_test", - srcs = ["move_copy_to_users_test.cc"], - deps = [ - ":move_copy_to_users", - "//xla/service:layout_assignment", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "gpu_conv_rewriter_test", - srcs = ["gpu_conv_rewriter_test.cc"], + name = "cudnn_support_utils_test", + srcs = ["cudnn_support_utils_test.cc"], deps = [ - ":cublas_cudnn", - ":gpu_conv_rewriter", - "//xla:array4d", - "//xla:literal_util", - "//xla:protobuf_util", + ":cudnn_support_utils", "//xla:shape_util", "//xla:test", - "//xla:test_helpers", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:shape_inference", + "//xla/service:hlo_parser", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:str_format", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) -xla_test( - name = "gpu_sort_rewriter_test", - srcs = if_cuda_is_configured(["gpu_sort_rewriter_test.cc"]), - backends = ["gpu"], - tags = ["no_oss"], +cc_library( + name = "cublas_padding_requirements", + srcs = ["cublas_padding_requirements.cc"], + hdrs = ["cublas_padding_requirements.h"], deps = [ - ":cublas_cudnn", - ":gpu_sort_rewriter", - "//xla:error_spec", - "//xla:xla_data_proto_cc", + ":variant_visitor", + "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - -cc_library( - name = "cusolver_context", - srcs = if_gpu_is_configured(["cusolver_context.cc"]), - hdrs = if_gpu_is_configured(["cusolver_context.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = [ - "//xla:comparison_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/stream_executor", - "//xla/stream_executor:blas", - "//xla/stream_executor/gpu:gpu_stream", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//xla/tsl/cuda:cusolver", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - "//xla/stream_executor/rocm:rocblas_wrapper", - "//xla/stream_executor/rocm:rocsolver_wrapper", - "//xla/stream_executor/rocm:hipsolver_wrapper", - ]), -) - -cc_library( - name = "instruction_fusion", - srcs = ["instruction_fusion.cc"], - hdrs = ["instruction_fusion.h"], - deps = [ - ":gpu_fusible", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:fusion_node_indexing_evaluation", - "//xla/service:fusion_queue", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "instruction_fusion_test", - srcs = ["instruction_fusion_test.cc"], - tags = [ - "nomsan", - "not_run:arm", - ], - deps = [ - ":gpu_device_info_for_tests", - ":gpu_fusible", - ":instruction_fusion", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -tf_proto_library( - name = "fusion_process_dump_proto", - srcs = ["fusion_process_dump.proto"], - cc_api_version = 2, - protodeps = [ - "//xla/stream_executor:device_description_proto", - ], -) - -cc_library( - name = "fusion_process_dump", - srcs = ["fusion_process_dump.cc"], - hdrs = ["fusion_process_dump.h"], - deps = [ - ":fusion_process_dump_proto_cc", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_graph_dumper", - "//xla/stream_executor:device_description", - "//xla/tools:hlo_module_loader", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "fusion_process_dump_test", - srcs = ["fusion_process_dump_test.cc"], - deps = [ - ":fusion_process_dump", - ":fusion_process_dump_proto_cc", - ":gpu_device_info_for_tests", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "priority_fusion", - srcs = ["priority_fusion.cc"], - hdrs = ["priority_fusion.h"], - deps = [ - ":backend_configs_cc", - ":fusion_process_dump_proto_cc", - ":gpu_fusible", - ":hlo_fusion_analysis", - ":hlo_traversal", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:dump", - "//xla/service:fusion_queue", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_graph_dumper", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/service/gpu/model:fusion_analysis_cache", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/service/gpu/model:gpu_performance_model", - "//xla/service/gpu/model:gpu_performance_model_base", - "//xla/service/gpu/model:symbolic_tile_analysis", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "priority_fusion_test", - srcs = ["priority_fusion_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = ["no_pip"], - deps = [ - ":backend_configs_cc", - ":gpu_device_info_for_tests", - ":gpu_fusible", - ":hlo_fusion_analysis", - ":priority_fusion", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_cost_analysis", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "multi_output_fusion", - srcs = ["multi_output_fusion.cc"], - hdrs = ["multi_output_fusion.h"], - deps = [ - ":gpu_fusible", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_dfs_reachability", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_graph_dumper", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/service/gpu/model:gpu_performance_model", - "//xla/service/gpu/model:gpu_performance_model_base", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "multi_output_fusion_test", - srcs = ["multi_output_fusion_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":gpu_device_info_for_tests", - ":gpu_fusible", - ":multi_output_fusion", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_cost_analysis", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "rename_fusions", - srcs = ["rename_fusions.cc"], - hdrs = ["rename_fusions.h"], - deps = [ - ":hlo_traversal", - ":ir_emission_utils", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "rename_fusions_test", - srcs = ["rename_fusions_test.cc"], - deps = [ - ":rename_fusions", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) - -xla_cc_test( - name = "softmax_rewriter_triton_test", - srcs = ["softmax_rewriter_triton_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_device_info_for_tests", - ":softmax_rewriter_triton", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:instruction_fusion", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_sanitize_constant_names", - srcs = ["gpu_sanitize_constant_names.cc"], - hdrs = ["gpu_sanitize_constant_names.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:name_uniquer", - "//xla/service/llvm_ir:buffer_assignment_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "gpu_sanitize_constant_names_test", - srcs = ["gpu_sanitize_constant_names_test.cc"], - deps = [ - ":gpu_sanitize_constant_names", - "//xla:literal_util", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - -cc_library( - name = "gpu_conv_padding_legalization", - srcs = ["gpu_conv_padding_legalization.cc"], - hdrs = ["gpu_conv_padding_legalization.h"], - deps = [ - ":cublas_cudnn", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:shape_inference", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_conv_padding_legalization_test", - srcs = ["gpu_conv_padding_legalization_test.cc"], - deps = [ - ":cublas_cudnn", - ":gpu_conv_padding_legalization", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@local_tsl//tsl/platform:test", - ], -) - -cc_library( - name = "cudnn_support_utils", - srcs = ["cudnn_support_utils.cc"], - hdrs = ["cudnn_support_utils.h"], - deps = [ - ":cublas_cudnn", - "//xla:shape_util", - "//xla:util", - "//xla:window_util", - "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "cudnn_support_utils_test", - srcs = ["cudnn_support_utils_test.cc"], - deps = [ - ":cudnn_support_utils", - "//xla:shape_util", - "//xla:test", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "cublas_padding_requirements", - srcs = ["cublas_padding_requirements.cc"], - hdrs = ["cublas_padding_requirements.h"], - deps = [ - ":variant_visitor", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description", ], ) @@ -1810,43 +1237,6 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) -cc_library( - name = "gpu_reduce_scatter_creator", - srcs = ["gpu_reduce_scatter_creator.cc"], - hdrs = ["gpu_reduce_scatter_creator.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:collective_opt_utils", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "gpu_all_gather_optimizer", - srcs = ["gpu_all_gather_optimizer.cc"], - hdrs = ["gpu_all_gather_optimizer.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - cc_library( name = "gpu_float_support", srcs = ["gpu_float_support.cc"], @@ -1881,6 +1271,7 @@ cc_library( ":metrics", ":runtime_intrinsics", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1891,10 +1282,8 @@ cc_library( "//xla/service:hlo_ordering", "//xla/service:hlo_proto_cc", "//xla/service:logical_buffer", - "//xla/service/gpu/runtime:conditional_thunk", "//xla/service/gpu/runtime:sequential_thunk", "//xla/service/gpu/runtime:thunk", - "//xla/service/gpu/runtime:while_thunk", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor/rocm:rocm_platform_id", @@ -1902,18 +1291,19 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:AsmParser", "@llvm-project//llvm:TransformUtils", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", ], ) @@ -1922,12 +1312,6 @@ cc_library( srcs = ["fusion_pipeline.cc"], hdrs = ["fusion_pipeline.h"], deps = [ - ":horizontal_input_fusion", - ":horizontal_loop_fusion", - ":instruction_fusion", - ":multi_output_fusion", - ":priority_fusion", - ":variadic_op_splitter", "//xla:xla_proto_cc", "//xla/service:cpu_gpu_shape_verifier", "//xla/service:hlo_cost_analysis", @@ -1939,6 +1323,12 @@ cc_library( "//xla/service:layout_assignment", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/transforms:fusion_merger", + "//xla/service/gpu/transforms:horizontal_input_fusion", + "//xla/service/gpu/transforms:horizontal_loop_fusion", + "//xla/service/gpu/transforms:instruction_fusion", + "//xla/service/gpu/transforms:multi_output_fusion", + "//xla/service/gpu/transforms:priority_fusion", + "//xla/service/gpu/transforms:variadic_op_splitter", "//xla/stream_executor:device_description", "@local_tsl//tsl/platform:env", ], @@ -1949,8 +1339,6 @@ cc_library( srcs = ["prepare_hlo_for_ir_emitting_pipeline.cc"], hdrs = ["prepare_hlo_for_ir_emitting_pipeline.h"], deps = [ - ":gpu_sanitize_constant_names", - ":horizontal_loop_fusion", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:copy_insertion", @@ -1963,6 +1351,8 @@ cc_library( "//xla/service:loop_schedule_linearizer", "//xla/service/gpu/transforms:alias_passthrough_params", "//xla/service/gpu/transforms:copy_fusion", + "//xla/service/gpu/transforms:horizontal_loop_fusion", + "//xla/service/gpu/transforms:sanitize_constant_names", ], ) @@ -1979,31 +1369,17 @@ cc_library( ":buffer_sharing", ":compile_module_to_llvm_ir", ":conv_layout_normalization", - ":dot_operand_converter", ":executable_proto_cc", ":execution_stream_assignment", ":fusion_pipeline", - ":gpu_algebraic_simplifier", - ":gpu_all_gather_optimizer", - ":gpu_async_collective_annotator", ":gpu_constants", - ":gpu_conv_rewriter", - ":gpu_convert_async_collectives_to_sync", ":gpu_executable", ":gpu_float_support", ":gpu_hlo_schedule", ":gpu_latency_hiding_scheduler", - ":gpu_layout_assignment", ":gpu_p2p_pipeliner", - ":gpu_reduce_scatter_creator", - ":gpu_sanitize_constant_names", - ":gpu_scatter_expander", ":gpu_spmd_pipeline", - ":gpu_windowed_einsum_handler", ":hlo_fusion_stats", - ":horizontal_input_fusion", - ":horizontal_loop_fusion", - ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", ":ir_emitter_context", @@ -2011,28 +1387,10 @@ cc_library( ":kernel_reuse_cache", ":matmul_utils", ":metrics", - ":move_copy_to_users", - ":multi_output_fusion", - ":pipelined_p2p_rewriter", ":prepare_hlo_for_ir_emitting_pipeline", - ":priority_fusion", - ":reduction_degenerate_dim_remover", - ":reduction_dimension_grouper", - ":reduction_layout_normalizer", - ":reduction_splitter", ":reduction_utils", - ":rename_fusions", ":runtime_intrinsics", - ":scatter_slice_simplifier", - ":softmax_rewriter_triton", - ":stream_attribute_annotator", - ":stream_attribute_async_wrapper", ":stream_executor_util", - ":topk_specializer", - ":topk_splitter", - ":tree_reduction_rewriter", - ":triton_fusion_numerics_verifier", - ":variadic_op_splitter", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2061,14 +1419,22 @@ cc_library( "//xla/service/gpu/model:gpu_cost_model_stats_collection", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/runtime:thunk", + "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/gpu/transforms:algorithm_checker", + "//xla/service/gpu/transforms:all_gather_optimizer", "//xla/service/gpu/transforms:all_reduce_blueconnect", + "//xla/service/gpu/transforms:all_reduce_splitter", + "//xla/service/gpu/transforms:async_collective_annotator", + "//xla/service/gpu/transforms:async_wrapper", "//xla/service/gpu/transforms:collective_permute_cycle_decomposer", "//xla/service/gpu/transforms:collective_permute_valid_iteration_annotator", "//xla/service/gpu/transforms:command_buffer_scheduling", + "//xla/service/gpu/transforms:conv_rewriter", + "//xla/service/gpu/transforms:convert_async_collectives_to_sync", "//xla/service/gpu/transforms:cudnn_custom_call_converter", "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/service/gpu/transforms:dot_dimension_sorter", + "//xla/service/gpu/transforms:dot_operand_converter", "//xla/service/gpu/transforms:double_buffer_loop_unrolling", "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", @@ -2076,6 +1442,27 @@ cc_library( "//xla/service/gpu/transforms:gemm_fusion", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:gemv_rewriter", + "//xla/service/gpu/transforms:layout_assignment", + "//xla/service/gpu/transforms:move_copy_to_users", + "//xla/service/gpu/transforms:pipelined_p2p_rewriter", + "//xla/service/gpu/transforms:reduce_scatter_creator", + "//xla/service/gpu/transforms:reduction_degenerate_dim_remover", + "//xla/service/gpu/transforms:reduction_dimension_grouper", + "//xla/service/gpu/transforms:reduction_layout_normalizer", + "//xla/service/gpu/transforms:reduction_splitter", + "//xla/service/gpu/transforms:rename_fusions", + "//xla/service/gpu/transforms:sanitize_constant_names", + "//xla/service/gpu/transforms:scatter_expander", + "//xla/service/gpu/transforms:scatter_slice_simplifier", + "//xla/service/gpu/transforms:softmax_rewriter_triton", + "//xla/service/gpu/transforms:stream_attribute_annotator", + "//xla/service/gpu/transforms:stream_attribute_async_wrapper", + "//xla/service/gpu/transforms:topk_specializer", + "//xla/service/gpu/transforms:topk_splitter", + "//xla/service/gpu/transforms:transpose_dimension_grouper", + "//xla/service/gpu/transforms:tree_reduction_rewriter", + "//xla/service/gpu/transforms:triton_fusion_numerics_verifier", + "//xla/service/gpu/transforms:windowed_einsum_handler", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd:collective_permute_motion", "//xla/service:algebraic_simplifier", @@ -2086,7 +1473,6 @@ cc_library( "//xla/service:all_reduce_folder", "//xla/service:all_reduce_promotion", "//xla/service:all_reduce_reassociate", - "//xla/service:all_reduce_splitter", "//xla/service:async_collective_creator", "//xla/service:batchnorm_expander", "//xla/service:bitcast_dtypes_expander", @@ -2216,7 +1602,7 @@ cc_library( xla_test( name = "gpu_compiler_test", - srcs = if_gpu_is_configured(["gpu_compiler_test.cc"]), + srcs = ["gpu_compiler_test.cc"], backends = ["gpu"], data = ["gpu_compiler_test_autotune_db.textproto"], deps = [ @@ -2225,9 +1611,13 @@ xla_test( ":metrics", "//xla:autotune_results_proto_cc", "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "//xla/service:compiler", "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", @@ -2235,8 +1625,11 @@ xla_test( "//xla/service:xla_debug_info_manager", "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", @@ -2270,7 +1663,7 @@ xla_test( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_rematerialization", - "//xla/service/gpu:stream_attribute_annotator", + "//xla/service/gpu/transforms:stream_attribute_annotator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -2334,16 +1727,11 @@ cc_library( deps = [ ":buffer_sharing", ":cublas_padding_requirements", - ":gpu_algebraic_simplifier", ":gpu_asm_opts_util", ":gpu_compiler", - ":gpu_conv_padding_legalization", - ":gpu_conv_rewriter", - ":gpu_sort_rewriter", ":ir_emission_utils", ":metrics", ":target_constants", - ":triangular_solve_rewriter", "//xla:autotune_results_proto_cc", "//xla:util", "//xla:xla_proto_cc", @@ -2371,7 +1759,11 @@ cc_library( "//xla/service/gpu/autotuning:gemm_algorithm_picker", "//xla/service/gpu/autotuning:gemm_fusion_autotuner", "//xla/service/gpu/llvm_gpu_backend", + "//xla/service/gpu/transforms:algebraic_simplifier", + "//xla/service/gpu/transforms:conv_padding_legalization", + "//xla/service/gpu/transforms:conv_rewriter", "//xla/service/gpu/transforms:cublas_pad_for_gemms", + "//xla/service/gpu/transforms:cudnn_custom_call_compiler", "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter", "//xla/service/gpu/transforms:cudnn_fused_mha_rewriter", "//xla/service/gpu/transforms:cudnn_fused_mha_transpose_fusion", @@ -2380,9 +1772,10 @@ cc_library( "//xla/service/gpu/transforms:cudnn_pad_for_convolutions", "//xla/service/gpu/transforms:cudnn_simplify_padding", "//xla/service/gpu/transforms:cudnn_vectorize_convolutions", - "//xla/service/gpu/transforms:cudnn_workspace_rewriter", "//xla/service/gpu/transforms:dot_sparsity_rewriter", "//xla/service/gpu/transforms:gpusolver_rewriter", + "//xla/service/gpu/transforms:sort_rewriter", + "//xla/service/gpu/transforms:triangular_solve_rewriter", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor", "//xla/stream_executor:dnn", @@ -2438,7 +1831,6 @@ xla_test( "gpu_a100", ], tags = [ - "gpu", "no_rocm", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. ], @@ -2477,7 +1869,6 @@ xla_test( "gpu", ], tags = [ - "gpu", "no_rocm", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. ], @@ -2568,45 +1959,6 @@ cc_library( alwayslink = True, # Contains compiler registration ) -cc_library( - name = "gpu_algebraic_simplifier", - srcs = [ - "gpu_algebraic_simplifier.cc", - ], - hdrs = [ - "gpu_algebraic_simplifier.h", - ], - deps = [ - ":matmul_utils", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", - "//xla/service:hlo_pass", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "gpu_algebraic_simplifier_test", - srcs = ["gpu_algebraic_simplifier_test.cc"], - deps = [ - ":gpu_algebraic_simplifier", - "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "amdgpu_compiler_impl", srcs = [ @@ -2618,13 +1970,8 @@ cc_library( tags = ["manual"], deps = [ ":cublas_padding_requirements", - ":gpu_algebraic_simplifier", ":gpu_compiler", - ":gpu_conv_padding_legalization", - ":gpu_conv_rewriter", - ":gpu_sort_rewriter", ":target_constants", - ":triangular_solve_rewriter", "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", @@ -2645,9 +1992,14 @@ cc_library( "//xla/service/gpu/autotuning:conv_algorithm_picker", "//xla/service/gpu/autotuning:gemm_algorithm_picker", "//xla/service/gpu/llvm_gpu_backend", + "//xla/service/gpu/transforms:algebraic_simplifier", + "//xla/service/gpu/transforms:conv_padding_legalization", + "//xla/service/gpu/transforms:conv_rewriter", "//xla/service/gpu/transforms:cublas_pad_for_gemms", "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter", "//xla/service/gpu/transforms:gpusolver_rewriter", + "//xla/service/gpu/transforms:sort_rewriter", + "//xla/service/gpu/transforms:triangular_solve_rewriter", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:dnn", @@ -2671,136 +2023,41 @@ cc_library( hdrs = ["xfeed_queue.h"], deps = [ "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:logging", - ], -) - -cc_library( - name = "io_feed_manager", - srcs = [ - "infeed_manager.cc", - "outfeed_manager.cc", - "xla_executor_state.h", - ], - hdrs = [ - "infeed_manager.h", - "outfeed_manager.h", - ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":xfeed_queue", - "//xla:literal", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:util", - "//xla/stream_executor:device_memory_handle", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/gpu:gpu_executor_header", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:notification", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_layout_assignment", - srcs = ["gpu_layout_assignment.cc"], - hdrs = ["gpu_layout_assignment.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":matmul_utils", - ":reduction_utils", - ":stream_executor_util", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:computation_layout", - "//xla/service:host_memory_offload_annotations_hdr", - "//xla/service:layout_assignment", - "//xla/service:logical_buffer", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/tsl/util:env_var", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_layout_assignment_test", - srcs = ["gpu_layout_assignment_test.cc"], - deps = [ - ":gpu_layout_assignment", - ":stream_executor_util", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:computation_layout", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/stream_executor:device_description", - "//xla/stream_executor:dnn", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_schedule_postprocessing", - srcs = ["gpu_schedule_postprocessing.cc"], - hdrs = ["gpu_schedule_postprocessing.h"], - deps = [ - ":backend_configs_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:logging", ], ) -xla_cc_test( - name = "gpu_schedule_postprocessing_test", - srcs = ["gpu_schedule_postprocessing_test.cc"], +cc_library( + name = "io_feed_manager", + srcs = [ + "infeed_manager.cc", + "outfeed_manager.cc", + "xla_executor_state.h", + ], + hdrs = [ + "infeed_manager.h", + "outfeed_manager.h", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ - ":backend_configs_cc", - ":gpu_schedule_postprocessing", + ":xfeed_queue", + "//xla:literal", + "//xla:shape_tree", + "//xla:shape_util", "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_executor_header", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:notification", "@local_tsl//tsl/platform:statusor", ], ) @@ -2812,8 +2069,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_latency_hiding_scheduler", - ":gpu_schedule_postprocessing", - ":scheduling_instruction_annotator", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", @@ -2826,6 +2081,9 @@ cc_library( "//xla/service:p2p_schedule_preparation", "//xla/service:profile_guided_latency_estimator", "//xla/service/gpu/model:analytical_latency_estimator", + "//xla/service/gpu/transforms:pgle_accuracy_checker", + "//xla/service/gpu/transforms:schedule_postprocessing", + "//xla/service/gpu/transforms:scheduling_instruction_annotator", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -2840,6 +2098,7 @@ cc_library( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", ], ) @@ -2864,6 +2123,7 @@ xla_test( "//xla/tests:test_utils", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -2919,7 +2179,6 @@ cc_library( srcs = ["gpu_spmd_pipeline.cc"], hdrs = ["gpu_spmd_pipeline.h"], deps = [ - ":gpu_algebraic_simplifier", ":runtime_intrinsics", "//xla/hlo/ir:hlo", "//xla/hlo/transforms:hlo_constant_splitter", @@ -2938,6 +2197,7 @@ cc_library( "//xla/service:tuple_simplifier", "//xla/service:while_loop_constant_sinking", "//xla/service:while_loop_simplifier", + "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/spmd:collective_permute_motion", "//xla/service/spmd:stateful_rng_spmd_partitioner", "//xla/service/spmd/shardy:shardy_xla_pass", @@ -2994,7 +2254,8 @@ xla_cc_test( cuda_library( name = "stream_executor_util_kernel", - srcs = if_cuda_is_configured(["stream_executor_util_kernel.cu.cc"]), + srcs = ["stream_executor_util_kernel.cu.cc"], + tags = ["no_rocm"], deps = ["@local_config_cuda//cuda:cuda_headers"], ) @@ -3007,7 +2268,6 @@ cc_library( deps = [ ":cublas_cudnn", ":launch_dimensions", - ":stream_executor_util_kernel", "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:util", @@ -3036,7 +2296,9 @@ cc_library( "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - ], + ] + if_cuda_is_configured([ + ":stream_executor_util_kernel", + ]), ) xla_cc_test( @@ -3268,350 +2530,79 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - ], -) - -xla_test( - name = "conv_layout_normalization_test", - srcs = ["conv_layout_normalization_test.cc"], - backends = ["gpu"], - deps = [ - "//xla:error_spec", - "//xla/hlo/ir:hlo", - "//xla/service/gpu/tests:gpu_codegen_test", # fixdeps: keep - "//xla/tests:hlo_test_base", - "//xla/tests:test_macros_header", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "variadic_op_splitter", - srcs = ["variadic_op_splitter.cc"], - hdrs = ["variadic_op_splitter.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_scatter_expander", - srcs = ["gpu_scatter_expander.cc"], - hdrs = ["gpu_scatter_expander.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:scatter_expander", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "variadic_op_splitter_test", - srcs = ["variadic_op_splitter_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":variadic_op_splitter", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "hlo_algorithm_denylist", - srcs = ["hlo_algorithm_denylist.cc"], - hdrs = ["hlo_algorithm_denylist.h"], - deps = [ - ":backend_configs_cc", - "//xla:autotuning_proto_cc", - "//xla:debug_options_flags", - "//xla/hlo/ir:backend_config", - "//xla/service/gpu/autotuning:gpu_autotuning_proto_cc", - "//xla/stream_executor:dnn", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "hlo_algorithm_denylist_test", - srcs = ["hlo_algorithm_denylist_test.cc"], - data = ["data/hlo_algorithm_denylist.pbtxt"], - deps = [ - ":hlo_algorithm_denylist", - "//xla/stream_executor:dnn", - "//xla/tests:test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "horizontal_loop_fusion", - srcs = ["horizontal_loop_fusion.cc"], - hdrs = ["horizontal_loop_fusion.h"], - deps = [ - ":gpu_fusible", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:sub_byte_normalization", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "horizontal_loop_fusion_test", - srcs = ["horizontal_loop_fusion_test.cc"], - backends = ["gpu"], - deps = [ - ":gpu_device_info_for_tests", - ":horizontal_loop_fusion", - ":instruction_fusion", - "//xla:error_spec", - "//xla:shape_util", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_dce", - "//xla/service:hlo_parser", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - ], -) - -cc_library( - name = "horizontal_input_fusion", - srcs = ["horizontal_input_fusion.cc"], - hdrs = ["horizontal_input_fusion.h"], - deps = [ - ":gpu_fusible", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "horizontal_input_fusion_test", - srcs = ["horizontal_input_fusion_test.cc"], - backends = ["gpu"], - deps = [ - ":gpu_device_info_for_tests", - ":horizontal_input_fusion", - "//xla:error_spec", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu/tests:gpu_codegen_test", - "//xla/stream_executor:device_description", - "//xla/tests:xla_internal_test_main", - ], -) - -xla_cc_test( - name = "gpu_float_support_test", - srcs = ["gpu_float_support_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_float_support", - ":ir_emission_utils", - "//xla:shape_util", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:float_normalization", - "//xla/service:hlo_verifier", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "reduction_degenerate_dim_remover", - srcs = ["reduction_degenerate_dim_remover.cc"], - hdrs = ["reduction_degenerate_dim_remover.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", + "@com_google_googletest//:gtest", ], ) -cc_library( - name = "reduction_dimension_grouper", - srcs = ["reduction_dimension_grouper.cc"], - hdrs = ["reduction_dimension_grouper.h"], +xla_test( + name = "conv_layout_normalization_test", + srcs = ["conv_layout_normalization_test.cc"], + backends = ["gpu"], deps = [ - "//xla:shape_util", + "//xla:error_spec", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", + "//xla/service/gpu/tests:gpu_codegen_test", # fixdeps: keep + "//xla/tests:hlo_test_base", + "//xla/tests:test_macros_header", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "reduction_splitter", - srcs = ["reduction_splitter.cc"], - hdrs = ["reduction_splitter.h"], + name = "hlo_algorithm_denylist", + srcs = ["hlo_algorithm_denylist.cc"], + hdrs = ["hlo_algorithm_denylist.h"], deps = [ - ":reduction_utils", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", + ":backend_configs_cc", + "//xla:autotuning_proto_cc", + "//xla:debug_options_flags", + "//xla/hlo/ir:backend_config", + "//xla/service/gpu/autotuning:gpu_autotuning_proto_cc", + "//xla/stream_executor:dnn", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", ], ) xla_cc_test( - name = "reduction_splitter_test", - srcs = ["reduction_splitter_test.cc"], - deps = [ - ":reduction_splitter", - "//xla:shape_util", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - -cc_library( - name = "reduction_layout_normalizer", - srcs = ["reduction_layout_normalizer.cc"], - hdrs = ["reduction_layout_normalizer.h"], + name = "hlo_algorithm_denylist_test", + srcs = ["hlo_algorithm_denylist_test.cc"], + data = ["data/hlo_algorithm_denylist.pbtxt"], deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + ":hlo_algorithm_denylist", + "//xla/stream_executor:dnn", + "//xla/tests:test_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) -cc_library( - name = "tree_reduction_rewriter", - srcs = ["tree_reduction_rewriter.cc"], - hdrs = ["tree_reduction_rewriter.h"], +xla_cc_test( + name = "gpu_float_support_test", + srcs = ["gpu_float_support_test.cc"], deps = [ - ":reduction_utils", + ":backend_configs_cc", + ":gpu_float_support", + ":ir_emission_utils", "//xla:shape_util", - "//xla:util", + "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", + "//xla/service:float_normalization", + "//xla/service:hlo_verifier", "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", + "//xla/tests:hlo_test_base", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/numeric:bits", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", + "@com_google_googletest//:gtest_main", ], ) @@ -3626,48 +2617,6 @@ cc_library( ], ) -cc_library( - name = "dot_operand_converter", - srcs = ["dot_operand_converter.cc"], - hdrs = ["dot_operand_converter.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:op_expander_pass", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_test( - name = "dot_operand_converter_test", - srcs = if_gpu_is_configured(["dot_operand_converter_test.cc"]), - backends = [ - "gpu_a100", - "gpu_p100", - "gpu_v100", - "gpu_amd_any", - ], - deps = if_gpu_is_configured( - [ - ":dot_operand_converter", - "@com_google_googletest//:gtest", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/service:pattern_matcher", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", - ], - ["@local_tsl//tsl/platform:test_main"], # b/317293391 - ) + ["//xla:xla_data_proto_cc"], -) - cc_library( name = "make_batch_pointers", srcs = if_gpu_is_configured(["make_batch_pointers.cc"]), @@ -3698,25 +2647,6 @@ cuda_library( ], ) -cc_library( - name = "triangular_solve_rewriter", - srcs = ["triangular_solve_rewriter.cc"], - hdrs = ["triangular_solve_rewriter.h"], - deps = [ - ":cublas_cudnn", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - tsl_gpu_library( name = "runtime_intrinsics", srcs = ["runtime_intrinsics.cc"], @@ -3730,6 +2660,7 @@ tsl_gpu_library( "//xla/service:custom_call_target_registry", "//xla/service:platform_util", "//xla/stream_executor", + "//xla/stream_executor:stream_finder", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -3744,182 +2675,60 @@ xla_test( name = "runtime_intrinsics_test", srcs = ["runtime_intrinsics_test.cc"], backends = ["gpu"], - deps = [ - ":runtime_intrinsics", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "hlo_fusion_stats", - srcs = ["hlo_fusion_stats.cc"], - hdrs = ["hlo_fusion_stats.h"], - deps = [ - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "hlo_fusion_stats_test", - srcs = ["hlo_fusion_stats_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":hlo_fusion_stats", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "scatter_slice_simplifier", - srcs = ["scatter_slice_simplifier.cc"], - hdrs = ["scatter_slice_simplifier.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "scatter_slice_simplifier_test", - srcs = ["scatter_slice_simplifier_test.cc"], - deps = [ - ":scatter_slice_simplifier", - "//xla:shape_util", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "conv_layout_normalization", - srcs = ["conv_layout_normalization.cc"], - hdrs = ["conv_layout_normalization.h"], - deps = [ - ":cublas_cudnn", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "topk_specializer", - srcs = ["topk_specializer.cc"], - hdrs = ["topk_specializer.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:hlo_proto_cc", - "//xla/service:tuple_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + deps = [ + ":runtime_intrinsics", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "topk_splitter", - srcs = ["topk_splitter.cc"], - hdrs = ["topk_splitter.h"], + name = "hlo_fusion_stats", + srcs = ["hlo_fusion_stats.cc"], + hdrs = ["hlo_fusion_stats.h"], deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:errors", ], ) xla_cc_test( - name = "topk_splitter_test", - srcs = ["topk_splitter_test.cc"], + name = "hlo_fusion_stats_test", + srcs = ["hlo_fusion_stats_test.cc"], + tags = [ + "nomsan", + ], deps = [ - ":topk_splitter", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_dce", - "//xla/service:pattern_matcher", - "//xla/service:topk_rewriter", + ":hlo_fusion_stats", + "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + "@com_google_googletest//:gtest_main", ], ) -xla_test( - name = "topk_test", - srcs = ["topk_test.cc"], - backends = ["gpu"], +cc_library( + name = "conv_layout_normalization", + srcs = ["conv_layout_normalization.cc"], + hdrs = ["conv_layout_normalization.h"], deps = [ - ":topk_specializer", + ":cublas_cudnn", "//xla:shape_util", + "//xla:status_macros", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:platform_util", - "//xla/service:topk_rewriter", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", ], ) @@ -3929,6 +2738,7 @@ xla_test( backends = [ "gpu_v100", "gpu_a100", + "gpu_h100", "gpu_amd_any", ], tags = [ @@ -4077,215 +2887,6 @@ cc_library( ], ) -cc_library( - name = "stream_attribute_annotator", - srcs = ["stream_attribute_annotator.cc"], - hdrs = ["stream_attribute_annotator.h"], - deps = [ - ":backend_configs_cc", - ":gpu_fusible", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_pass", - "//xla/service/gpu/runtime:thunk", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "stream_attribute_annotator_test", - srcs = ["stream_attribute_annotator_test.cc"], - deps = [ - ":backend_configs_cc", - ":stream_attribute_annotator", - "//xla/hlo/ir:hlo", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "stream_attribute_async_wrapper", - srcs = ["stream_attribute_async_wrapper.cc"], - hdrs = ["stream_attribute_async_wrapper.h"], - deps = [ - ":backend_configs_cc", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service/gpu/runtime:thunk", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "stream_attribute_async_wrapper_test", - srcs = ["stream_attribute_async_wrapper_test.cc"], - deps = [ - ":backend_configs_cc", - ":stream_attribute_async_wrapper", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_windowed_einsum_handler", - srcs = ["gpu_windowed_einsum_handler.cc"], - hdrs = ["gpu_windowed_einsum_handler.h"], - deps = [ - ":backend_configs_cc", - "//xla:literal_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "//xla/service:shape_inference", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_windowed_einsum_handler_test", - srcs = ["gpu_windowed_einsum_handler_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_windowed_einsum_handler", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "triton_fusion_numerics_verifier", - srcs = if_gpu_is_configured(["triton_fusion_numerics_verifier.cc"]), - hdrs = if_gpu_is_configured(["triton_fusion_numerics_verifier.h"]), - deps = if_gpu_is_configured([ - "//xla/service/gpu/autotuning:autotuner_compile_util", - "//xla/service/gpu/autotuning:autotuner_util", - ":backend_configs_cc", - ":buffer_comparator", - ":ir_emission_utils", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:executable", - "//xla/service:hlo_pass", - "//xla/service:shaped_buffer", - "//xla/service:hlo_module_config", - "//xla/stream_executor:stream", - "//xla/tools:hlo_decomposer_lib", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ]), -) - -xla_test( - name = "triton_fusion_numerics_verifier_test", - srcs = if_gpu_is_configured(["triton_fusion_numerics_verifier_test.cc"]), - backend_tags = {"gpu": [ - "requires-gpu-sm80", - ]}, - backends = ["gpu"], - deps = [ - ":triton_fusion_numerics_verifier", - "//xla:shape_util", - "//xla:test_helpers", - "//xla/hlo/ir:hlo", - "//xla/service:platform_util", - "//xla/service/gpu/autotuning:autotuner_compile_util", - "//xla/service/gpu/autotuning:autotuner_util", - "//xla/stream_executor:platform", - "//xla/tests:hlo_test_base", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "pipelined_p2p_rewriter", - srcs = ["pipelined_p2p_rewriter.cc"], - hdrs = ["pipelined_p2p_rewriter.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "pipelined_p2p_rewriter_test", - srcs = ["pipelined_p2p_rewriter_test.cc"], - deps = [ - ":pipelined_p2p_rewriter", - "//xla/hlo/ir:hlo", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - cc_library( name = "execution_stream_assignment", srcs = ["execution_stream_assignment.cc"], @@ -4357,32 +2958,3 @@ xla_cc_test( "@local_tsl//tsl/platform:statusor", ], ) - -cc_library( - name = "scheduling_instruction_annotator", - srcs = ["scheduling_instruction_annotator.cc"], - hdrs = ["scheduling_instruction_annotator.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "scheduling_instruction_annotator_test", - srcs = ["scheduling_instruction_annotator_test.cc"], - deps = [ - ":scheduling_instruction_annotator", - "//xla/hlo/ir:hlo", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index a7d2cf47408a4a..ae541ba167f582 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -36,17 +36,17 @@ limitations under the License. #include "xla/service/gpu/autotuning/conv_algorithm_picker.h" #include "xla/service/gpu/autotuning/gemm_algorithm_picker.h" #include "xla/service/gpu/cublas_padding_requirements.h" -#include "xla/service/gpu/gpu_algebraic_simplifier.h" #include "xla/service/gpu/gpu_compiler.h" -#include "xla/service/gpu/gpu_conv_padding_legalization.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" -#include "xla/service/gpu/gpu_sort_rewriter.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/target_constants.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" +#include "xla/service/gpu/transforms/conv_padding_legalization.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include "xla/service/gpu/transforms/cublas_pad_for_gemms.h" #include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h" #include "xla/service/gpu/transforms/gpusolver_rewriter.h" -#include "xla/service/gpu/triangular_solve_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" +#include "xla/service/gpu/transforms/triangular_solve_rewriter.h" #include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_fix.h" @@ -123,8 +123,8 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(&conv_bf16_support); pipeline.AddPass(); - pipeline.AddPass(gpu_version); - pipeline.AddPass(); + pipeline.AddPass(gpu_version); + pipeline.AddPass(); auto rcc = std::get(gpu_version); pipeline.AddPass(rcc, dnn_version, GetToolkitVersion()); @@ -135,7 +135,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); pipeline.AddPass(); - // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter + // tf2xla bridge, DepthwiseConvolutionConverter and ConvRewriter // introduces reshapes and transposes that can be eliminated using // AlgebraicSimplifier We run algsimp to a fixed point. AlgebraicSimplifierOptions options = @@ -144,7 +144,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(options, gpu_version); - // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and + // tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and // CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover // to a fixed point. Include algsimp because ReshapeMover relies on it. [&, &pipeline = pipeline.AddPass>( @@ -166,7 +166,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(options, gpu_version); }(); - // GpuConvRewriter, GpuConvPaddingLegalization and + // ConvRewriter, ConvPaddingLegalization and // CudnnConvPadForTensorCores may add instructions which can be simplified // by constant folding. pipeline.AddPass(); @@ -240,7 +240,7 @@ absl::Status AMDGPUCompiler::AddConvAndGemmAutotuningPasses( absl::Status AMDGPUCompiler::AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) { if (debug_options.xla_gpu_enable_cub_radix_sort()) { - pipeline->AddPass(); + pipeline->AddPass(); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index ed6aba7efab485..aa82b8678bdb1d 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -1,10 +1,6 @@ # Description: # Components that implement GPU autotuning. -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", -) load( "@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library", @@ -14,10 +10,6 @@ load( "if_cuda_is_configured", ) load("//xla:xla.bzl", "xla_cc_test") -load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", -) load("//xla/tests:build_defs.bzl", "xla_test") package( @@ -35,25 +27,15 @@ package_group( cc_library( name = "gemm_fusion_autotuner", - srcs = if_cuda_is_configured(["gemm_fusion_autotuner.cc"]), - hdrs = if_cuda_is_configured(["gemm_fusion_autotuner.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = if_gpu_is_configured([ + srcs = ["gemm_fusion_autotuner.cc"], + hdrs = ["gemm_fusion_autotuner.h"], + tags = [ + "gpu", + "no_rocm", + ], + deps = [ ":autotuner_compile_util", ":autotuner_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@local_config_cuda//cuda:cuda_headers", "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:status_macros", @@ -73,18 +55,17 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:gpu_float_support", - "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu:instruction_fusion", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:priority_fusion", "//xla/service/gpu:split_k_gemm_rewriter", "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/transforms:cudnn_fusion_compiler", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_rewriter", + "//xla/service/gpu/transforms:instruction_fusion", + "//xla/service/gpu/transforms:priority_fusion", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", @@ -92,6 +73,19 @@ cc_library( "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/lib/core:bits", "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", @@ -101,13 +95,13 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:scoped_annotation", - ]), + ], ) xla_test( name = "gemm_fusion_autotuner_test", timeout = "long", - srcs = if_cuda_is_configured(["gemm_fusion_autotuner_test.cc"]), + srcs = ["gemm_fusion_autotuner_test.cc"], backend_tags = {"gpu": [ "requires-gpu-sm80", ]}, @@ -115,6 +109,7 @@ xla_test( "gpu", ], tags = [ + "no_rocm", "nomac", ], deps = [ @@ -153,93 +148,89 @@ xla_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cc_library( name = "gemm_algorithm_picker", - srcs = if_gpu_is_configured(["gemm_algorithm_picker.cc"]), - hdrs = if_gpu_is_configured(["gemm_algorithm_picker.h"]), - deps = if_gpu_is_configured([ - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:buffer_comparator", - "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:gpu_conv_runner", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:stream_executor_util", - "//xla/service/gpu:variant_visitor", + srcs = ["gemm_algorithm_picker.cc"], + hdrs = ["gemm_algorithm_picker.h"], + tags = ["gpu"], + deps = [ ":autotuner_compile_util", ":autotuner_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", "//xla:autotune_results_proto_cc", + "//xla:autotuning_proto_cc", + "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/service:hlo_pass", - "//xla:status_macros", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:buffer_comparator", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu:variant_visitor", "//xla/stream_executor", "//xla/stream_executor:blas", - "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/gpu:redzone_allocator", "//xla/tsl/util/proto:proto_utils", - "//xla:util", - "//xla:autotuning_proto_cc", - "//xla:shape_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:scoped_annotation", - ]) + ["@com_google_absl//absl/status"], + ], ) cc_library( name = "autotuner_util", - srcs = if_gpu_is_configured(["autotuner_util.cc"]), - hdrs = if_gpu_is_configured(["autotuner_util.h"]), - deps = if_gpu_is_configured([ - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:stream_executor_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Core", - "@llvm-project//llvm:Support", + srcs = ["autotuner_util.cc"], + hdrs = ["autotuner_util.h"], + tags = ["gpu"], + deps = [ "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:compilation_environments", + "//xla/service:dump", + "//xla/service/gpu:gpu_asm_opts_util", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor", "//xla/stream_executor/gpu:redzone_allocator", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:base64", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -247,69 +238,64 @@ cc_library( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", - ]), + ], ) # We need a separate target, as runtime executable cannot depend on compilation # pipeline. cc_library( name = "autotuner_compile_util", - srcs = if_gpu_is_configured(["autotuner_compile_util.cc"]), - hdrs = if_gpu_is_configured(["autotuner_compile_util.h"]), - deps = if_gpu_is_configured([ + srcs = ["autotuner_compile_util.cc"], + hdrs = ["autotuner_compile_util.h"], + tags = ["gpu"], + deps = [ ":autotuner_util", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla:util", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:compiler", + "//xla/service:executable", + "//xla/service:maybe_owning_device_memory", + "//xla/service:shaped_buffer", "//xla/service/gpu:gpu_executable_run_options", "//xla/service/gpu:ir_emission_utils", + "//xla/stream_executor", + "//xla/stream_executor/gpu:redzone_allocator", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "//xla/hlo/ir:hlo", - "//xla/service:compiler", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service:maybe_owning_device_memory", - "//xla/service:shaped_buffer", - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:redzone_allocator", - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:util", - "//xla:xla_proto_cc", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - ]) + ["@com_google_absl//absl/status"], + ], ) xla_test( name = "autotuner_compile_util_test", - srcs = if_gpu_is_configured(["autotuner_compile_util_test.cc"]), + srcs = ["autotuner_compile_util_test.cc"], backends = ["gpu"], - deps = if_gpu_is_configured( - [ - ":autotuner_compile_util", - ":autotuner_util", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "//xla/hlo/ir:hlo", - "//xla/service:platform_util", - "//xla/stream_executor:platform", - "//xla/tests:hlo_test_base", - "@local_tsl//tsl/platform:statusor", - ], - if_false = [ - "@com_google_googletest//:gtest_main", # b/317293391 - ], - ), + deps = [ + ":autotuner_compile_util", + ":autotuner_util", + "//xla/hlo/ir:hlo", + "//xla/service:platform_util", + "//xla/stream_executor:platform", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], ) xla_test( name = "gemm_algorithm_picker_test", - srcs = if_gpu_is_configured(["gemm_algorithm_picker_test.cc"]), + srcs = ["gemm_algorithm_picker_test.cc"], backends = [ "gpu_v100", "gpu_amd_any", @@ -338,32 +324,14 @@ xla_test( cc_library( name = "conv_algorithm_picker", - srcs = if_gpu_is_configured(["conv_algorithm_picker.cc"]), - hdrs = if_gpu_is_configured(["conv_algorithm_picker.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = if_gpu_is_configured([ + srcs = ["conv_algorithm_picker.cc"], + hdrs = ["conv_algorithm_picker.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + tags = ["gpu"], + deps = [ ":autotuner_compile_util", ":autotuner_util", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:buffer_comparator", - "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:gpu_asm_opts_util", ":gpu_autotuning_proto_cc", - "//xla/service/gpu:gpu_conv_runner", - "//xla/service/gpu:hlo_algorithm_denylist", - "//xla/service/gpu:stream_executor_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@local_config_cuda//cuda:cudnn_header", "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", "//xla:debug_options_flags", @@ -376,28 +344,47 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/service:slow_operation_alarm", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:gpu_asm_opts_util", + "//xla/service/gpu:gpu_conv_runner", + "//xla/service/gpu:hlo_algorithm_denylist", + "//xla/service/gpu:stream_executor_util", "//xla/stream_executor", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:dnn", + "//xla/stream_executor:lazy_op_runner", "//xla/stream_executor:numeric_options", "//xla/stream_executor:scratch_allocator", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:lazy_op_runner", "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/stream_executor/gpu:redzone_allocator", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/tsl/util:env_var", + "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", - "//xla/tsl/util:env_var", - "@local_tsl//tsl/platform:statusor", - "//xla/tsl/util/proto:proto_utils", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ] + if_cuda_is_configured([ + # keep sorted + "//xla/service/gpu:buffer_comparator", + "//xla/stream_executor/gpu:redzone_allocator", + "@local_config_cuda//cuda:cudnn_header", ]), ) xla_test( name = "conv_algorithm_picker_test", - srcs = if_gpu_is_configured(["conv_algorithm_picker_test.cc"]), + srcs = ["conv_algorithm_picker_test.cc"], backends = [ "gpu_v100", "gpu_amd_any", @@ -415,8 +402,8 @@ xla_test( "//xla/service:pattern_matcher_gmock", "//xla/service:platform_util", "//xla/service:tuple_simplifier", - "//xla/service/gpu:gpu_conv_rewriter", "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/transforms:conv_rewriter", "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/tests:hlo_test_base", @@ -430,39 +417,18 @@ xla_test( cc_library( name = "custom_kernel_fusion_autotuner", - srcs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.cc"]), - hdrs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = if_gpu_is_configured([ + srcs = ["custom_kernel_fusion_autotuner.cc"], + hdrs = ["custom_kernel_fusion_autotuner.h"], + tags = ["gpu"], + deps = [ ":autotuner_compile_util", ":autotuner_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@local_config_cuda//cuda:cuda_headers", "//xla:autotuning_proto_cc", - "//xla:shape_util", "//xla:status_macros", - "//xla:statusor", "//xla:util", - "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:algorithm_util", - "//xla/service:dump", "//xla/service:executable", - "//xla/service:float_normalization", - "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/service:shaped_buffer", "//xla/service/gpu:backend_configs_cc", @@ -470,7 +436,6 @@ cc_library( "//xla/service/gpu:gpu_float_support", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu:instruction_fusion", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:split_k_gemm_rewriter", @@ -479,29 +444,29 @@ cc_library( "//xla/service/gpu/kernels:custom_kernel_fusion", "//xla/stream_executor", "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", - "//xla/tsl/util/proto:proto_utils", - "@local_tsl//tsl/lib/core:bits", - "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:env", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - ]), + ], ) xla_test( name = "custom_kernel_fusion_autotuner_test", - srcs = if_cuda_is_configured(["custom_kernel_fusion_autotuner_test.cc"]), + srcs = ["custom_kernel_fusion_autotuner_test.cc"], backends = [ "gpu", ], + tags = ["no_rocm"], deps = [ ":autotuner_util", ":custom_kernel_fusion_autotuner", @@ -530,20 +495,24 @@ tf_proto_library( xla_cc_test( name = "autotuner_util_test", - srcs = if_cuda_is_configured(["autotuner_util_test.cc"]), + srcs = ["autotuner_util_test.cc"], data = [ "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb", "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb", "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", ], - deps = if_cuda_is_configured([ - # keep sorted + tags = [ + "gpu", + "no_rocm", + ], + deps = [ ":autotuner_util", "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:dump", "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:platform", @@ -551,6 +520,7 @@ xla_cc_test( "//xla/stream_executor/host:host_platform", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/container:flat_hash_set", @@ -568,7 +538,5 @@ xla_cc_test( "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", - ]) + [ - "//xla/tests:xla_internal_test_main", # Keep outside GPU guard ], ) diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc index 89853c11c29405..79bb7441ea636e 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/dump.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/shape.h" @@ -130,6 +131,9 @@ absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key, return absl::OkStatus(); } + tsl::Env* default_env = tsl::Env::Default(); + TF_RETURN_IF_ERROR(CreateDirIfNeeded(std::string(cache_dir), default_env)); + TF_ASSIGN_OR_RETURN(const std::string file_path, GetCacheFilePath(cache_dir, key)); @@ -145,7 +149,6 @@ absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key, // file. Also avoids reading incomplete files. (This may not work on all file // systems.) std::string temp_file_path = tsl::io::GetTempFilename(".textproto"); - tsl::Env* default_env = tsl::Env::Default(); TF_RETURN_IF_ERROR( tsl::WriteStringToFile(default_env, temp_file_path, result_str)); return default_env->RenameFile(temp_file_path, file_path); diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc index 59de15c525c9d4..974f4d4d2816c2 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/dump.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tests/hlo_test_base.h" @@ -113,8 +114,7 @@ results { static stream_executor::StreamExecutor* NewStreamExecutor() { stream_executor::Platform* platform = stream_executor::PlatformManager::PlatformWithName("Host").value(); - stream_executor::StreamExecutorConfig config(/*ordinal=*/0); - return platform->GetExecutor(config).value(); + return platform->ExecutorForDevice(/*ordinal=*/0).value(); } absl::Status PopulateResultCache() { @@ -278,8 +278,9 @@ class FileBasedCacheTest : public AutotunerUtilTest { return file_content; } - static void Write(const absl::string_view filepath, - const absl::string_view content) { + void Write(const absl::string_view filepath, + const absl::string_view content) { + TF_CHECK_OK(CreateDirIfNeeded(cache_dir_, tsl::Env::Default())); TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(), std::string(filepath), content)); } @@ -293,7 +294,6 @@ class FileBasedCacheTest : public AutotunerUtilTest { tsl::Env* default_env = tsl::Env::Default(); std::string cache_dir; CHECK(default_env->LocalTempFilename(&cache_dir)); - CHECK_OK(default_env->CreateDir(cache_dir)); return cache_dir; }(); AutotuneConfig config_ = AutotuneConfig(DeviceConfig{executor_}, [&] { diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc index 62caa23862aebb..617391514b00f7 100644 --- a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc @@ -57,6 +57,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/lazy_op_runner.h" diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h index 32be7011956452..173a0c61481e57 100644 --- a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h @@ -36,16 +36,10 @@ limitations under the License. #include "xla/service/gpu/gpu_conv_runner.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/shape.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -#include "xla/stream_executor/gpu/redzone_allocator.h" -#endif - namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc index 3c0d49ca650347..96520143e0fe4c 100644 --- a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/autotuning/autotuner_util.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/platform_util.h" @@ -68,7 +68,7 @@ ENTRY main { ->GetDeviceDescription() .gpu_compute_capability(); bool changed = false; - TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(cc), m.get())); + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(ConvRewriter(cc), m.get())); changed = false; DebugOptions opts = DefaultDebugOptionsIgnoringFlags(); @@ -92,7 +92,7 @@ ENTRY main { // should have the new scratch bytes. TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo)); changed = false; - TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(cc), m.get())); + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(ConvRewriter(cc), m.get())); changed = false; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvAlgorithmPicker(cfg), m.get())); diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index a812e0a3564289..269fd815363c23 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -63,16 +63,16 @@ limitations under the License. #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/gpu/priority_fusion.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" +#include "xla/service/gpu/transforms/priority_fusion.h" #include "xla/service/hlo_module_config.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" @@ -321,6 +321,21 @@ absl::StatusOr GetLimits(const HloDotInstruction& dot) { int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; } +int64_t PriorityFusionShapeSize(const Shape& shape) { + // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the + // pointer size is used only to determine the size of tuple types. We + // shouldn't have any tuples in the autotuned module, so it's safe to use + // a constant here, instead of piping the real value. + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +HloCostAnalysis::Options PriorityFusionOptions() { + return {/*shape_size=*/PriorityFusionShapeSize, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; +} + absl::StatusOr> TritonGemmAutotuneExtractor( const TritonGemmConfig& config, const se::DeviceDescription& gpu_device_info, @@ -355,19 +370,8 @@ absl::StatusOr> TritonGemmAutotuneExtractor( TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status()); } - auto shape_size_function = [&](const Shape& shape) { - // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the - // pointer size is used only to determine the size of tuple types. We - // shouldn't have any tuples in the autotuned module, so it's safe to use - // a constant here, instead of piping the real value. - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - GpuPriorityFusion priority_fusion( - /*thread_pool=*/nullptr, gpu_device_info, - GpuHloCostAnalysis::Options{/*shape_size=*/shape_size_function, - /*per_second_rates=*/{}, - /*count_multiple_input_accesses=*/true}); + PriorityFusion priority_fusion( + /*thread_pool=*/nullptr, gpu_device_info, PriorityFusionOptions()); TF_RETURN_IF_ERROR(priority_fusion.Run(new_module.get()).status()); // If the priority fusion pass above skipped some instructions, turn them @@ -379,8 +383,9 @@ absl::StatusOr> TritonGemmAutotuneExtractor( } absl::StatusOr> CublasGemmAutotuneExtractor( - const AutotuneConfig& config, const int32_t toolkit_version, - const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { + const AutotuneConfig& config, const se::DeviceDescription& gpu_device_info, + const int32_t toolkit_version, const HloFusionInstruction* fusion, + const DebugOptions& debug_opts) { const HloComputation* fusion_computation = fusion->called_computations().at(0); std::unique_ptr new_module = @@ -400,11 +405,13 @@ absl::StatusOr> CublasGemmAutotuneExtractor( PrecisionConfig::ALG_DOT_F32_F32_F32); } - for (bool fp8 : {true, false}) { + for (GemmRewriterOptions::DType dtype : + {GemmRewriterOptions::DType::kFp8Only, + GemmRewriterOptions::DType::kNonFp8Only}) { GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version, - fp8); - GpuInstructionFusion fusion_pass( - /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription()); + GemmRewriterOptions{dtype}); + PriorityFusion fusion_pass( + /*thread_pool=*/nullptr, gpu_device_info, PriorityFusionOptions()); TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); } @@ -529,8 +536,9 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, triton_gemm_config, device_desc, fusion, debug_opts, /*allow_filtering_kernels_spilling_registers=*/true); } else if (result.has_gemm()) { - return CublasGemmAutotuneExtractor(autotune_config, toolkit_version, - fusion, debug_opts); + return CublasGemmAutotuneExtractor(autotune_config, device_desc, + toolkit_version, fusion, + debug_opts); } else { LOG(FATAL) << "Unknown result type: " << result.DebugString(); } @@ -783,11 +791,12 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, }) .value_or(nullptr); } else if (std::holds_alternative(config)) { - TF_ASSIGN_OR_RETURN(executable, - compile_util.Compile([&](const DebugOptions& opts) { - return CublasGemmAutotuneExtractor( - config_, toolkit_version_, fusion, opts); - })); + TF_ASSIGN_OR_RETURN( + executable, compile_util.Compile([&](const DebugOptions& opts) { + return CublasGemmAutotuneExtractor( + config_, config_.GetExecutor()->GetDeviceDescription(), + toolkit_version_, fusion, opts); + })); } else { LOG(FATAL) << "Unsupported config type: " << config.index(); } @@ -1199,7 +1208,10 @@ absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store, std::string autotune_results_str, key_value_store.Get( absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, i), - absl::InfiniteDuration())); + // TODO(b/361009609): reset to infinite duration once solved. + // Using an infinite duration here leads to issues with MPI, see + // https://github.com/google/jax/issues/22995. + absl::Hours(24))); TF_RETURN_IF_ERROR( AutotunerUtil::LoadAutotuneResults(autotune_results_str, true)); } diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 909591d0213868..7af1805ecf7010 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -468,23 +468,23 @@ ENTRY %e { })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); - EXPECT_THAT( - backend().compiler()->RunBackend(std::move(module), - backend().default_stream_executor(), - {/*device_allocator=*/nullptr, - /*thread_pool=*/nullptr, - /*layout_canonicalization_callback=*/{}, - /*is_autotuning_compilation=*/true}), - ::testing::AnyOf( - tsl::testing::StatusIs( - tsl::error::CANCELLED, - absl::StrFormat( - "Compilation result discarded due to register spilling")), - // Hopper can't spill registers since wgmma instructions are - // asynchronous, instead it just runs out of them. - tsl::testing::StatusIs( - tsl::error::RESOURCE_EXHAUSTED, - absl::StrFormat("Register allocation failed")))); + EXPECT_THAT(backend().compiler()->RunBackend( + std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*is_autotuning_compilation=*/true}), + ::testing::AnyOf( + tsl::testing::StatusIs( + tsl::error::CANCELLED, + "Compilation result discarded due to register spilling"), + // Hopper can't spill registers since wgmma instructions are + // asynchronous, instead it just runs out of them. + tsl::testing::StatusIs(tsl::error::RESOURCE_EXHAUSTED, + "Register allocation failed"), + tsl::testing::StatusIs( + tsl::error::INTERNAL, + ::testing::HasSubstr("Insufficient registers")))); } // Modify block_k back to 16 once b/337839570 is fixed. @@ -618,9 +618,12 @@ ENTRY main { pipeline.AddPass(autotune_config, GetToolkitVersion(), &thread_pool, key_value_store); pipeline.AddPass(); - for (bool fp8_rewrite : {true, false}) { + for (GemmRewriterOptions::DType dtype : + {GemmRewriterOptions::DType::kFp8Only, + GemmRewriterOptions::DType::kNonFp8Only}) { pipeline.AddPass(autotune_config.GetGpuComputeCapability(), - GetToolkitVersion(), fp8_rewrite); + GetToolkitVersion(), + GemmRewriterOptions{dtype}); } TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get())); diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index c21784b1b3dda8..39f00c826fdc7c 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -28,6 +27,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" @@ -53,25 +54,22 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_unnested.h" #include "xla/service/gpu/metrics.h" -#include "xla/service/gpu/runtime/conditional_thunk.h" -#include "xla/service/gpu/runtime/sequential_thunk.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/service/gpu/runtime/while_thunk.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" #include "xla/shape.h" +#include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/casts.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla::gpu { @@ -102,8 +100,10 @@ void RemoveUnusedAndUninitializedGlobals( } } -static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, - absl::string_view cache_file_path) { +} // namespace + +absl::Status LoadCache(IrEmitterContext& ir_emitter_context, + absl::string_view cache_file_path) { std::string resolved_path; if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) { return FailedPrecondition("File path can not be resolved: %s", @@ -114,7 +114,7 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, TF_RETURN_IF_ERROR( tsl::ReadFileToString(tsl::Env::Default(), resolved_path, &serialized)); CompilationCacheProto proto; - if (!proto.ParseFromString(std::string(serialized))) { + if (!proto.ParseFromString(serialized)) { return Internal("Failed to parse serialized CompilationCacheProto."); } // Register all cached kernel names with the name uniquer to avoid @@ -131,8 +131,6 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, return absl::OkStatus(); } -} // namespace - absl::StatusOr CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h index d7005f879c3994..a451af5a149fad 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/execution_stream_assignment.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo.pb.h" @@ -66,6 +67,9 @@ struct CompileModuleResults { void ForAllThunks(const std::function& fn, ThunkSequence* thunk_sequence); +absl::Status LoadCache(IrEmitterContext& ir_emitter_context, + absl::string_view cache_file_path); + absl::StatusOr CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, diff --git a/third_party/xla/xla/service/gpu/determinism_test.cc b/third_party/xla/xla/service/gpu/determinism_test.cc index 45b970429d9643..2d4a7a94ce087f 100644 --- a/third_party/xla/xla/service/gpu/determinism_test.cc +++ b/third_party/xla/xla/service/gpu/determinism_test.cc @@ -97,6 +97,7 @@ ENTRY e { if (!rocm.has_hipblaslt()) { GTEST_SKIP() << "No hipblas-lt support on this architecture!"; } + debug_options_.set_xla_gpu_enable_triton_gemm(false); #endif // TENSORFLOW_USE_ROCM debug_options_.set_xla_gpu_triton_fusion_level(0); diff --git a/third_party/xla/xla/service/gpu/execution_stream_assignment.cc b/third_party/xla/xla/service/gpu/execution_stream_assignment.cc index 6ee0f2bfbc1f7b..8a55dc555e0550 100644 --- a/third_party/xla/xla/service/gpu/execution_stream_assignment.cc +++ b/third_party/xla/xla/service/gpu/execution_stream_assignment.cc @@ -34,7 +34,8 @@ limitations under the License. namespace xla::gpu { -ExecutionStreamAssignment::ExecutionStreamAssignment(const HloModule* module) { +ExecutionStreamAssignment::ExecutionStreamAssignment( + const HloModule* module, ExecutionStreamAssignmentOptions options) { std::unique_ptr call_graph = CallGraph::Build(module); // We'll walk the `CallGraph` starting from the entrypoint. The instructions @@ -88,14 +89,18 @@ ExecutionStreamAssignment::ExecutionStreamAssignment(const HloModule* module) { // Asynchronous calls will result in a new `ExecutionStreamId` being // dispensed for the called computations. CHECK_EQ(callsite.instruction()->opcode(), HloOpcode::kAsyncStart); - const ExecutionStreamId async_stream_id = next_stream_id++; - enqueue_called_computations(callsite, async_stream_id); + enqueue_called_computations(callsite, next_stream_id); AsyncExecutionStreamIds streams; streams.source_stream_id = pending.stream_id; - streams.destination_stream_id = async_stream_id; + streams.destination_stream_id = next_stream_id; CHECK(async_instructions_.try_emplace(callsite.instruction(), streams) .second); + + next_stream_id++; + if (next_stream_id.value() > options.number_of_execution_streams) { + next_stream_id = ExecutionStreamId(1); + } } else { // Synchronous calls will result in the called computations being // invoked using the same `ExecutionStreamId`. diff --git a/third_party/xla/xla/service/gpu/execution_stream_assignment.h b/third_party/xla/xla/service/gpu/execution_stream_assignment.h index adbd7f04ace5ec..cb0e87ae0e44f2 100644 --- a/third_party/xla/xla/service/gpu/execution_stream_assignment.h +++ b/third_party/xla/xla/service/gpu/execution_stream_assignment.h @@ -26,6 +26,12 @@ limitations under the License. namespace xla::gpu { +struct ExecutionStreamAssignmentOptions { + // The `ExecutionStreamAssignment` will round-robin across this many + // `ExecutionStreams`. + int number_of_execution_streams = 4; +}; + // `ExecutionStreamAssignments` represent a mapping from `HloInstructions` to // `ExecutionStreamIds`. Asynchronous calls (`async-start`, `async-update`, and // `async-done`) result in the target computations being assigned new @@ -37,7 +43,8 @@ class ExecutionStreamAssignment { // pass the module through the `FlattenCallGraph` pass. // // The ExecutionStreamAssignment does not take ownership of the `HloModule`. - explicit ExecutionStreamAssignment(const HloModule* module); + explicit ExecutionStreamAssignment( + const HloModule* module, ExecutionStreamAssignmentOptions options = {}); // Returns the `ExecutionStreamId` for the given instruction, which *must* be // synchronous. Returns an error if the instruction is either not reachable diff --git a/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc b/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc index cf7ec32ab62757..e6abd3e3f5e101 100644 --- a/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc @@ -69,6 +69,10 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) { p0 = f32[2,2] parameter(0) ROOT add = f32[2,2] add(p0, p0) } + leaf3 { + p0 = f32[2,2] parameter(0) + ROOT add = f32[2,2] add(p0, p0) + } // Entry computation that calls each of the leaves asynchronously. ENTRY entry { @@ -77,21 +81,30 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) { kind=kLoop, calls=leaf1 start2 = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0), kind=kLoop, calls=leaf2 + start3 = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0), + kind=kLoop, calls=leaf3 update1 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start1) update2 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start2) + update3 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start3) done1 = f32[2,2] fusion-done(update1) done2 = f32[2,2] fusion-done(update2) - ROOT done = f32[2,2] add(done1, done2) + done3 = f32[2,2] fusion-done(update3) + ROOT done = f32[2,2] custom-call(done1, done2, done3), + custom_call_target="target" } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr)); - ExecutionStreamAssignment assignment(module.get()); + ExecutionStreamAssignment assignment( + module.get(), + ExecutionStreamAssignmentOptions{/*number_of_execution_streams=*/2}); // The outermost computation should run on `ExecutionStreamId(0)`. The two // asynchronous branches should be launched on `ExecutionStreamId(1)` and - // `ExecutionStreamId(2)`, respectively. + // `ExecutionStreamId(2)`, respectively. The third asynchronous branch should + // reuse `ExecutionStreamId(1)` because we set `number_of_execution_streams` + // to `2`. ExpectExecutionStreamForSyncInstructions( assignment, FindComputation(module.get(), "entry"), ExecutionStreamId(0)); for (std::string_view instruction : {"start1", "update1", "done1"}) { @@ -108,6 +121,13 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) { /*source_stream_id=*/ExecutionStreamId(0), /*destination_stream_id=*/ExecutionStreamId(2)})); } + for (std::string_view instruction : {"start3", "update3", "done3"}) { + EXPECT_THAT(assignment.GetAsyncExecutionStreamIds(Cast( + FindInstruction(module.get(), instruction))), + IsOkAndHolds(AsyncExecutionStreamIds{ + /*source_stream_id=*/ExecutionStreamId(0), + /*destination_stream_id=*/ExecutionStreamId(1)})); + } // Leaf computations should run on the respective asynchronous // `ExecutionStreamIds`. diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc index 90e019253fad96..4fc4af0a4cfafa 100644 --- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "xla/service/cpu_gpu_shape_verifier.h" -#include "xla/service/gpu/horizontal_input_fusion.h" -#include "xla/service/gpu/horizontal_loop_fusion.h" -#include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/gpu/multi_output_fusion.h" -#include "xla/service/gpu/priority_fusion.h" #include "xla/service/gpu/transforms/fusion_merger.h" -#include "xla/service/gpu/variadic_op_splitter.h" +#include "xla/service/gpu/transforms/horizontal_input_fusion.h" +#include "xla/service/gpu/transforms/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" +#include "xla/service/gpu/transforms/multi_output_fusion.h" +#include "xla/service/gpu/transforms/priority_fusion.h" +#include "xla/service/gpu/transforms/variadic_op_splitter.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" @@ -63,8 +63,8 @@ HloPassPipeline FusionPipeline( shape_size_bytes_function, /*per_second_rates=*/{}, /*count_multiple_input_accesses=*/true}; - fusion.AddPass(thread_pool, gpu_device_info, - std::move(cost_analysis_options)); + fusion.AddPass(thread_pool, gpu_device_info, + std::move(cost_analysis_options)); } else { fusion.AddPass(/*may_duplicate=*/false, gpu_device_info); @@ -77,8 +77,7 @@ HloPassPipeline FusionPipeline( fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); fusion.AddPass(); - fusion.AddPass(gpu_device_info, - shape_size_bytes_function); + fusion.AddPass(gpu_device_info, shape_size_bytes_function); fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); fusion.AddPass(); @@ -88,8 +87,8 @@ HloPassPipeline FusionPipeline( HloPassPipeline HorizontalFusionPipeline( const se::DeviceDescription& gpu_device_info) { HloPassFix horizontal_fusion("horizontal fusion"); - horizontal_fusion.AddPass(); - horizontal_fusion.AddPass(gpu_device_info); + horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(gpu_device_info); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); horizontal_fusion.AddPass(); diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 58d095b8f2430a..cf9801fe291ffa 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -262,8 +262,8 @@ cc_library( "//xla/service:gpu_plugin", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:affine_map_printer", "//xla/stream_executor:device_description", "//xla/tests:filecheck", @@ -303,10 +303,10 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -348,10 +348,10 @@ cc_library( "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -399,11 +399,11 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir:type_util", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -432,6 +432,7 @@ xla_test( "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", ], ) @@ -630,11 +631,11 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:reduction_utils", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir:type_util", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -726,10 +727,10 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc index e53dbf294345f1..f97a90dc80b5ef 100644 --- a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc @@ -86,6 +86,14 @@ class DynamicSliceFusionTest : public HloTestBase { return config; } + HloModuleConfig GetModuleConfigWithDeterministicOps() { + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_deterministic_ops(true); + HloModuleConfig config; + config.set_debug_options(debug_options); + return config; + } + std::vector GetAddressComputations(const HloModule& module) { std::vector computations; for (auto computation : module.computations()) { @@ -264,8 +272,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, ContiguousSlice) { @@ -1354,8 +1364,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmDynamicWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, DynamicContiguousSlice) { @@ -2183,8 +2195,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmDUSWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, CublasGemmDUSWorkspaceIgnored) { @@ -2268,8 +2282,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmDUSWorkspaceIgnored) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, CublasGemmDUSOffsetS32NotConstant) { @@ -2462,8 +2478,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmDUSOffsetOOB) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) { @@ -2472,9 +2490,7 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) { &b, "__xla_test$$memcpy", /*operands=*/ {DynamicSlice(Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), "p0"), - {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "start0"), - Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start1")}, - {2, 128})}, + {ConstantR0(&b, 2), ConstantR0(&b, 0)}, {2, 128})}, ShapeUtil::MakeShape(F32, {2, 128}), /*opaque=*/"", /*has_side_effect=*/false, /*output_operand_aliasing=*/{}, /*literal=*/nullptr, @@ -2491,7 +2507,6 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) { hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); DynamicSliceFusionRewriter pass(PLATFORM); @@ -2529,11 +2544,7 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallWithTuple) { DynamicSlice( Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), "p0"), - {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), - "start0"), - Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), - "start1")}, - {3, 128}), + {ConstantR0(&b, 20), ConstantR0(&b, 0)}, {3, 128}), }), }, ShapeUtil::MakeTupleShape({ @@ -2572,6 +2583,15 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallWithTuple) { DynamicSliceFusionRewriter pass(PLATFORM); TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); EXPECT_TRUE(changed); + EXPECT_TRUE(*RunFileCheck(hlo_opt->ToString(), R"( + // CHECK: %address-computation{{.+}} { + // CHECK: {{.+}} = {{.+}} slice + // CHECK: {{.+}} = {{.+}} dynamic-slice + // CHECK: {{.+}} = {{.+}} custom-call + // CHECK: ENTRY {{.+}} { + // CHECK-NOT: {{.+}} = {{.+}} slice + // CHECK-NOT: {{.+}} = {{.+}} dynamic-slice + )")); EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), error_spec, /*run_hlo_passes=*/false)); @@ -2936,6 +2956,7 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDUSLoopIterationOffset) { HloModuleConfig ref_config; debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(false); + debugoptions.set_xla_gpu_enable_pipelined_reduce_scatter(false); ref_config.set_debug_options(debugoptions); TF_ASSERT_OK_AND_ASSIGN(auto ref_module, ParseAndReturnVerifiedModule(hlo_ref, ref_config)); @@ -2945,6 +2966,7 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDUSLoopIterationOffset) { HloModuleConfig opt_config; debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(true); opt_config.set_debug_options(debugoptions); + debugoptions.set_xla_gpu_enable_pipelined_reduce_scatter(false); TF_ASSERT_OK_AND_ASSIGN(auto module_with_adddress_computation_flag, ParseAndReturnVerifiedModule(hlo_ref, opt_config)); TF_ASSERT_OK_AND_ASSIGN( diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index d662d254ce4d13..200f06f8461db5 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -113,19 +113,8 @@ std::unique_ptr GetFusionEmitter( .GetModule() ->config() .debug_options(); - auto check_mlir_emitters = [&](int64_t required_level, bool check = true) { - if (opts.xla_gpu_mlir_emitter_level() < required_level) { - return false; - } - CHECK(!check || - mlir_converter::IsHloConversionSupported( - analysis.fusion(), - fusion_info.analysis().device_info().gpu_compute_capability())) - << "Unsupported fusion: " - << analysis.fusion_root(0).instruction().parent()->ToString(); - - VLOG(5) << "Emitting with MLIR."; - return true; + auto check_mlir_emitters = [&](int64_t required_level) { + return opts.xla_gpu_mlir_emitter_level() >= required_level; }; switch (analysis.GetEmitterFusionKind()) { @@ -166,7 +155,7 @@ std::unique_ptr GetFusionEmitter( } return std::make_unique(analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: { - if (check_mlir_emitters(/*required_level=*/2, false)) { + if (check_mlir_emitters(/*required_level=*/2)) { return std::make_unique(analysis); } return std::make_unique(analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc index 5297bf7526de9d..d2739e0a8c3765 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc @@ -38,9 +38,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_analysis.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD similarity index 88% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD rename to third_party/xla/xla/service/gpu/fusions/ir/BUILD index aa8461575c9a5f..8250669ca75d2b 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD @@ -1,4 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -139,3 +140,18 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +xla_test( + name = "xla_gpu_ops_test", + srcs = ["xla_gpu_ops_test.cc"], + backends = ["gpu"], + deps = [ + ":xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD new file mode 100644 index 00000000000000..381d5a3220b1df --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD @@ -0,0 +1,16 @@ +load("//xla:lit.bzl", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "tests", + srcs = glob(["*.mlir"]), + cfg = "//xla:lit.cfg.py", + tools = [ + "//xla/service/gpu/fusions/tools:mlir_fusions_opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir similarity index 83% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir index 34065f9c19d53a..946e73494584be 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir @@ -140,6 +140,42 @@ func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { // ----- +#indexing_map1 = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0 + 8512), + domain: d0 in [0, 1], d1 in [0, 607]> +#indexing_map2 = #xla_gpu.indexing_map< + (d0, d1, d2) -> (((d1 floordiv 32 + 1) mod 3) * 64 + + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2), + domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]> + +func.func @fold_sequence_no_simplification_needed(%i: index) -> index { + %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} + %ind0 = xla_gpu.apply_indexing #indexing_map1(%i, %thread_id_x) + %ind1 = xla_gpu.apply_indexing #indexing_map2(%ind0, %thread_id_x, %i) + func.return %ind1 : index +} +// CHECK: xla_gpu.apply_indexing +// CHECK-NOT: xla_gpu.apply_indexing + +// ----- + +#indexing_map1 = #xla_gpu.indexing_map<(d0) -> (3 * d0), + domain: d0 in [0, 9407]> +#indexing_map2 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 floordiv 32 + 1), + domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]> +#indexing_map3 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 floordiv 32 + 2), + domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]> + +func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { + %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} + %ind0 = xla_gpu.apply_indexing #indexing_map1(%thread_id_x) + %ind1 = xla_gpu.apply_indexing #indexing_map2(%ind0, %thread_id_x, %i) + %ind2 = xla_gpu.apply_indexing #indexing_map3(%ind0, %thread_id_x, %i) + func.return %ind1, %ind2 : index, index +} +// CHECK-COUNT-3: xla_gpu.apply_indexing + +// ----- + func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir similarity index 88% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir index 81ada2be721dbe..999a6de959328c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir @@ -209,6 +209,26 @@ func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread #map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]> #map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]> +func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id, %block_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]> +func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id, %block_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index b29170fa5aee21..a3220b06ccf9d2 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -31,7 +31,7 @@ limitations under the License. #define GET_ATTRDEF_LIST #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td similarity index 97% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td index af5663f2fc0f92..19dd24f2e67a2c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td @@ -17,7 +17,7 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS include "mlir/IR/AttrTypeBase.td" -include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td" +include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" class XLAGPU_Attr traits = []> : AttrDef { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc similarity index 93% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc index 6fd297f1a6dc98..57d2d706737089 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc @@ -17,12 +17,12 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/Transforms/InliningUtils.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" #undef GET_ATTRDEF_CLASSES #define GET_TYPEDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_types.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc" #undef GET_TYPEDEF_CLASSES namespace xla { @@ -116,18 +116,18 @@ struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { void XlaGpuDialect::initialize() { addOperations< #define GET_OP_LIST -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc" #undef GET_OP_LIST >(); addAttributes< #define GET_ATTRDEF_LIST -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" >(); #undef GET_ATTRDEF_LIST addInterfaces(); addTypes< #define GET_TYPEDEF_LIST -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_types.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc" #undef GET_TYPEDEF_LIST >(); } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td index b55d3615765e5a..9a5c539e39e591 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td @@ -30,4 +30,4 @@ def XlaGpuDialect : Dialect { let useDefaultTypePrinterParser = 1; } -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_DIALECT +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc similarity index 87% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index 5ad1941028a9aa..39a0318af58896 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include #include #include #include -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -45,7 +44,7 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -270,6 +269,19 @@ struct FoldApplyIndexingSequence LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, PatternRewriter& rewriter) const override { + SmallVector, 2> apply_indexing_ops; + bool all_apply_indexing_operands_have_one_use = true; + for (auto& operand : indexing_op->getOpOperands()) { + if (auto producer = operand.get().getDefiningOp()) { + apply_indexing_ops.push_back({operand.getOperandNumber(), producer}); + all_apply_indexing_operands_have_one_use &= producer->hasOneUse(); + } + } + if (apply_indexing_ops.empty()) { + return rewriter.notifyMatchFailure(indexing_op, + "No apply_indexing sequences found"); + } + MLIRContext* ctx = indexing_op.getContext(); int num_dims = indexing_op.getAffineMap().getNumDims(); int num_syms = indexing_op.getAffineMap().getNumSymbols(); @@ -290,53 +302,44 @@ struct FoldApplyIndexingSequence auto new_sym_vars = this_map.GetRangeVars(); mlir::DenseMap replacements; - for (auto& operand : indexing_op->getOpOperands()) { - if (auto producer = operand.get().getDefiningOp()) { - auto producer_map = producer.getIndexingMap(); - int producer_result_id = - mlir::cast(operand.get()).getResultNumber(); - int num_producer_dims = producer.getAffineMap().getNumDims(); - SmallVector producer_dim_replacements; - SmallVector producer_sym_replacements; - for (auto& producer_operand : producer->getOpOperands()) { - int producer_operand_number = producer_operand.getOperandNumber(); - bool is_dim = producer_operand_number < num_producer_dims; - auto& replacement_expr = operand_exprs[producer_operand.get()]; - if (!replacement_expr) { - if (is_dim) { - int dim_num = producer_operand_number; - replacement_expr = - getAffineDimExpr(num_dims + added_dim_args.size(), ctx); - added_dim_args.push_back(producer_operand.get()); - new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); - } else { - int sym_num = producer_operand_number - - producer.getAffineMap().getNumDims(); - replacement_expr = - getAffineSymbolExpr(num_syms + added_sym_args.size(), ctx); - added_sym_args.push_back(producer_operand.get()); - new_sym_vars.push_back(producer_map.GetRangeVar(sym_num)); - } - } - + for (auto& [operand_id, producer] : apply_indexing_ops) { + auto producer_map = producer.getIndexingMap(); + mlir::OpResult producer_result = producer->getOpResult(0); + int producer_result_id = producer_result.getResultNumber(); + int num_producer_dims = producer.getAffineMap().getNumDims(); + SmallVector producer_dim_replacements; + SmallVector producer_sym_replacements; + for (auto& producer_operand : producer->getOpOperands()) { + int producer_operand_number = producer_operand.getOperandNumber(); + bool is_dim = producer_operand_number < num_producer_dims; + auto& replacement_expr = operand_exprs[producer_operand.get()]; + if (!replacement_expr) { if (is_dim) { - producer_dim_replacements.push_back(replacement_expr); + int dim_num = producer_operand_number; + replacement_expr = + getAffineDimExpr(num_dims + added_dim_args.size(), ctx); + added_dim_args.push_back(producer_operand.get()); + new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); } else { - producer_sym_replacements.push_back(replacement_expr); + int sym_num = + producer_operand_number - producer.getAffineMap().getNumDims(); + replacement_expr = + getAffineSymbolExpr(num_syms + added_sym_args.size(), ctx); + added_sym_args.push_back(producer_operand.get()); + new_sym_vars.push_back(producer_map.GetRangeVar(sym_num)); } } - - replacements[operand_exprs[operand.get()]] = - producer.getAffineMap() - .getResult(producer_result_id) - .replaceDimsAndSymbols(producer_dim_replacements, - producer_sym_replacements); + if (is_dim) { + producer_dim_replacements.push_back(replacement_expr); + } else { + producer_sym_replacements.push_back(replacement_expr); + } } - } - - if (replacements.empty()) { - return rewriter.notifyMatchFailure(indexing_op, - "No apply_indexing sequences found"); + replacements[operand_exprs[producer_result]] = + producer.getAffineMap() + .getResult(producer_result_id) + .replaceDimsAndSymbols(producer_dim_replacements, + producer_sym_replacements); } int new_num_operands = indexing_op->getNumOperands() + @@ -346,10 +349,12 @@ struct FoldApplyIndexingSequence num_syms + added_sym_args.size()); IndexingMap new_indexing_map(new_affine_map, new_dim_vars, new_sym_vars, /*rt_vars=*/{}); - if (!new_indexing_map.Simplify()) { + if (!all_apply_indexing_operands_have_one_use && + !new_indexing_map.Simplify()) { return rewriter.notifyMatchFailure( indexing_op, "Folded indexing map was not simplified"); } + SmallVector new_operands; new_operands.reserve(new_num_operands); @@ -745,6 +750,22 @@ IndexingMap LoopOp::getIndexingMap() { // MaterializeOp //===----------------------------------------------------------------------===// +VariableConstraints GetConstraintsForVariables(const IndexingMap& map) { + VariableConstraints result; + result.constraints_for_dims.resize(map.GetDimensionCount()); + result.constraints_for_symbols.resize(map.GetSymbolCount()); + for (const auto& constraint : map.GetConstraints()) { + constraint.first.walk([&](mlir::AffineExpr leaf) { + if (auto dim = mlir::dyn_cast(leaf)) { + result.constraints_for_dims[dim.getPosition()].push_back(constraint); + } else if (auto sym = mlir::dyn_cast(leaf)) { + result.constraints_for_symbols[sym.getPosition()].push_back(constraint); + } + }); + } + return result; +} + LogicalResult MaterializeOp::verify() { IndexingMap map_in = getMap().getIndexingMap(); IndexingMap map_out = @@ -763,9 +784,11 @@ LogicalResult MaterializeOp::verify() { return emitOpError() << "thread_id dimension must have the same bounds in " "both indexing maps"; } - auto thread_id_constraints_in = map_in.GetConstraintsForDim(0); - auto thread_id_constraints_out = map_out.GetConstraintsForDim(0); - if (thread_id_constraints_in != thread_id_constraints_out) { + + auto variable_constraints_in = GetConstraintsForVariables(map_in); + auto variable_constraints_out = GetConstraintsForVariables(map_out); + if (variable_constraints_in.constraints_for_dims[0] != + variable_constraints_out.constraints_for_dims[0]) { return emitOpError() << "constraints of indexing maps must be equal for " << "the thread_id dimension"; } @@ -781,13 +804,10 @@ LogicalResult MaterializeOp::verify() { return emitOpError() << "domain of symbols of indexing_maps must match"; } } - for (int symbol_id = 0; symbol_id < map_in.GetRangeVarsCount(); ++symbol_id) { - auto constraints_in = map_in.GetConstraintsForSymbol(symbol_id); - auto constraints_out = map_out.GetConstraintsForSymbol(symbol_id); - if (constraints_in != constraints_out) { - return emitOpError() - << "constraints of indexing maps must be equal for all symbols"; - } + if (variable_constraints_in.constraints_for_symbols != + variable_constraints_out.constraints_for_symbols) { + return emitOpError() + << "constraints of indexing maps must be equal for all symbols"; } // The vector mapping indices must not depend on the block ID @@ -801,9 +821,18 @@ LogicalResult MaterializeOp::verify() { } // If there are constraints on the block ID, they must be the same in both // maps - auto block_id_constraints_in = map_in.GetConstraintsForDim(1); - auto block_id_constraints_out = map_out.GetConstraintsForDim(1); - if (block_id_constraints_in != block_id_constraints_out) { + if (map_in.GetDimVarsCount() > 1 && map_out.GetDimVarsCount() > 1) { + if (variable_constraints_in.constraints_for_dims[1] != + variable_constraints_out.constraints_for_dims[1]) { + return emitOpError() << "constraints of indexing maps must be equal for " + << "the block_id dimension"; + } + } else if (map_in.GetDimVarsCount() > 1 && + !variable_constraints_in.constraints_for_dims[1].empty()) { + return emitOpError() << "constraints of indexing maps must be equal for " + << "the block_id dimension"; + } else if (map_out.GetDimVarsCount() > 1 && + !variable_constraints_out.constraints_for_dims[1].empty()) { return emitOpError() << "constraints of indexing maps must be equal for " << "the block_id dimension"; } @@ -815,4 +844,4 @@ LogicalResult MaterializeOp::verify() { } // namespace xla #define GET_OP_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h similarity index 66% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h index 589aa60379cb94..e3b8fd641f9a06 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h @@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ -#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ +#define XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ +#include + +#include "llvm/ADT/SmallVector.h" #include "mlir/Bytecode/BytecodeOpInterface.h" // IWYU pragma: keep #include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep #include "mlir/IR/Attributes.h" // IWYU pragma: keep @@ -26,16 +29,28 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // IWYU pragma: keep #include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" #include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" #undef GET_ATTRDEF_CLASSES #define GET_TYPEDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_types.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc" #undef GET_TYPEDEF_CLASSES #define GET_OP_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h.inc" #undef GET_OP_CLASSES -#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ +namespace xla::gpu { + +struct VariableConstraints { + llvm::SmallVector>> + constraints_for_dims; + llvm::SmallVector>> + constraints_for_symbols; +}; +VariableConstraints GetConstraintsForVariables(const IndexingMap& map); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td index 99e1c6c5eb257f..9eb246f70d9d34 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td @@ -23,9 +23,9 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" -include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td" -include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td" -include "xla/service/gpu/fusions/mlir/ir/xla_gpu_types.td" +include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" +include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td" +include "xla/service/gpu/fusions/ir/xla_gpu_types.td" class XLAGPU_Op traits = []> : Op { @@ -364,4 +364,4 @@ def XLAGPU_InsertOp : XLAGPU_Op<"insert", []> { }]; } -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_OPS diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc new file mode 100644 index 00000000000000..2d9076d7803280 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" + +#include +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class XLAGPUOpsTest : public HloTestBase { + public: + mlir::MLIRContext mlir_context_; +}; + +TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) { + auto map = IndexingMap( + ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_), + /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}}, + /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{}); + map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), + Interval{0, 1}); + map.AddConstraint(ParseAffineExpr("s1 mod 4", &mlir_context_), + Interval{0, 2}); + map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}); + map.AddConstraint(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}); + map.AddConstraint(ParseAffineExpr("d1 mod 32", &mlir_context_), + Interval{0, 6}); + + auto constraints_for_variables = GetConstraintsForVariables(map); + EXPECT_THAT(constraints_for_variables.constraints_for_dims[0], + UnorderedElementsAre()); + EXPECT_THAT( + constraints_for_variables.constraints_for_dims[1], + UnorderedElementsAre( + Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}), + Pair(ParseAffineExpr("d1 mod 32", &mlir_context_), Interval{0, 6}))); + EXPECT_THAT( + constraints_for_variables.constraints_for_symbols[0], + UnorderedElementsAre( + Pair(ParseAffineExpr("s0 mod 4", &mlir_context_), Interval{0, 1}), + Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}))); + EXPECT_THAT( + constraints_for_variables.constraints_for_symbols[1], + UnorderedElementsAre( + Pair(ParseAffineExpr("s1 mod 4", &mlir_context_), Interval{0, 2}), + Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}), + Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}))); +} + +TEST_F(XLAGPUOpsTest, GetConstraintsForVariablesEmpty) { + auto map = IndexingMap( + ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_), + /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}}, + /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{}); + auto constraints_for_variables = GetConstraintsForVariables(map); + EXPECT_THAT(constraints_for_variables.constraints_for_dims, + ElementsAre(IsEmpty(), IsEmpty())); + EXPECT_THAT(constraints_for_variables.constraints_for_symbols, + ElementsAre(IsEmpty(), IsEmpty())); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_types.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc similarity index 90% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_types.cc rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc index 2a041517b1c2e2..1c1b218db7bc19 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_types.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc @@ -21,14 +21,14 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" #include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" #undef GET_ATTRDEF_CLASSES #define GET_TYPEDEF_LIST #define GET_TYPEDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_types.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_types.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td similarity index 93% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_types.td rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td index afb474ead47be4..5d73344654d1de 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_types.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td @@ -19,8 +19,8 @@ limitations under the License. include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/BuiltinTypeInterfaces.td" -include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td" -include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td" +include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" +include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td" class XLAGPU_Type traits = []> : TypeDef { diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD index e5c68d0ff35224..98d8ade7c5e5c3 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD @@ -308,7 +308,6 @@ cc_library( ":tiling_util", "//xla:permutation_util", "//xla:shape_util", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc index a05508c2e378cc..5832d13701a3dc 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc @@ -88,20 +88,20 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000, - ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, - ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id + (bl_x * 128 + th_x) floordiv 15000, + ((bl_x * 128 + th_x) floordiv 75) mod 200, + ((bl_x * 128 + th_x) mod 75) * 4 + unroll_id ) domain: th_x in [0, 127] th_y in [0, 0] th_z in [0, 0] - bl_x in [0, 1007] + bl_x in [0, 11718] bl_y in [0, 0] bl_z in [0, 0] - chunk_id in [0, 11] + chunk_id in [0, 0] unroll_id in [0, 3] - bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999] + bl_x * 128 + th_x in [0, 1499999] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc index 16806889510f13..d6cbdecf4bfceb 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc @@ -52,7 +52,6 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/llvm_ir/loop_emitter.h" #include "xla/shape_util.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -66,12 +65,13 @@ Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info, static_assert(WarpSize() % kNumRows == 0); // 3D view over the output shape. - Vector3 transposed_dims = tiled_transpose.dimensions; - Vector3 permutation = tiled_transpose.permutation; + absl::InlinedVector transposed_dims = tiled_transpose.dimensions; + absl::InlinedVector permutation = tiled_transpose.permutation; // Note: the supported permutations are their own inverses. Therefore we // always use the permutation, even when we want the inverse. - CHECK((permutation == Vector3{0, 2, 1}) || (permutation == Vector3{2, 1, 0})); + CHECK((permutation == absl::InlinedVector{0, 2, 1}) || + (permutation == absl::InlinedVector{2, 1, 0})); absl::InlinedVector input_dims{transposed_dims[permutation[0]], transposed_dims[permutation[1]], @@ -189,7 +189,7 @@ absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, } absl::flat_hash_map tiles; - Vector3 permutation; + absl::InlinedVector permutation; for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) { permutation = tr.permutation; auto tile_size = tiling_.GetBlockTileSize(); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h index 323113d4e4639d..3366130c05546b 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "llvm/IR/IRBuilder.h" #include "mlir/IR/MLIRContext.h" @@ -30,7 +31,6 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/ir_array.h" -#include "xla/util.h" namespace xla { namespace gpu { @@ -82,7 +82,7 @@ class TransposeFusion : public KernelFusionEmitterBase { private: const HloFusionAnalysis& analysis_; Tiling tiling_; - Vector3 permutation_; + absl::InlinedVector permutation_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc index f33ae04b958a73..43a417843858db 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc @@ -118,18 +118,18 @@ TEST_F(TransposeTest, ThreadIndexing021) { )")); } -TEST_F(TransposeTest, ThreadIndexing201) { +TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { auto module = ParseAndReturnVerifiedModule(R"( HloModule module fusion { - %input = f32[100,64,32] parameter(0) - ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1} + %input = f32[1,6400,32] parameter(0) + ROOT transpose = f32[1,32,6400] transpose(%input), dimensions={0,2,1} } ENTRY entry { - %input = f32[100,64,32] parameter(0) - ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion + %input = f32[1,6400,32] parameter(0) + ROOT %fusion = f32[1,32,6400] fusion(%input), kind=kInput, calls=fusion })") .value(); @@ -142,8 +142,8 @@ TEST_F(TransposeTest, ThreadIndexing201) { fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32, + 0, + d3 * 32 + s1 * 4 + d0 floordiv 32, d0 mod 32 ) domain: @@ -162,9 +162,9 @@ TEST_F(TransposeTest, ThreadIndexing201) { fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + 0, d0 floordiv 32 + s1 * 4, - d3 floordiv 2, - (d3 mod 2) * 32 + d0 mod 32 + d3 * 32 + d0 mod 32 ) domain: d0 in [0, 127] @@ -185,13 +185,13 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { HloModule m fused_computation { - %p0 = f64[24,2,6,4] parameter(0) - ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} + %p0 = f64[24,2,24] parameter(0) + ROOT %t = f64[24,2,24] transpose(%p0), dimensions={2,1,0} } ENTRY main { - %p0 = f64[24,2,6,4] parameter(0) - ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput, + %p0 = f64[24,2,24] parameter(0) + ROOT %fusion = f64[24,2,24] fusion(%p0), kind=kInput, calls=%fused_computation } )") @@ -208,8 +208,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d0 floordiv 32 + s0 * 4, d3, - (d0 floordiv 4) mod 8, - d0 mod 4 + d0 mod 32 ) domain: d0 in [0, 127] @@ -227,8 +226,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - s0, - d0 floordiv 32, + d0 floordiv 32 + s0 * 4, d3, d0 mod 32 ) diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc index 9db9173ab05141..4c6bdac0ce01ad 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc @@ -35,9 +35,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc index eda89f8b70bb70..37efa18945e58d 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -29,100 +29,6 @@ namespace { using MlirLoopFusionTest = MlirEmitterTestBase; -TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( - HloModule module - - neg { - %input = f32[100,200,300] parameter(0) - ROOT neg = f32[100,200,300] negate(%input) - } - ENTRY entry { - %input = f32[100,200,300] parameter(0) - ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg - } - )")); - thread_id_printer_.SetSymbolName(0, "chunk_id"); - thread_id_printer_.SetSymbolName(1, "unroll_id"); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - MlirLoopFusion fusion(analysis); - auto thread_id_to_output_indexing = - fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - - EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000, - ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, - ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id - ) - domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 1007] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 11] - unroll_id in [0, 3] - bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999] -)")); -} - -TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( - HloModule module - - neg { - %input = f32[20] parameter(0) - ROOT neg = f32[20] negate(%input) - } - ENTRY entry { - %input = f32[20] parameter(0) - ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg - } - )")); - thread_id_printer_.SetSymbolName(0, "chunk_id"); - thread_id_printer_.SetSymbolName(1, "unroll_id"); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - MlirLoopFusion fusion(analysis); - auto thread_id_to_output_indexing = - fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) - domain: - th_x in [0, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - )")); - auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) - domain: - th_x in [0, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - )")); -} - TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule module @@ -182,42 +88,6 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { )")); } -TEST_F(MlirLoopFusionTest, Constant_Broadcast) { - auto kHloString = R"( - HloModule module - - bcast { - zero = bf16[] constant(0) - ROOT broadcast = bf16[2,16,48]{2,1,0} broadcast(zero), dimensions={} - } - - ENTRY entry { - ROOT %fusion = bf16[2,16,48]{2,1,0} fusion(), kind=kLoop, calls=bcast - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP0:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 1024 + d0) - // CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 1024 + d0) floordiv 768) - // CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16) - // CHECK-DAG: #[[MAP3:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48) - // CHECK: func.func @fused_computation(%[[ARG0:.*]]: tensor<2x16x48xbf16> - // CHECK: %[[UPPER_BOUND:.*]] = arith.constant 1535 : index - // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id - // CHECK: %[[BLOCK_ID:.*]] = gpu.block_id - // CHECK: %[[LINEAR:.*]] = xla_gpu.apply_indexing #[[MAP0]] - // CHECL: %[[IN_BOUNDS:.*]] = arith.cmpi sle, %[[LINEAR]], %[[UPPER_BOUND]] : index - // scf.if %[[IN_BOUNDS]] - // CHECK: %[[I0:.*]] = xla_gpu.apply_indexing #[[MAP1]] - // CHECK: %[[I1:.*]] = xla_gpu.apply_indexing #[[MAP2]] - // CHECK: %[[I2:.*]] = xla_gpu.apply_indexing #[[MAP3]] - // CHECK: %[[BCAST:.*]] = xla_gpu.pure_call @bcast_broadcast - // CHECK: %[[INSERTED:.*]] = tensor.insert %[[BCAST]] into %[[ARG0]][%[[I0]], %[[I1]], %[[I2]]] - // CHECK: func.func private @bcast_broadcast - // CHECK: arith.constant 0.000000e+00 - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0})); -} - TEST_F(MlirLoopFusionTest, NoCodeDuplication) { // This test HLO is copied from // xla/service/fusion_node_indexing_evaluation_test.cc. @@ -253,85 +123,6 @@ TEST_F(MlirLoopFusionTest, NoCodeDuplication) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirLoopFusionTest, TwoUsersConsistentIndexing) { - auto kHloString = R"( - HloModule test_module - - %fused_computation (param: f32[6]) -> f32[2] { - %p0 = f32[2]{0} parameter(0) - %p1 = f32[2]{0} parameter(1) - %add = f32[2] add(%p0, %p1) - %sub = f32[2] subtract(%p0, %p1) - %mul = f32[2] multiply(%add, %sub) - %div = f32[2] divide(%add, %sub) - ROOT %atan2 = f32[2] atan2(%mul, %div) - } - ENTRY entry_computation { - p0 = f32[2] parameter(0) - p1 = f32[2] parameter(1) - ROOT %fusion = f32[2] fusion(p0, p1), kind=kLoop, calls=%fused_computation - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: func.func @fused_computation - // CHECK-NEXT: gpu.thread_id - // CHECK-NEXT: pure_call @fused_computation_atan2 - // CHECK-NEXT: tensor.insert - // CHECK-NEXT: return - - // CHECK: func.func private @fused_computation_atan2 - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: addf - // CHECK-NEXT: subf - // CHECK-NEXT: mulf - // CHECK-NEXT: divf - // CHECK-NEXT: atan2 - // CHECK-NEXT: return - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, ComplexOps) { - auto kHloString = R"( - HloModule test_module - - %fused_computation { - %p0 = f32[2]{0} parameter(0) - %p1 = f32[2]{0} parameter(1) - %p2 = c64[2]{0} parameter(2) - %complex = c64[2] complex(%p0, %p1) - %add = c64[2] add(%complex, %p2) - %cst = c64[2]{0} constant({(2.0, 0.0), (0.0, 2.0)}) - ROOT %mul = c64[2] multiply(%add, %cst) - } - ENTRY entry_computation { - p0 = f32[2] parameter(0) - p1 = f32[2] parameter(1) - p2 = c64[2] parameter(2) - ROOT %fusion = c64[2] fusion(p0, p1, p2), kind=kLoop, calls=%fused_computation - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: func.func @fused_computation - // CHECK-NEXT: gpu.thread_id - // CHECK-NEXT: pure_call @fused_computation_mul - // CHECK-NEXT: tensor.insert - // CHECK-NEXT: return - - // CHECK: func.func private @fused_computation_mul - // CHECK-NEXT: arith.constant - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: complex.create - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: complex.add - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: complex.mul - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirLoopFusionTest, IotaCopyBitcastBroadcastReshapeReverseTranspose) { auto kHloString = R"( HloModule test_module @@ -359,137 +150,6 @@ TEST_F(MlirLoopFusionTest, IotaCopyBitcastBroadcastReshapeReverseTranspose) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirLoopFusionTest, VariadicReduce) { - auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - scalar_lhs.0 = f32[] parameter(0) - scalar_lhs.1 = f32[] parameter(1) - scalar_rhs.0 = f32[] parameter(2) - scalar_rhs.1 = f32[] parameter(3) - add = f32[] add(scalar_lhs.0, scalar_rhs.0) - mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) - ROOT t = (f32[], f32[]) tuple(add, mul) - } - fused_computation { - param_0 = f32[3,4,5]{2,1,0} parameter(0) - param_1 = f32[3,4,5]{2,1,0} parameter(1) - param_2 = f32[] parameter(2) - ROOT d.1 = (f32[4], f32[4]) reduce(f32[3,4,5]{2,1,0} param_0, - f32[3,4,5]{2,1,0} %param_1, f32[] param_2, f32[] param_2), - dimensions={0,2}, to_apply=Add - } - ENTRY main { - a = f32[3,4,5]{2,1,0} parameter(0) - b = f32[3,4,5]{2,1,0} parameter(1) - c = f32[] constant(0) - ROOT fusion = (f32[4]{0}, f32[4]{0}) fusion(a, b, c), - kind=kLoop, calls=fused_computation - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: func @fused_computation( - // CHECK: %[[TID_X:.*]] = gpu.thread_id x - // CHECK: %[[SCALARS_0:.*]], %[[SCALARS_1:.*]] = xla_gpu.pure_call @fused_computation_d_1 - // CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[SCALARS_0]] into %{{.*}}[%[[TID_X]]] - // CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[SCALARS_1]] into %{{.*}}[%[[TID_X]]] - // CHECK: return %[[INSERTED_1]], %[[INSERTED_2]] - - // CHECK: func private @fused_computation_d_1 - // CHECK: %[[RET:.*]]:2 = func.call @Add_t - // CHECK: yield %[[RET]]#0, %[[RET]]#1 - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, MinimumMaximum) { - auto kHloString = R"( - HloModule Test - - fused_computation { - param0 = f64[] parameter(0) - param1 = f64[] parameter(1) - - minimum = f64[] minimum(f64[] param0, f64[] param1) - maximum = f64[] maximum(f64[] param0, f64[] param1) - ROOT tuple = (f64[], f64[]) tuple(minimum, maximum) - } - - ENTRY main { - param0 = f64[] parameter(0) - param1 = f64[] parameter(1) - ROOT fusion = (f64[], f64[]) fusion(f64[] param0, f64[] param1), kind=kLoop, calls=fused_computation - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: func.func @fused_computation - // CHECK: xla_gpu.pure_call @fused_computation_tuple - // CHECK: func.func private @fused_computation_tuple - // CHECK-DAG: arith.minimumf - // CHECK-DAG: arith.maximumf - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, TupleBitcast) { - auto kHloString = R"( - HloModule Test - - fused_computation { - param0 = f64[8] parameter(0) - param1 = f64[8] parameter(1) - - minimum = f64[8] minimum(param0, param1) - maximum = f64[8] maximum(param0, param1) - bc = f64[2, 4] bitcast(maximum) - ROOT tuple = (f64[8], f64[2,4]) tuple(minimum, bc) - } - - ENTRY main { - param0 = f64[8] parameter(0) - param1 = f64[8] parameter(1) - ROOT fusion = (f64[8], f64[2,4]) fusion(param0, param1), - kind=kLoop, calls=fused_computation - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, NestedTuple) { - auto kHloString = R"( - add { - scalar_lhs.0 = f32[] parameter(0) - scalar_lhs.1 = f32[] parameter(1) - scalar_rhs.0 = f32[] parameter(2) - scalar_rhs.1 = f32[] parameter(3) - add = f32[] add(scalar_lhs.0, scalar_rhs.0) - mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) - ROOT t = (f32[], f32[]) tuple(add, mul) - } - fused_computation { - param_0 = f32[3,4,5]{2,1,0} parameter(0) - param_1 = f32[3,4,5]{2,1,0} parameter(1) - param_2 = f32[] parameter(2) - param_3 = f32[4] parameter(3) - reduce = (f32[4], f32[4]) reduce(f32[3,4,5]{2,1,0} param_0, - f32[3,4,5]{2,1,0} %param_1, f32[] param_2, f32[] param_2), - dimensions={0,2}, to_apply=add - log = f32[4] log(param_3) - ROOT tuple = ((f32[4], f32[4]), f32[4]) tuple(reduce, log) - } - ENTRY main { - a = f32[3,4,5]{2,1,0} parameter(0) - b = f32[3,4,5]{2,1,0} parameter(1) - c = f32[] constant(0) - d = f32[4] parameter(2) - ROOT fusion = ((f32[4], f32[4]), f32[4]) fusion(a, b, c, d), - kind=kLoop, calls=fused_computation - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirLoopFusionTest, DynamicSliceWith64BitInput) { // Lowering this kernel with 32 bit indices causes an underflow of `c`, // resulting in slicing the last four elements instead of the first four. @@ -511,63 +171,6 @@ TEST_F(MlirLoopFusionTest, DynamicSliceWith64BitInput) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirLoopFusionTest, DynamicUpdateSlice) { - constexpr auto kHloString = R"( - %fused_computation { - in = c64[2,3] parameter(0) - updates = c64[2,2] parameter(1) - i0 = s32[] parameter(2) - i1 = s32[] parameter(3) - updated = c64[2,3] dynamic-update-slice(in, updates, i0, i1) - ROOT transpose = c64[3,2] transpose(updated), dimensions={1,0} - } - - ENTRY main { - p0 = c64[2,3] parameter(0) - p1 = c64[2,2] parameter(1) - p2 = s32[] parameter(2) - p3 = s32[] parameter(3) - ROOT %fusion = c64[3,2] fusion(p0, p1, p2, p3), kind=kLoop, calls=%fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: scf.if - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, NotPred) { - constexpr auto kHloString = R"( - %fused_computation { - p0 = s8[1000] parameter(0) - cvt = pred[1000] convert(p0) - ROOT not = pred[1000] not(cvt) - } - - ENTRY main { - p0 = s8[1000] parameter(0) - ROOT %fusion = pred[1000] fusion(p0), kind=kLoop, calls=%fused_computation - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, MulPred) { - constexpr auto kHloString = R"( - %fused_computation { - p0 = s8[1000] parameter(0) - p1 = s8[1000] parameter(1) - cvt0 = pred[1000] convert(p0) - cvt1 = pred[1000] convert(p1) - ROOT mul = pred[1000] multiply(cvt0, cvt1) - } - - ENTRY main { - p0 = s8[1000] parameter(0) - p1 = s8[1000] parameter(1) - ROOT %fusion = pred[1000] fusion(p0, p1), kind=kLoop, calls=%fused_computation - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index 7fe72d7f9932c3..08a159f9b268f3 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -75,7 +75,7 @@ cc_library( "//xla/mlir_hlo:type_conversion", "//xla/service:algorithm_util", "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", @@ -118,7 +118,7 @@ xla_cc_test( "//xla/mlir_hlo", "//xla/service:hlo_parser", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:launch_dim", @@ -170,7 +170,7 @@ cc_library( "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:target_util", "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/transforms:passes", "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/runtime:kernel_thunk", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 52af21061c6311..839dd96ab48fea 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -67,24 +66,18 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "xla/mlir_hlo/mhlo/utils/type_conversion.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/type_util.h" -#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/device_description.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -1179,57 +1172,6 @@ absl::StatusOr> HloToMlir( } // namespace -bool IsHloOpSupported(const HloInstruction* instr, - se::CudaComputeCapability compute_capability) { - return !(kUnsupportedOps.contains(instr->opcode()) || - IsUnsupportedGather(instr)); -} - -bool IsHloConversionSupported(const HloComputation* computation, - se::GpuComputeCapability compute_capability) { - if (!std::holds_alternative(compute_capability)) { - // ROCM is not tested. - return false; - } - auto cuda_compute_capability = - std::get(compute_capability); - - return absl::c_all_of( - computation->instructions(), - [=](const HloInstruction* instr) { - return absl::c_all_of(instr->called_computations(), - [&](const HloComputation* called) { - return IsHloConversionSupported( - called, compute_capability); - }) && - IsHloOpSupported(instr, cuda_compute_capability); - }) && - (computation->IsFusionComputation() || - (absl::c_all_of( - computation->parameter_instructions(), [](auto* param) { - return param->shape().IsArray() && param->shape().rank() == 0; - }))); -} - -bool IsHloConversionSupported(const HloFusionAdaptor& fusion, - se::GpuComputeCapability compute_capability) { - if (!std::holds_alternative(compute_capability)) { - // ROCM is not tested. - return false; - } - auto cuda_compute_capability = - std::get(compute_capability); - - return !HloAnyOf(fusion, [=](HloInstructionAdaptor instr) { - return !absl::c_all_of(instr.instruction().called_computations(), - [&](const HloComputation* called) { - return IsHloConversionSupported( - called, compute_capability); - }) || - !IsHloOpSupported(&instr.instruction(), cuda_compute_capability); - }); -} - ValueRange ProvideParameter(const PartitionedComputation& computation, const HloInstruction* instr, int operand_index, ValueRange indices, diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h index 1a97b575fa73fd..82811ea56fa97a 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h @@ -62,19 +62,6 @@ llvm::SmallVector ProvideParameterRange( const CallTargetProvider& call_target_provider, mlir::func::FuncOp this_fn, mlir::ImplicitLocOpBuilder& builder); -// Checks whether the given HLO instruction can be converted to MLIR. -bool IsHloOpSupported(const HloInstruction* instr, - se::CudaComputeCapability compute_capability); - -// Checks whether the given HLO computation is supported by the MLIR converter: -// - all instructions in it are supported -// - the signature is supported: if the computation is not a fusion computation, -// all arguments have rank 0. -bool IsHloConversionSupported(const HloComputation* computation, - se::GpuComputeCapability compute_capability); -bool IsHloConversionSupported(const HloFusionAdaptor& fusion, - se::GpuComputeCapability compute_capability); - // Converts a function (subgraph) to an MLIR function producing one element of // the result. The function must have the correct interface. absl::Status SubgraphToMlirFunction( diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 7c983d215917cb..eab1568fa2eb38 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "mlir/Transforms/Passes.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 6c864c974aa0e7..efb13ae94e090f 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -83,9 +83,9 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/dump.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/ir_emitter_context.h" @@ -323,6 +323,7 @@ MlirFusionEmitterBase::CreateLLVMModule( // opportunities for LICM. This would not be necessary if LICM also moved // instructions over ifs. pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(CreateFlattenTensorsPass()); pm.addNestedPass(CreateVectorizeLoadsAndStoresPass()); pm.addNestedPass(CreateOptimizeLoopsPass()); pm.addNestedPass(CreateConvertPureCallOpsPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc index b76a953cf80076..4921e745b5176f 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -76,7 +76,7 @@ class DummyCopyFusionEmitter : public MlirFusionEmitterBase { const mlir_converter::PartitionedComputations& computations, const mlir_converter::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, - const HloFusionInstruction& fusion) const { + const HloFusionInstruction& fusion) const override { mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function); b.setInsertionPointToStart(entry_function.addEntryBlock()); auto thread_id = EmitThreadId(b, 0); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD deleted file mode 100644 index 608692dd0016f2..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -load("//xla:lit.bzl", "lit_test_suite") -load("//xla:xla.bzl", "xla_cc_binary") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -xla_cc_binary( - name = "mlir_fusions_opt", - srcs = ["mlir_fusions_opt.cc"], - visibility = ["//xla/service/gpu/fusions:__subpackages__"], - deps = [ - "//xla/mlir_hlo", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", - "//xla/service/gpu/fusions/transforms:passes", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:DLTIDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:NVVMDialect", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - ], -) - -lit_test_suite( - name = "tests", - srcs = glob(["*.mlir"]), - cfg = "//xla:lit.cfg.py", - tools = [ - ":mlir_fusions_opt", - "@llvm-project//llvm:FileCheck", - ], -) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc index 350b3e5e148a81..db4a93a9bbfebc 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc @@ -38,7 +38,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/affine_map_printer.h" #include "xla/tests/filecheck.h" diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index f2859bc7b6bacd..075678d2e58605 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -47,9 +47,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/fusions/reduction_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc index ba6d4ab527a143..479852851322c0 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/strings/substitute.h" #include "absl/types/span.h" #include "xla/error_spec.h" #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" @@ -55,69 +54,8 @@ class ReductionTest : public MlirEmitterTestBase { } }; -using MlirRowReductionTest = ReductionTest; -using MlirColumnReductionTest = ReductionTest; using MlirMultiRowReductionTest = ReductionTest; -constexpr std::string_view kVariadicRowReduction = R"( - Add { - scalar_lhs.0 = f32[] parameter(0) - scalar_rhs.0 = f32[] parameter(1) - scalar_lhs.1 = f32[] parameter(2) - scalar_rhs.1 = f32[] parameter(3) - add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) - add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) - ROOT t = (f32[], f32[]) tuple(add.0, add.1) - } - fused_computation { - param_0 = f32[2, 3, 2048] parameter(0) - param_1 = f32[2, 3, 2048] parameter(1) - param_2 = f32[] parameter(2) - ROOT d.1 = (f32[2, 3], f32[2, 3]) - reduce(param_0, param_1, param_2, param_2), dimensions={2}, to_apply=Add - } - ENTRY main { - a = f32[2, 3, 2048] parameter(0) - b = f32[2, 3, 2048] parameter(1) - c = f32[] constant(0) - ROOT fusion = (f32[2, 3], f32[2, 3]) fusion(a, b, c), - kind=kInput, calls=fused_computation - })"; - -constexpr std::string_view kF64RowReduction = R"( - Add { - lhs = f64[] parameter(0) - rhs = f64[] parameter(1) - ROOT add = f64[] add(lhs, rhs) - } - fused_computation { - param_0 = f64[100,128] parameter(0) - param_1 = f64[] parameter(1) - ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add - } - ENTRY main { - a = f64[100,128] parameter(0) - c = f64[] constant(0) - ROOT fusion = f64[100] fusion(a, c), kind=kInput, calls=fused_computation - })"; - -constexpr auto kRowReductionMinorAndMajor = R"( - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[7,100,128] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={0,2}, to_apply=Add - } - ENTRY main { - a = f32[7,100,128] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation - })"; - constexpr auto kMultiRowReductionX8 = R"( Add { lhs = f32[] parameter(0) @@ -180,181 +118,6 @@ constexpr auto kMultiRowReductionX16VectorX2 = R"( ROOT fusion = pred[76800] fusion(p0), kind=kInput, calls=fusion })"; -constexpr std::string_view kRowReductionSideOutput = R"( - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[8,2048] parameter(0) - param_1 = f32[] parameter(1) - exp = f32[8,2048] exponential(param_0) - reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add - ROOT t = (f32[8], f32[8,2048]) tuple(reduce, exp) - } - ENTRY main { - a = f32[8,2048] parameter(0) - c = f32[] constant(0) - ROOT fusion = (f32[8], f32[8,2048]) fusion(a, c), kind=kInput, - calls=fused_computation - })"; - -TEST_F(MlirRowReductionTest, VariadicRowReductionIndexing) { - auto fusion = GetEmitter(kVariadicRowReduction); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - {2, 3, 2048})); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {2, 3})); -} - -TEST_F(MlirRowReductionTest, VariadicRowReductionCorrectness) { - EXPECT_TRUE(RunAndCompareNoHloPasses(kVariadicRowReduction, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, RowReduceEpilogue) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[8,2048] parameter(0) - param_1 = f32[] parameter(1) - reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add - ROOT log = f32[8] log(reduce) - } - ENTRY main { - a = f32[8,2048] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[8] fusion(a, c), kind=kInput, calls=fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: pure_call @Add_add - // CHECK: shuffle_reduce - // CHECK: allocate_shared - // CHECK: sync_threads - // CHECK: shuffle_reduce - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, RowReduceMOFEpilogue) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - Mul { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT mul = f32[] multiply(lhs, rhs) - } - fused_computation { - param_0 = f32[8,1024] parameter(0) - param_1 = f32[] parameter(1) - reduce1 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add - reduce2 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Mul - log = f32[8] log(reduce1) - abs = f32[8] abs(reduce1) - neg = f32[8] negate(reduce2) - ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs) - } - ENTRY main { - a = f32[8,1024] parameter(0) - c = f32[] constant(0) - ROOT fusion = (f32[8], f32[8], f32[8]) fusion(a, c), kind=kInput, - calls=fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: pure_call @Add_add - // CHECK-DAG: shuffle_reduce @Add_add - // CHECK-DAG: pure_call @Mul_mul - // CHECK-DAG: shuffle_reduce @Mul_mul - // CHECK: allocate_shared - // CHECK: allocate_shared - // CHECK: sync_threads - // CHECK-DAG: shuffle_reduce @Add_add - // CHECK-DAG: shuffle_reduce @Mul_mul - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, RowReduceMOFGroups) { - constexpr auto kHloString = R"( - %add_f32 { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add = f32[] add(%x, %y) - } - - %fused_computation { - %param0 = f32[1024] parameter(0) - %param1 = f32[1024] parameter(1) - %constant0 = f32[] constant(0) - %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 - %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 - ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2) - } - - ENTRY %cluster { - %param0 = f32[1024] parameter(0) - %param1 = f32[1024] parameter(1) - ROOT %fusion = (f32[], f32[]) - fusion(%param0, %param1), kind=kInput, calls=%fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: scf.index_switch %block_id_y - // CHECK: case 1 { - // CHECK: default { - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, F64RowReductionIndexing) { - auto fusion = GetEmitter(kF64RowReduction); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - /*shape=*/{100, 128})); - TF_EXPECT_OK( - TestBijection(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), - /*shape=*/{100})); -} - -TEST_F(MlirRowReductionTest, F64RowReductionIr) { - // This reduction is small enough not to require shared memory. - TF_ASSERT_OK(EmitAndCheckIR(kF64RowReduction, R"( - // CHECK-NOT: allocate_shared - )")); -} - -TEST_F(MlirRowReductionTest, F64RowReductionCorrectness) { - EXPECT_TRUE(RunAndCompareNoHloPasses(kF64RowReduction, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, RowReductionMinorAndMajorIndexing) { - auto fusion = GetEmitter(kRowReductionMinorAndMajor); - - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - /*shape=*/{7, 100, 128})); - TF_EXPECT_OK( - TestBijection(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), - /*shape=*/{100})); -} - -TEST_F(MlirRowReductionTest, RowReductionMinorAndMajorCorrectness) { - EXPECT_TRUE( - RunAndCompareNoHloPasses(kRowReductionMinorAndMajor, ErrorSpec{1e-3})); -} - TEST_F(MlirMultiRowReductionTest, MultiRowReductionIndexing) { auto fusion = GetEmitter(kMultiRowReductionX8); @@ -380,207 +143,6 @@ TEST_F(MlirMultiRowReductionTest, MultiRowReductionCorrectness) { EXPECT_TRUE(RunAndCompareNoHloPasses(kMultiRowReductionX8, ErrorSpec{1e-3})); } -TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[100,568] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add - } - ENTRY main { - a = f32[100,568] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation - })"; - TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 3]> - // CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512), domain: d0 in [0, 1], d1 in [0, 255]> - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - // CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] - // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] - // CHECK-NOT: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]], %thread_id_x)[%[[I]]] - // CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]]) - // CHECK: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]], %thread_id_x) - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirMultiRowReductionTest, NonTrivialEpilogueCorrectness) { - constexpr auto kHloString = R"( - HloModule module - add { - p0 = f64[] parameter(0) - p1 = f64[] parameter(1) - ROOT add = f64[] add(p0, p1) - } - fusion { - %p0 = f64[4] parameter(0) - %p1 = f64[4] parameter(1) - %c0 = f64[] constant(-inf) - %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add - %bc0 = f64[4] broadcast(reduce0), dimensions={} - %compare0 = pred[4] compare(p1, bc0), direction=EQ - %c1 = f64[] constant(0) - %bc1 = f64[4] broadcast(c1), dimensions={} - %select.3.1 = f64[4] select(compare0, p0, bc1) - %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add - %convert0 = f64[4] convert(compare0) - %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add - ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2) - } - ENTRY main { - %p0 = f64[4] parameter(0) - %p1 = f64[4] parameter(1) - ROOT %fusion = (f64[], f64[], f64[]) fusion(%p0, %p1), kind=kInput, - calls=fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, SideOutputIndexing) { - auto fusion = GetEmitter(kRowReductionSideOutput); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - {8, 2048})); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {8})); - TF_EXPECT_OK( - TestBijection(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context_), - {8, 2048})); // Side output. -} - -TEST_F(MlirRowReductionTest, SideOutputIr) { - TF_ASSERT_OK(EmitAndCheckIR(kRowReductionSideOutput, R"( - // CHECK: @fused_computation - // CHECK: scf.for - // CHECK: scf.for - // CHECK: %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fused_computation_exp - // CHECK-NEXT: tensor.insert %[[SIDE_OUTPUT]] - )")); -} - -TEST_F(MlirRowReductionTest, SideOutputCorrectness) { - EXPECT_TRUE( - RunAndCompareNoHloPasses(kRowReductionSideOutput, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, UnsignedSideOutputCorrectness) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = u32[] parameter(0) - rhs = u32[] parameter(1) - ROOT add = u32[] add(lhs, rhs) - } - fused_computation { - param_0 = u32[8,2048] parameter(0) - param_1 = u32[] parameter(1) - add = u32[8,2048] add(param_0, param_0) - reduce = u32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add - ROOT t = (u32[8], u32[8,2048]) tuple(reduce, add) - } - ENTRY main { - a = u32[8,2048] parameter(0) - c = u32[] constant(0) - ROOT fusion = (u32[8], u32[8,2048]) fusion(a, c), kind=kInput, - calls=fused_computation - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, BroadcastSideOutputCorrectness) { - constexpr auto kHloString = R"( - %add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - %fusion { - %p0 = f32[6,6] parameter(0) - %c0 = f32[] constant(0) - %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add - %broadcast = f32[6,6] broadcast(%reduce), dimensions={} - ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce) - } - ENTRY main { - %p0 = f32[6,6] parameter(0) - ROOT %fusion = (f32[6,6], f32[]) fusion(%p0), kind=kInput, calls=%fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, VariadicMOFCorrectness) { - constexpr auto kHloString = R"( - %reducer1 { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - %reducer2 { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - p2 = f32[] parameter(2) - p3 = f32[] parameter(3) - add0 = f32[] add(p0, p2) - add1 = f32[] add(p1, p3) - ROOT tuple = (f32[], f32[]) tuple(add0, add1) - } - %fusion { - %p0 = f32[6,6] parameter(0) - %c0 = f32[] constant(0) - %neg = f32[6,6] negate(%p0) - %reduce1 = f32[] reduce(%neg, %c0), dimensions={0,1}, to_apply=%reducer1 - %reduce2 = (f32[], f32[]) reduce(%p0, %p0, %c0, %c0), dimensions={0,1}, to_apply=%reducer2 - ROOT %tuple = (f32[], (f32[], f32[]), f32[6,6]) tuple(%reduce1, %reduce2, %neg) - } - ENTRY main { - %p0 = f32[6,6] parameter(0) - ROOT %fusion = (f32[], (f32[], f32[]), f32[6,6]) fusion(%p0), kind=kInput, calls=%fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, OutputLayoutCorrectness) { - constexpr std::string_view kHloString = R"( - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - fusion { - %input = f32[17,19,127] parameter(0) - %c0 = f32[] constant(0) - ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add - } - - ENTRY entry { - %input = f32[17,19,127] parameter(0) - ROOT %fusion = f32[17,19]{0,1} fusion(%input), kind=kInput, calls=fusion - })"; - - auto fusion = GetEmitter(kHloString); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - {17, 19, 127})); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {17, 19})); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirMultiRowReductionTest, TwoGroups) { auto module = ParseAndReturnVerifiedModule(R"( add { @@ -642,225 +204,6 @@ TEST_F(MlirMultiRowReductionTest, OneGroup) { EXPECT_THAT(mlir_fusion.GetGroups().grouped_roots, SizeIs(1)); } -constexpr absl::string_view kColumnVectorizationTemplate = R"( - add { - b = $0[] parameter(1) - a = $0[] parameter(0) - ROOT out = $0[] add(a, b) - } - fusion { - %p0 = $0[192,64,1536] parameter(0) - %p1 = $0[] parameter(1) - ROOT reduce = $0[192,1536] reduce(p0, p1), dimensions={1}, to_apply=add - } - ENTRY entry { - %p0 = $0[192,64,1536] parameter(0) - %p1 = $0[] parameter(1) - ROOT %fusion = $0[192,1536] fusion(p0, p1), kind=kInput, calls=fusion - })"; - -TEST_F(MlirColumnReductionTest, ColumnReduction) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[13,1051,321] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=Add - } - ENTRY main { - a = f32[13,1051,321] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[13,321] fusion(a, c), kind=kInput, calls=fused_computation - })"; - - auto module = ParseAndReturnVerifiedModule(kHloString).value(); - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - MlirColumnReductionFusion fusion(analysis); - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( - d3 floordiv 11, - d0 floordiv 32 + s0 * 32, - (d3 mod 11) * 32 + d0 mod 32 - ) - domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 142] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 32] - s1 in [0, 0] - (d3 mod 11) * 32 + d0 mod 32 in [0, 320] - d0 floordiv 32 + s0 * 32 in [0, 1050] - )")); - EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0] -> ( - d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 - ) - domain: - d0 in [0, 992] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 142] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - (d3 mod 11) * 32 + d0 floordiv 32 in [0, 320] - d0 mod 32 in [0, 0] - )")); - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: xla_gpu.pure_call @Add_add - // CHECK: allocate_shared - // CHECK: tensor.insert - // CHECK: sync_threads - // CHECK: predicated_extract - // CHECK: shuffle_reduce - // CHECK: predicated_insert - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirColumnReductionTest, SmallColumnReduction) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[3,128,4] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[3,4] reduce(param_0, param_1), dimensions={1}, to_apply=Add - } - ENTRY main { - a = f32[3,128,4] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[3,4] fusion(a, c), kind=kInput, calls=fused_computation - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirColumnReductionTest, MixedIndexing) { - constexpr auto kHloString = R"( - HloModule module - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - fusion { - %param_0 = f32[64,128] parameter(0) - %constant_0 = f32[] constant(0) - %reduce.1 = f32[128] reduce(f32[64,128] %param_0, f32[] %constant_0), dimensions={0}, to_apply=%add - %neg = f32[64,128] negate(f32[64,128] %param_0) - %bitcast = f32[8,8,128]{2,1,0} bitcast(f32[64,128] %neg) - %reduce.2 = f32[128] reduce(f32[8,8,128]{2,1,0} %bitcast, f32[] %constant_0), dimensions={0,1}, to_apply=%add - ROOT %tuple.12 = (f32[128], f32[128]) tuple(f32[128] %reduce.1, f32[128] %reduce.2) - } - ENTRY entry { - %param_0 = f32[64,128] parameter(0) - ROOT %fusion = (f32[128], f32[128]) fusion(%param_0), kind=kInput, calls=fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirColumnReductionTest, ColumnReductionVectorizationCorrectness) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[2048,16384] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[16384] reduce(param_0, param_1), dimensions={0}, to_apply=Add - } - ENTRY main { - a = f32[2048,16384] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[16384] fusion(a, c), kind=kInput, calls=fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: vector<2xf32> - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirColumnReductionTest, ColumnReductionVectorization_v4) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - Add { - lhs = s16[] parameter(0) - rhs = s16[] parameter(1) - ROOT add = s16[] add(lhs, rhs) - } - fused_computation { - param_0 = s16[2048,16384] parameter(0) - param_1 = s16[] parameter(1) - ROOT reduce = s16[16384] reduce(param_0, param_1), dimensions={0}, to_apply=Add - } - ENTRY main { - a = s16[2048,16384] parameter(0) - c = s16[] constant(0) - ROOT fusion = s16[16384] fusion(a, c), kind=kInput, calls=fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: vector<4xi16> - )")); - // We don't use RunAndCompareNoHloPasses because the interpreter is too slow - // for this input. -} - -TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) { - const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f32"); - auto fusion = GetEmitter(hlo_string); - EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing( - 0, 0, &mlir_context_)), - ElementsAre(2 /* major reduced */, 2 /* vector size */)); -} - -TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) { - const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f16"); - auto fusion = GetEmitter(hlo_string); - EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing( - 0, 0, &mlir_context_)), - ElementsAre(2 /* major reduced */, 4 /* vector size */)); -} - -TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_f64) { - // Verifies that we do not use the vectorized indexing for f64. - const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f64"); - auto fusion = GetEmitter(hlo_string); - EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing( - 0, 0, &mlir_context_)), - ElementsAre(2 /* major reduced */, 1 /* vector size */)); -} - -TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) { - // Verifies that we do not use the vectorized indexing for complex types. - const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "c64"); - auto fusion = GetEmitter(hlo_string); - EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing( - 0, 0, &mlir_context_)), - ElementsAre(2 /* major reduced */, 1 /* vector size */)); -} - TEST_F(MlirMultiRowReductionTest, VectorizedX4Indexing) { auto fusion = GetEmitter(kMultiRowReductionX2VectorX4); @@ -884,61 +227,6 @@ TEST_F(MlirMultiRowReductionTest, VectorizedX4Correctness) { RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3})); } -TEST_F(MlirRowReductionTest, LargeToUnit) { - // Regression test for a bug where not all threads in the warp produced a - // valid value for the final warp shuffle. - constexpr auto kHloString = R"( - and { - p0 = pred[] parameter(0) - p1 = pred[] parameter(1) - ROOT and = pred[] and(p0, p1) - } - - %fused_reduce { - c1 = pred[] constant(true) - p0 = pred[10000] broadcast(c1), dimensions={} - ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, MOFTwoVariadic) { - // Regression test for a compilation crash with a MOF with two variadic - // reductions. - constexpr auto kHloString = R"( - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - p2 = f32[] parameter(2) - p3 = f32[] parameter(3) - a = f32[] add(p0, p2) - b = f32[] add(p1, p3) - ROOT out = (f32[], f32[]) tuple(a, b) - } - - fused_reduce { - p0 = f32[3,2] parameter(0) - p1 = f32[3,2] parameter(1) - c0 = f32[] constant(0) - iota0 = f32[3,2] iota(), iota_dimension=1 - iota1 = f32[3,2] iota(), iota_dimension=1 - reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1}, - to_apply=add - reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1}, - to_apply=add - ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1) - } - - ENTRY main { - p0 = f32[3,2] parameter(0) - p1 = f32[3,2] parameter(1) - ROOT fusion = ((f32[3], f32[3]), (f32[3], f32[3])) fusion(p0, p1), - kind=kInput, calls=fused_reduce - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc index 14ec832be1a9f4..85e1e504e79d73 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc @@ -39,9 +39,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/primitive_util.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" diff --git a/third_party/xla/xla/service/gpu/fusions/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/tests/BUILD new file mode 100644 index 00000000000000..d3e3b665e75d3b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/BUILD @@ -0,0 +1,19 @@ +load("//xla:lit.bzl", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "tests", + srcs = glob(["**/*.hlo"]), + cfg = "//xla:lit.cfg.py", + default_tags = ["requires-gpu-sm80-only"], + tools = [ + "//xla/service/gpu/fusions/tools:fusion_to_mlir", + "//xla/service/gpu/fusions/tools:mlir_fusions_opt", + "//xla/service/gpu/fusions/tools:test_correctness", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo new file mode 100644 index 00000000000000..d0dd73d59081cd --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo @@ -0,0 +1,19 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline | FileCheck %s --dump-input=always +// RUN: test_correctness %s --bijection_outputs=broadcast + +bcast { + zero = bf16[] constant(0) + ROOT broadcast = bf16[2,16,48]{2,1,0} broadcast(zero), dimensions={} +} + +// CHECK-DAG: #[[MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 6) +// CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 48) mod 16) +// CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 48) +// CHECK: func.func @main(%[[ARG0:.*]]: tensor<2x16x48xbf16> +// CHECK: %[[THREAD_ID:.*]] = gpu.thread_id +// CHECK: %[[BLOCK_ID:.*]] = gpu.block_id +// CHECK: %[[I0:.*]] = xla_gpu.apply_indexing #[[MAP0]] +// CHECK: %[[I1:.*]] = xla_gpu.apply_indexing #[[MAP1]] +// CHECK: %[[I2:.*]] = xla_gpu.apply_indexing #[[MAP2]] +// CHECK: %[[CST:.*]] = arith.constant 0.000 +// CHECK: %[[INSERTED:.*]] = tensor.insert %[[CST]] into %[[ARG0]][%[[I0]], %[[I1]], %[[I2]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo new file mode 100644 index 00000000000000..2f6a5aa41c664f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo @@ -0,0 +1,28 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +%fused_computation { + %p0 = f32[2]{0} parameter(0) + %p1 = f32[2]{0} parameter(1) + %p2 = c64[2]{0} parameter(2) + %complex = c64[2] complex(%p0, %p1) + %add = c64[2] add(%complex, %p2) + %cst = c64[2]{0} constant({(2.0, 0.0), (0.0, 2.0)}) + ROOT %mul = c64[2] multiply(%add, %cst) +} + +// CHECK: func.func @main +// CHECK-NEXT: gpu.thread_id +// CHECK-NEXT: pure_call @fused_computation_mul +// CHECK-NEXT: tensor.insert +// CHECK-NEXT: return + +// CHECK: func.func private @fused_computation_mul +// CHECK-NEXT: arith.constant +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: complex.create +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: complex.add +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: complex.mul diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo new file mode 100644 index 00000000000000..f9976a51a3994c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo @@ -0,0 +1,27 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline | FileCheck %s --dump-input=always +// RUN: test_correctness %s + +%fused_computation { + in = c64[2,3] parameter(0) + updates = c64[2,2] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + updated = c64[2,3] dynamic-update-slice(in, updates, i0, i1) + // Add some random epilogue to prevent in-place DUS from triggering. + ROOT negated = c64[2,3] negate(updated) +} + +// CHECK: func.func @main +// CHECK-SAME: %[[IN:.*]]: tensor<2x3xcomplex> {xla.slice_index = 0 +// CHECK-SAME: %[[UPDATES:.*]]: tensor<2x2xcomplex> {xla.slice_index = 1 +// CHECK-SAME: %[[I0:.*]]: tensor {xla.slice_index = 2 +// CHECK-SAME: %[[I1:.*]]: tensor {xla.slice_index = 3 + +// No need to load i0, since its value is irrelevant. +// CHECK-NOT: tensor.extract %[[I0]] +// CHECK: tensor.extract %[[I1]] +// CHECK-NOT: tensor.extract %[[I0]] +// CHECK: scf.if +// CHECK: tensor.extract %[[UPDATES]] +// CHECK: } else { +// CHECK: tensor.extract %[[IN]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo new file mode 100644 index 00000000000000..7d7fdf79fe2b55 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +fused_computation { + param0 = f64[] parameter(0) + param1 = f64[] parameter(1) + + minimum = f64[] minimum(param0, param1) + maximum = f64[] maximum(param0, param1) + ROOT tuple = (f64[], f64[]) tuple(minimum, maximum) +} + +// CHECK: func.func @main +// CHECK: xla_gpu.pure_call @fused_computation_tuple +// CHECK: func.func private @fused_computation_tuple +// CHECK-DAG: arith.minimumf +// CHECK-DAG: arith.maximumf diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo new file mode 100644 index 00000000000000..017019e436d125 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo @@ -0,0 +1,15 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +%fused_computation { + p0 = s8[1000] parameter(0) + p1 = s8[1000] parameter(1) + cvt0 = pred[1000] convert(p0) + cvt1 = pred[1000] convert(p1) + ROOT mul = pred[1000] multiply(cvt0, cvt1) +} + +// CHECK: %[[A:.*]] = arith.cmpi ne, +// CHECK: %[[B:.*]] = arith.cmpi ne, +// CHECK: %[[R:.*]] = arith.andi %[[A]], %[[B]] +// CHECK: arith.extui %[[R]] : i1 to i8 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_not.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_not.hlo new file mode 100644 index 00000000000000..0597b3590cbbd4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_not.hlo @@ -0,0 +1,13 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +%fused_computation { + p0 = s8[1000] parameter(0) + cvt = pred[1000] convert(p0) + ROOT not = pred[1000] not(cvt) +} + +// CHECK: %[[C0:.*]] = arith.constant 0 : i8 +// CHECK: %[[NONZERO:.*]] = arith.cmpi eq, {{.*}}, %[[C0]] +// CHECK: %[[CVT:.*]] = arith.extui %[[NONZERO]] : i1 to i8 +// CHECK: return %[[CVT]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo new file mode 100644 index 00000000000000..f77a3c38cd8ded --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo @@ -0,0 +1,21 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-to-inline | FileCheck %s +// RUN: test_correctness %s + +fused_computation { + param0 = f64[8] parameter(0) + param1 = f64[8] parameter(1) + + minimum = f64[8] minimum(param0, param1) + maximum = f64[8] maximum(param0, param1) + bc = f64[2, 4] bitcast(maximum) + ROOT tuple = (f64[8], f64[2,4]) tuple(minimum, bc) +} + +// CHECK: #[[MAJOR:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 4), +// CHECK: #[[MINOR:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 4), + +// CHECK: %[[TID:.*]] = gpu.thread_id +// CHECK-DAG: %[[MAJOR_IDX:.*]] = xla_gpu.apply_indexing #[[MAJOR]] +// CHECK-DAG: %[[MINOR_IDX:.*]] = xla_gpu.apply_indexing #[[MINOR]] +// CHECK-DAG: tensor.insert {{.*}}[%[[MAJOR_IDX]], %[[MINOR_IDX]]] +// CHECK-DAG: tensor.insert {{.*}}[%[[TID]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo new file mode 100644 index 00000000000000..ac5f26682ec356 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo @@ -0,0 +1,35 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +add { + scalar_lhs.0 = f32[] parameter(0) + scalar_lhs.1 = f32[] parameter(1) + scalar_rhs.0 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add = f32[] add(scalar_lhs.0, scalar_rhs.0) + mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add, mul) +} + +fused_computation { + param_0 = f32[3,4,5]{2,1,0} parameter(0) + param_1 = f32[3,4,5]{2,1,0} parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[4] parameter(3) + reduce = (f32[4], f32[4]) reduce(f32[3,4,5]{2,1,0} param_0, + f32[3,4,5]{2,1,0} %param_1, f32[] param_2, f32[] param_2), + dimensions={0,2}, to_apply=add + log = f32[4] log(param_3) + ROOT tuple = ((f32[4], f32[4]), f32[4]) tuple(reduce, log) +} + +// CHECK: @main +// CHECK: %[[R0:.*]], %[[R1:.*]], %[[R2:.*]] = xla_gpu.pure_call @fused_computation_tuple +// CHECK-DAG: tensor.insert %[[R0]] +// CHECK-DAG: tensor.insert %[[R1]] +// CHECK-DAG: tensor.insert %[[R2]] + +// CHECK: @fused_computation_tuple +// CHECK: %[[REDUCTION:.*]]:2 = scf.for +// CHECK: %[[LOG:.*]] = math.log +// CHECK: return %[[REDUCTION]]#0, %[[REDUCTION]]#1, %[[LOG]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/two_users.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/two_users.hlo new file mode 100644 index 00000000000000..b16b005897b3f7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/two_users.hlo @@ -0,0 +1,30 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +// We have two users of add and sub, but they use consistent indexing, so they +// can be generated as a single function (fused_computation_atan2). +%fused_computation { + %p0 = f32[2] parameter(0) + %p1 = f32[2] parameter(1) + %add = f32[2] add(%p0, %p1) + %sub = f32[2] subtract(%p0, %p1) + %mul = f32[2] multiply(%add, %sub) + %div = f32[2] divide(%add, %sub) + ROOT %atan2 = f32[2] atan2(%mul, %div) +} + +// CHECK: func.func @main +// CHECK-NEXT: gpu.thread_id +// CHECK-NEXT: pure_call @fused_computation_atan2 +// CHECK-NEXT: tensor.insert +// CHECK-NEXT: return + +// CHECK: func.func private @fused_computation_atan2 +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: addf +// CHECK-NEXT: subf +// CHECK-NEXT: mulf +// CHECK-NEXT: divf +// CHECK-NEXT: atan2 +// CHECK-NEXT: return \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo new file mode 100644 index 00000000000000..8ac83ced80af31 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo @@ -0,0 +1,31 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +add { + scalar_lhs.0 = f32[] parameter(0) + scalar_lhs.1 = f32[] parameter(1) + scalar_rhs.0 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add = f32[] add(scalar_lhs.0, scalar_rhs.0) + mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add, mul) +} + +fused_computation { + param_0 = f32[3,4,5] parameter(0) + param_1 = f32[3,4,5] parameter(1) + c = f32[] constant(0) + ROOT d.1 = (f32[4], f32[4]) reduce(param_0, param_1, c, c), dimensions={0,2}, + to_apply=add +} + +// CHECK: func @main( +// CHECK: %[[TID_X:.*]] = gpu.thread_id x +// CHECK: %[[SCALARS_0:.*]], %[[SCALARS_1:.*]] = xla_gpu.pure_call @fused_computation_d_1 +// CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[SCALARS_0]] into %{{.*}}[%[[TID_X]]] +// CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[SCALARS_1]] into %{{.*}}[%[[TID_X]]] +// CHECK: return %[[INSERTED_1]], %[[INSERTED_2]] + +// CHECK: func private @fused_computation_d_1 +// CHECK: %[[RET:.*]]:2 = func.call @add_t +// CHECK: yield %[[RET]]#0, %[[RET]]#1 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo new file mode 100644 index 00000000000000..75510894bcadd2 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo @@ -0,0 +1,10 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline -xla-gpu-test-vectorize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=neg:0 --bijection_outputs=neg + +neg { + %input = f32[20] parameter(0) + ROOT neg = f32[20] negate(%input) +} + +// CHECK-NOT: vector. +// CHECK: tensor.extract diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo new file mode 100644 index 00000000000000..549231c7aa4447 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo @@ -0,0 +1,13 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline -xla-gpu-test-vectorize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=neg:0 --bijection_outputs=neg + +neg { + %input = f32[20,40,300] parameter(0) + ROOT neg = f32[20,40,300] negate(%input) +} + +// CHECK-NOT: tensor. +// CHECK: vector.transfer_read {{.*}} vector<4xf32> +// CHECK-NOT: tensor. +// CHECK: vector.transfer_write {{.*}} vector<4xf32> +// CHECK-NOT: tensor. \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo new file mode 100644 index 00000000000000..1646aded57fdf4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo @@ -0,0 +1,19 @@ +// RUN: test_correctness %s --bijection_inputs=reduce.1:0 --bijection_inputs=reduce.2:0 --bijection_outputs=reduce.1 --bijection_outputs=reduce.2 + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +fusion { + %param_0 = f32[64,128] parameter(0) + %constant_0 = f32[] constant(0) + %reduce.1 = f32[128] reduce(param_0, constant_0), dimensions={0}, + to_apply=%add + %neg = f32[64,128] negate(param_0) + %bitcast = f32[8,8,128] bitcast(neg) + %reduce.2 = f32[128] reduce(bitcast, constant_0), dimensions={0,1}, + to_apply=%add + ROOT %tuple = (f32[128], f32[128]) tuple(reduce.1, reduce.2) +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo new file mode 100644 index 00000000000000..e7ae070f2938e7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[13,1051,321] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=add +} + +// CHECK: xla_gpu.pure_call @add_add +// CHECK: allocate_shared +// CHECK: tensor.insert +// CHECK: sync_threads +// CHECK: predicated_extract +// CHECK: shuffle_reduce +// CHECK: predicated_insert diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo new file mode 100644 index 00000000000000..958b391179001f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo @@ -0,0 +1,13 @@ +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[3,128,4] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[3,4] reduce(param_0, c0), dimensions={1}, to_apply=add +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo new file mode 100644 index 00000000000000..a2a22363108b10 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = c64[] parameter(0) + rhs = c64[] parameter(1) + ROOT add = c64[] add(lhs, rhs) +} + +fused_computation { + param_0 = c64[128,64] parameter(0) + c0 = c64[] constant((0, 0)) + ROOT reduce = c64[64] reduce(param_0, c0), dimensions={0}, + to_apply=add +} + +// CHECK-NOT: vector< \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo new file mode 100644 index 00000000000000..660664bba95f37 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f64[] parameter(0) + rhs = f64[] parameter(1) + ROOT add = f64[] add(lhs, rhs) +} + +fused_computation { + param_0 = f64[128,64] parameter(0) + c0 = f64[] constant(0) + ROOT reduce = f64[64] reduce(param_0, c0), dimensions={0}, + to_apply=add +} + +// CHECK-NOT: vector< \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo new file mode 100644 index 00000000000000..a142ad4a164100 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[2048,64] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[64] reduce(param_0, c0), dimensions={0}, + to_apply=add +} + +// CHECK: vector<2xf32> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo new file mode 100644 index 00000000000000..81da088974132f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = s16[] parameter(0) + rhs = s16[] parameter(1) + ROOT add = s16[] add(lhs, rhs) +} + +fused_computation { + param_0 = s16[256,128] parameter(0) + c0 = s16[] constant(0) + ROOT reduce = s16[128] reduce(param_0, c0), dimensions={0}, + to_apply=add +} + +// CHECK: vector<4xi16> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo new file mode 100644 index 00000000000000..f8a9e86ff48f65 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -xla-erase-dead-functions -inline -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[8,2048] parameter(0) + param_1 = f32[] parameter(1) + reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add + ROOT log = f32[8] log(reduce) +} + +// CHECK: shuffle_reduce +// CHECK: allocate_shared +// CHECK: sync_threads +// CHECK: shuffle_reduce +// CHECK-NEXT: %[[OUT:.*]] = math.log +// CHECK: predicated_insert %[[OUT]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo new file mode 100644 index 00000000000000..bc841743d9d3f8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo @@ -0,0 +1,43 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce1:0 --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 --bijection_outputs=reduce2 + +add { + p0 = f64[] parameter(0) + p1 = f64[] parameter(1) + ROOT add = f64[] add(p0, p1) +} + +// This fusion is valid, but we can't efficiently codegen it. +fusion { + %p0 = f64[4] parameter(0) + %p1 = f64[4] parameter(1) + %c0 = f64[] constant(-inf) + %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add + %bc0 = f64[4] broadcast(reduce0), dimensions={} + %compare0 = pred[4] compare(p1, bc0), direction=EQ + %c1 = f64[] constant(0) + %bc1 = f64[4] broadcast(c1), dimensions={} + %select.3.1 = f64[4] select(compare0, p0, bc1) + %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add + %convert0 = f64[4] convert(compare0) + %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add + ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2) +} + +// We read all of %p1 once from each thread, and then read one element again. +// CHECK: func.func @main +// CHECK-SAME: , %[[P1:.*]]: tensor<4xf64> {xla.slice_index = 1 : index} +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CST0:.*]] = arith.constant 0xFFF0000000000000 +// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x + +// reduce0 in the context of reduce2 and reduce1's prologue: +// CHECK: scf.for %[[I:.*]] = %[[C0]] +// CHECK-NEXT: tensor.extract %[[P1]][%[[I]]] +// CHECK-NEXT: addf +// CHECK-NEXT: yield + +// reduce0 again, in the context of its status as a fusion hero: +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[P1]][%[[TID_X]]] +// CHECK: %[[ADDED:.*]] = arith.addf %[[CST0]], %[[EXTRACTED]] +// CHECK: shuffle_reduce @add_add(%[[ADDED]]) to 2 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo new file mode 100644 index 00000000000000..ee155c86e2bb54 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo @@ -0,0 +1,15 @@ +// Regression test for a bug where not all threads in the warp produced a valid +// value for the final warp shuffle. +// RUN: test_correctness %s + +and { + p0 = pred[] parameter(0) + p1 = pred[] parameter(1) + ROOT and = pred[] and(p0, p1) +} + +fused_reduce { + c1 = pred[] constant(true) + p0 = pred[10000] broadcast(c1), dimensions={} + ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo new file mode 100644 index 00000000000000..102e32b861e648 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +fusion { + %input = f32[17,19,127] parameter(0) + %c0 = f32[] constant(0) + // The output is physically transposed. + ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add +} + +// CHECK: xla_gpu.predicated_insert {{.*}} : tensor<17x19xf32, dense<[0, 1]> : tensor<2xi64>> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo new file mode 100644 index 00000000000000..c9481f35bf7fe3 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo @@ -0,0 +1,20 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -inline -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[7,100,128] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={0,2}, to_apply=add +} + +// Our codegen doesn't support parallelizing the major reduction dimension. In +// principle, this could be done via shared memory. +// CHECK-NOT: allocate_shared +// CHECK: shuffle_reduce +// CHECK-NOT: allocate_shared diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo new file mode 100644 index 00000000000000..315d604b563ebe --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo @@ -0,0 +1,40 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -xla-erase-dead-functions -inline -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce1:0 --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 --bijection_outputs=reduce2 + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +fused_computation { + param_0 = f32[8,1024] parameter(0) + c0 = f32[] constant(0) + c1 = f32[] constant(1) + reduce1 = f32[8] reduce(param_0, c0), dimensions={1}, to_apply=add + reduce2 = f32[8] reduce(param_0, c1), dimensions={1}, to_apply=mul + log = f32[8] log(reduce1) + abs = f32[8] abs(reduce1) + neg = f32[8] negate(reduce2) + ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs) +} + +// CHECK-DAG: shuffle_reduce @add_add +// CHECK-DAG: shuffle_reduce @mul_mul +// CHECK: allocate_shared +// CHECK: allocate_shared +// CHECK: sync_threads +// CHECK-DAG: %[[ADDED:.*]] = xla_gpu.shuffle_reduce @add_add +// CHECK-DAG: %[[MULTIPLIED:.*]] = xla_gpu.shuffle_reduce @mul_mul +// CHECK-DAG: %[[LOG:.*]] = math.log %[[ADDED]] +// CHECK-DAG: %[[ABS:.*]] = math.absf %[[ADDED]] +// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[MULTIPLIED]] +// CHECK-DAG: xla_gpu.predicated_insert %[[LOG]] +// CHECK-DAG: xla_gpu.predicated_insert %[[ABS]] +// CHECK-DAG: xla_gpu.predicated_insert %[[NEG]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo new file mode 100644 index 00000000000000..48a20334c7ea03 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo @@ -0,0 +1,26 @@ +// RUN: test_correctness %s + +%reducer1 { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +%reducer2 { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + p2 = f32[] parameter(2) + p3 = f32[] parameter(3) + add0 = f32[] add(p0, p2) + add1 = f32[] add(p1, p3) + ROOT tuple = (f32[], f32[]) tuple(add0, add1) +} + +%fusion { + %p0 = f32[6,6] parameter(0) + %c0 = f32[] constant(0) + %neg = f32[6,6] negate(%p0) + %reduce1 = f32[] reduce(%neg, %c0), dimensions={0,1}, to_apply=%reducer1 + %reduce2 = (f32[], f32[]) reduce(%p0, %p0, %c0, %c0), dimensions={0,1}, to_apply=%reducer2 + ROOT %tuple = (f32[], (f32[], f32[]), f32[6,6]) tuple(%reduce1, %reduce2, %neg) +} \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo new file mode 100644 index 00000000000000..6d47fc6b842b9f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo @@ -0,0 +1,26 @@ +// Regression test for a compilation crash with a MOF with two variadic +// reductions. +// RUN: test_correctness %s + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + p2 = f32[] parameter(2) + p3 = f32[] parameter(3) + a = f32[] add(p0, p2) + b = f32[] add(p1, p3) + ROOT out = (f32[], f32[]) tuple(a, b) +} + +fused_reduce { + p0 = f32[3,2] parameter(0) + p1 = f32[3,2] parameter(1) + c0 = f32[] constant(0) + iota0 = f32[3,2] iota(), iota_dimension=1 + iota1 = f32[3,2] iota(), iota_dimension=1 + reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1}, + to_apply=add + reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1}, + to_apply=add + ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1) +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo new file mode 100644 index 00000000000000..30202d0f2613b8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo @@ -0,0 +1,31 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -inline -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} +fused_computation { + param_0 = f32[100,568] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={1}, to_apply=add +} + +// CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 3]> +// CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512), domain: d0 in [0, 1], d1 in [0, 255]> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + +// The full loop without bounds checks: +// CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] +// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK-NOT: scf.if +// CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]], %thread_id_x)[%[[I]]] + +// The tail loop: +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]]) +// CHECK: scf.if +// CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]], %thread_id_x) diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo new file mode 100644 index 00000000000000..a7e64151affdda --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + %constant0 = f32[] constant(0) + %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2) +} + +// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y +// CHECK: scf.index_switch %[[BLOCK_ID_Y]] +// CHECK: case 1 { +// CHECK: default { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo new file mode 100644 index 00000000000000..e950e3cbdf8d83 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo @@ -0,0 +1,24 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce --bijection_outputs=exp + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[8,2048] parameter(0) + param_1 = f32[] parameter(1) + exp = f32[8,2048] exponential(param_0) + reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add + ROOT t = (f32[8], f32[8,2048]) tuple(reduce, exp) +} + +// CHECK: @fused_computation +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fused_computation_exp +// CHECK-NEXT: tensor.insert %[[SIDE_OUTPUT]] +// CHECK: scf.yield +// CHECK: scf.yield diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo new file mode 100644 index 00000000000000..0db1901a532501 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo @@ -0,0 +1,15 @@ +// RUN: test_correctness %s + +%add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +%fusion { + %p0 = f32[6,6] parameter(0) + %c0 = f32[] constant(0) + %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add + %broadcast = f32[6,6] broadcast(%reduce), dimensions={} + ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce) +} \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo new file mode 100644 index 00000000000000..5371b80532bad7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo @@ -0,0 +1,15 @@ +// RUN: test_correctness %s + +add { + lhs = u32[] parameter(0) + rhs = u32[] parameter(1) + ROOT add = u32[] add(lhs, rhs) +} + +fused_computation { + param_0 = u32[8,2048] parameter(0) + param_1 = u32[] parameter(1) + add = u32[8,2048] add(param_0, param_0) + reduce = u32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add + ROOT t = (u32[8], u32[8,2048]) tuple(reduce, add) +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo new file mode 100644 index 00000000000000..56e326608a0826 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f64[] parameter(0) + rhs = f64[] parameter(1) + ROOT add = f64[] add(lhs, rhs) +} + +fused_computation { + param_0 = f64[100,128] parameter(0) + param_1 = f64[] parameter(1) + ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=add +} + +// This reduction is small enough to not require any shared memory. +// CHECK-NOT: allocate_shared diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo new file mode 100644 index 00000000000000..b28bff49d8245d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo @@ -0,0 +1,23 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0,1 --bijection_outputs=reduce + +add { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + scalar_lhs.1 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) + add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +fused_computation { + param_0 = f32[2, 3, 2048] parameter(0) + param_1 = f32[2, 3, 2048] parameter(1) + c0 = f32[] constant(0) + ROOT reduce = (f32[2, 3], f32[2, 3]) + reduce(param_0, param_1, c0, c0), dimensions={2}, to_apply=add +} + +// CHECK: allocate_shared +// CHECK: allocate_shared diff --git a/third_party/xla/xla/service/gpu/fusions/tools/BUILD b/third_party/xla/xla/service/gpu/fusions/tools/BUILD new file mode 100644 index 00000000000000..2886ad1f7578bf --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/BUILD @@ -0,0 +1,113 @@ +load("//xla:xla.bzl", "xla_cc_binary") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +xla_cc_binary( + name = "mlir_fusions_opt", + srcs = ["mlir_fusions_opt.cc"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], + deps = [ + "//xla/mlir_hlo", + "//xla/service/gpu/fusions/ir:xla_gpu", + "//xla/service/gpu/fusions/transforms:passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", + ], +) + +cc_library( + name = "test_lib", + testonly = 1, + srcs = ["test_lib.cc"], + hdrs = ["test_lib.h"], + deps = [ + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions/ir:xla_gpu", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/stream_executor:device_description", + "//xla/tools:hlo_module_loader", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:VectorDialect", + ], +) + +xla_cc_binary( + name = "fusion_to_mlir", + testonly = 1, + srcs = ["fusion_to_mlir.cc"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], + deps = [ + ":test_lib", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_binary( + name = "test_correctness", + testonly = 1, + srcs = ["test_correctness.cc"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], + deps = [ + ":test_lib", + "//xla:debug_options_flags", + "//xla:error_spec", + "//xla:shape_util", + "//xla/service:gpu_plugin", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc new file mode 100644 index 00000000000000..9fe41b6cb97a5b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "llvm/Support/raw_ostream.h" +#include "xla/service/gpu/fusions/tools/test_lib.h" +#include "tsl/platform/init_main.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +absl::Status Run(const std::string& filename) { + TF_ASSIGN_OR_RETURN(auto module, LoadTestModule(filename)); + TF_ASSIGN_OR_RETURN(auto emitter_data, GetMlirFusionEmitter(*module)); + + auto context = GetMlirContextForTest(); + TF_ASSIGN_OR_RETURN(auto mlir_module, + emitter_data->emitter->CreateMLIRModule( + context, *emitter_data->fusion, "main", + /*buffer_assignment=*/nullptr)); + llvm::outs() << *mlir_module; + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla + +int main(int argc, char** argv) { + tsl::port::InitMain(argv[0], &argc, &argv); + CHECK_EQ(argc, 2) << "Must specify an input file"; + CHECK_OK(xla::gpu::Run(argv[1])); + return 0; +} diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc similarity index 52% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc rename to third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc index b6c63ecc690e4f..780ede0d6d061c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -26,14 +29,15 @@ limitations under the License. #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/transforms/passes.h" -int main(int argc, char **argv) { +int main(int argc, char** argv) { mlir::DialectRegistry registry; registry.insert + errorHandler) { + if (!options.empty()) return mlir::failure(); + + pm.addNestedPass( + xla::gpu::CreateSimplifyArithPass()); + pm.addPass(xla::gpu::CreateEraseDeadFunctionsPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) { + pm.addPass(mlir::createCSEPass()); + })); + return mlir::success(); + }, + [](llvm::function_ref) {}); + mlir::registerPassPipeline( + "xla-gpu-test-vectorize", + "Test pipeline for vectorization. Should run after " + "xla-gpu-test-to-inline.", + [=](mlir::OpPassManager& pm, llvm::StringRef options, + llvm::function_ref + errorHandler) { + if (!options.empty()) return mlir::failure(); + pm.addNestedPass( + xla::gpu::CreateLowerXlaGpuLoopsToScfPass()); + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addNestedPass( + xla::gpu::CreateUnswitchLoopsPass()); + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(xla::gpu::CreateFlattenTensorsPass()); + pm.addNestedPass( + xla::gpu::CreateVectorizeLoadsAndStoresPass()); + return mlir::success(); + }, + [](llvm::function_ref) {}); return mlir::failed( MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry)); diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc new file mode 100644 index 00000000000000..72529cd6545c4d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc @@ -0,0 +1,192 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "xla/debug_options_flags.h" +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/tools/test_lib.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/shape.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +struct Flags { + std::string input_file = ""; + float abs_error_bound = 1e-4; + float rel_error_bound = 1e-4; + std::vector>> bijection_inputs; + std::vector bijection_outputs; +}; + +Flags& flags = *new Flags; + +namespace xla { +namespace gpu { +namespace { + +using CorrectnessTest = HloTestBase; + +const Shape& GetFirstArrayShape(const Shape& shape) { + if (shape.IsArray()) { + return shape; + } + CHECK(shape.IsTuple()); + return GetFirstArrayShape(shape.tuple_shapes(0)); +} + +absl::Status TestBijection(const IndexingMap& map, + absl::Span shape) { + std::vector intervals; + for (int64_t size : shape) { + intervals.push_back({0, size - 1}); + } + auto status = VerifyBijection(map, intervals); + if (status.ok()) return status; + return absl::FailedPreconditionError( + absl::StrCat(status.message(), " in map ", map.ToString())); +} + +TEST_F(CorrectnessTest, RunAndCompare) { + TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file)); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), + ErrorSpec{flags.abs_error_bound, flags.rel_error_bound})); +} + +absl::StatusOr GetHeroIndex(absl::string_view name, + const HloFusionAnalysis& analysis) { + for (auto [index, hero] : llvm::enumerate(analysis.fusion_heroes())) { + if (hero.name() == name) { + return index; + } + } + return absl::NotFoundError(absl::StrCat("Hero ", name, " not found")); +} + +std::pair> ParseHeroAndIds( + absl::string_view hero_and_ids) { + std::pair hero_and_ids_pair = + absl::StrSplit(hero_and_ids, ':'); + std::vector ids; + for (absl::string_view id : absl::StrSplit(hero_and_ids_pair.second, ',')) { + ids.push_back(std::stoi(std::string(absl::StripAsciiWhitespace(id)))); + } + return {std::string(absl::StripAsciiWhitespace(hero_and_ids_pair.first)), + ids}; +} + +TEST_F(CorrectnessTest, InputIndexingIsBijection) { + auto context = GetMlirContextForTest(); + TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file)); + TF_ASSERT_OK_AND_ASSIGN(auto emitter_data, GetMlirFusionEmitter(*module)); + for (const auto& [hero_name, ids] : flags.bijection_inputs) { + TF_ASSERT_OK_AND_ASSIGN(int64_t hero_index, + GetHeroIndex(hero_name, *emitter_data->analysis)); + for (int64_t id : ids) { + auto indexing = emitter_data->emitter->ComputeThreadIdToInputIndexing( + hero_index, id, &context); + ASSERT_TRUE(indexing.has_value()); + TF_ASSERT_OK(TestBijection(*indexing, + emitter_data->analysis->fusion_hero(hero_index) + .GetOperand(id) + .shape() + .dimensions())) + << "Expected operand " << id << " of " << hero_name << " (root index " + << hero_index << ") to be read exactly once."; + } + } +} + +TEST_F(CorrectnessTest, OutputIndexingIsBijection) { + auto context = GetMlirContextForTest(); + TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file)); + TF_ASSERT_OK_AND_ASSIGN(auto emitter_data, GetMlirFusionEmitter(*module)); + for (const auto& hero_name : flags.bijection_outputs) { + TF_ASSERT_OK_AND_ASSIGN(int64_t hero_index, + GetHeroIndex(hero_name, *emitter_data->analysis)); + auto indexing = emitter_data->emitter->ComputeThreadIdToOutputIndexing( + hero_index, &context); + ASSERT_TRUE(indexing.has_value()); + TF_ASSERT_OK(TestBijection( + *indexing, GetFirstArrayShape( + emitter_data->analysis->fusion_root(hero_index).shape()) + .dimensions())) + << "Expected output of " << hero_name << " (root index " << hero_index + << ") to be written exactly once."; + } +} + +} // namespace +} // namespace gpu +} // namespace xla + +int main(int argc, char* argv[]) { + std::vector flag_list = { + tsl::Flag("abs_error_bound", &flags.abs_error_bound, + "Absolute error bound."), + tsl::Flag("rel_error_bound", &flags.rel_error_bound, + "Relative error bound."), + tsl::Flag( + "bijection_inputs", + [](std::string name_and_ids) { + if (name_and_ids.empty()) return false; + flags.bijection_inputs.push_back( + xla::gpu::ParseHeroAndIds(name_and_ids)); + return true; + }, + "", + "The name of a hero followed by operand ids that should be read " + "exactly once, i.e. there's a bijection between a subset of threads " + "and the input shape. Example: 'reduction0: 0, 1'."), + tsl::Flag( + "bijection_outputs", + [](std::string name) { + if (name.empty()) return false; + flags.bijection_outputs.push_back(name); + return true; + }, + "", + "The name of a hero whose outputs should be written exactly once, " + "i.e. there's a bijection between a subset of threads and the output " + "shape.")}; + + xla::AppendDebugOptionsFlags(&flag_list); + std::string usage = tsl::Flags::Usage(argv[0], flag_list); + bool parseResult = tsl::Flags::Parse(&argc, argv, flag_list); + if (!parseResult || argc != 2) { + LOG(ERROR) << "\n" << usage; + return 1; + } + + flags.input_file = argv[1]; + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc new file mode 100644 index 00000000000000..11b82ddd517072 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/tools/test_lib.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/status_macros.h" +#include "xla/tools/hlo_module_loader.h" + +namespace xla { +namespace gpu { + +absl::StatusOr> LoadTestModule( + absl::string_view filename) { + auto module = *xla::LoadModuleFromFile(std::string(filename)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_mlir_emitter_level(4); + + int num_fusions = absl::c_count_if( + module->entry_computation()->instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() == xla::HloOpcode::kFusion; + }); + TF_RET_CHECK(num_fusions <= 1) << "HLO must contain at most one fusion"; + + if (num_fusions == 0) { + // Generate a fusion from the entry computation. + HloComputation::Builder builder("generated_main"); + std::vector params; + for (const auto* param : + module->entry_computation()->parameter_instructions()) { + params.push_back(*builder.AddParameter(param->Clone(/*suffix=*/""))); + } + builder.AddInstruction(HloInstruction::CreateFusion( + module->entry_computation()->root_instruction()->shape(), + HloInstruction::FusionKind::kLoop /* irrelevant */, params, + module->entry_computation())); + + auto* new_entry = module->AddComputationAndUnifyNamesAndIds( + builder.Build(), /*is_entry=*/false); + module->ReplaceEntryComputation(new_entry); + } + + return module; +} + +absl::StatusOr> GetMlirFusionEmitter( + const HloModule& module) { + auto data = std::make_unique(); + data->fusion = DynCast( + module.entry_computation()->root_instruction()); + TF_RET_CHECK(data->fusion != nullptr) << "Root instruction must be a fusion"; + data->device.emplace(TestGpuDeviceInfo::RTXA6000DeviceInfo()); + data->analysis.emplace( + HloFusionAnalysis::Create(*data->fusion, data->device.value())); + PreBufferAssignmentFusionInfo info(data->analysis.value()); + auto emitter = GetFusionEmitter(info); + + auto mlir_emitter = dynamic_cast(emitter.get()); + TF_RET_CHECK(mlir_emitter != nullptr) + << "Expected emitter to be an MlirFusionEmitter"; + + emitter.release(); + data->emitter.reset(mlir_emitter); + return data; +} + +mlir::MLIRContext GetMlirContextForTest() { + mlir::DialectRegistry registry; + registry.insert(); + return mlir::MLIRContext(registry); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h new file mode 100644 index 00000000000000..5dfa3009f71c40 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_ +#define XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { + +namespace gpu { + +// Loads a test module from the given filename, ensuring it has a single fusion. +// If the file contains more than one fusion, the function fails. If the file +// contains no fusions, the function generates a fusion from the entry +// computation. +absl::StatusOr> LoadTestModule( + absl::string_view filename); + +// Returns the MLIR fusion emitter for the given module, which should have been +// loaded using LoadTestModule. +struct EmitterData { + HloFusionInstruction* fusion; + std::optional device; + std::optional analysis; + std::unique_ptr emitter; +}; +absl::StatusOr> GetMlirFusionEmitter( + const HloModule& module); + +// Returns an MLIR context with all the dialects needed for testing. +mlir::MLIRContext GetMlirContextForTest(); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD index 77009eacebfdd1..24fb1963afccaa 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD @@ -58,13 +58,14 @@ cc_library( "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc index 72446ec1ba0a10..0c9053a5570654 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc b/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc index 285b40de81ec86..3918a191fee3cb 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc @@ -21,7 +21,7 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc index 2274576dcb23ca..66ff74413ef25c 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc index a5a00eba7cd2b3..c854507003c44f 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc @@ -21,6 +21,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -43,9 +46,10 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/layout_util.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -71,6 +75,7 @@ using mlir::func::FuncOp; using mlir::func::ReturnOp; using mlir::scf::ForOp; using mlir::scf::IfOp; +using mlir::scf::IndexSwitchOp; using mlir::tensor::ExtractOp; using mlir::tensor::InsertOp; @@ -79,12 +84,25 @@ RankedTensorType GetFlattenedType(RankedTensorType tensor_type) { tensor_type.getElementType()); } +bool IsScalarOrFlat(Type type) { + auto tensor_type = mlir::dyn_cast(type); + if (!tensor_type) return true; + return tensor_type.getRank() < 2; +} + bool HasOnlyFlatTensorsOrScalars(TypeRange types) { - return llvm::all_of(types, [](Type ty) { - auto tensor_type = mlir::dyn_cast(ty); - if (!tensor_type) return true; - return tensor_type.getRank() < 2; - }); + return llvm::all_of(types, IsScalarOrFlat); +} + +Value Flatten(Value value, PatternRewriter& rewriter) { + auto tensor_type = mlir::dyn_cast(value.getType()); + if (!tensor_type || tensor_type.getRank() < 2) { + return value; + } + auto flat_type = GetFlattenedType(tensor_type); + return rewriter + .create(value.getLoc(), flat_type, value) + .getResult(0); } struct RewriteFunctionSignatures : OpRewritePattern { @@ -109,20 +127,9 @@ struct RewriteFunctionSignatures : OpRewritePattern { rewriter.setInsertionPoint(terminator); for (Value result : terminator->getOperands()) { - auto tensor_type = mlir::dyn_cast(result.getType()); - if (!tensor_type) { - new_result_types.push_back(result.getType()); - new_results.push_back(result); - continue; - } - auto new_result_type = GetFlattenedType(tensor_type); - new_result_types.push_back(new_result_type); - - Value result_1d = - rewriter - .create(loc, new_result_type, result) - .getResult(0); - new_results.push_back(result_1d); + Value flattened = Flatten(result, rewriter); + new_results.push_back(flattened); + new_result_types.push_back(flattened.getType()); } rewriter.replaceOpWithNewOp(terminator, new_results); @@ -130,16 +137,14 @@ struct RewriteFunctionSignatures : OpRewritePattern { SmallVector new_operand_types(input_types); rewriter.setInsertionPointToStart(entry_block); for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) { - if (auto tensor_type = mlir::dyn_cast(operand_type)) { - if (tensor_type.getRank() > 1) { - mlir::BlockArgument func_argument = op.getArgument(index); - auto cast_to_orig_type = rewriter.create( - loc, operand_type, func_argument); - func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0), - cast_to_orig_type); - operand_type = GetFlattenedType(tensor_type); - } - } + if (IsScalarOrFlat(operand_type)) continue; + mlir::BlockArgument func_argument = op.getArgument(index); + auto cast_to_orig_type = rewriter.create( + loc, operand_type, func_argument); + func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0), + cast_to_orig_type); + operand_type = + GetFlattenedType(mlir::cast(operand_type)); } // Replace the function arguments with the new types. for (auto [arg, arg_type] : @@ -152,6 +157,51 @@ struct RewriteFunctionSignatures : OpRewritePattern { } }; +struct RewritePureCall : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PureCallOp op, + PatternRewriter& rewriter) const override { + if (HasOnlyFlatTensorsOrScalars(op.getOperandTypes()) && + HasOnlyFlatTensorsOrScalars(op.getResultTypes())) { + return rewriter.notifyMatchFailure(op, "nothing to flatten"); + } + SmallVector flat_operands; + flat_operands.reserve(op.getNumOperands()); + for (Value operand : op.getOperands()) { + flat_operands.push_back(Flatten(operand, rewriter)); + } + SmallVector flat_result_types; + flat_result_types.reserve(op.getNumResults()); + llvm::SmallBitVector results_to_update(op.getNumResults(), false); + for (auto [index, result_type] : llvm::enumerate(op.getResultTypes())) { + if (IsScalarOrFlat(result_type)) { + flat_result_types.push_back(result_type); + continue; + } + results_to_update.set(index); + flat_result_types.push_back( + GetFlattenedType(mlir::cast(result_type))); + } + Location loc = op.getLoc(); + auto new_call_op = rewriter.create( + loc, flat_result_types, op.getCalleeAttr(), flat_operands); + SmallVector new_results; + new_results.reserve(op.getNumResults()); + for (auto [index, new_result] : llvm::enumerate(new_call_op.getResults())) { + if (results_to_update.test(index)) { + new_results.push_back(new_result); + continue; + } + auto cast_to_orig_type = rewriter.create( + loc, op.getResult(index).getType(), new_result); + new_results.push_back(cast_to_orig_type.getResult(0)); + } + rewriter.replaceOp(op, new_results); + return mlir::success(); + } +}; + // Returns the linearized index, if the rank is greater than 1. Otherwise, // returns nullptr. Value LinearizeIndex(TypedValue tensor, @@ -174,6 +224,43 @@ Value LinearizeIndex(TypedValue tensor, return result.front(); } +struct RewriteAllocateShared : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocateSharedOp op, + PatternRewriter& rewriter) const override { + auto tensor_type = op.getResult().getType(); + if (IsScalarOrFlat(tensor_type)) { + return rewriter.notifyMatchFailure(op, "the tensor is already flat"); + } + auto flat_type = GetFlattenedType(tensor_type); + Location loc = op.getLoc(); + Value new_op = rewriter.create(op.getLoc(), flat_type); + auto cast_to_orig_type = + rewriter.create(loc, tensor_type, new_op); + rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); + return mlir::success(); + } +}; + +struct RewriteTensorConstant : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::arith::ConstantOp op, + PatternRewriter& rewriter) const override { + if (IsScalarOrFlat(op.getType())) { + return rewriter.notifyMatchFailure(op, "the tensor is already flat"); + } + auto tensor_type = mlir::cast(op.getType()); + auto dense_attr = mlir::dyn_cast(op.getValue()); + Value new_constant = rewriter.create( + op.getLoc(), dense_attr.reshape(GetFlattenedType(tensor_type))); + rewriter.replaceOpWithNewOp(op, tensor_type, + new_constant); + return mlir::success(); + } +}; + struct RewriteTensorExtract : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -262,7 +349,7 @@ std::optional GetDelinearizedTensor(Value value) { return cast->getOperand(0); } -struct RewriteForOp : public OpRewritePattern { +struct RewriteFor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ForOp op, @@ -337,7 +424,7 @@ struct RewriteForOp : public OpRewritePattern { } }; -struct RewriteIfOp : public OpRewritePattern { +struct RewriteIf : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IfOp op, @@ -405,6 +492,113 @@ struct RewriteIfOp : public OpRewritePattern { } }; +struct RewriteIndexSwitch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexSwitchOp op, + PatternRewriter& rewriter) const override { + auto result_types = op.getResultTypes(); + if (HasOnlyFlatTensorsOrScalars(result_types)) { + return rewriter.notifyMatchFailure(op, "nothing to flatten"); + } + auto default_yield = + mlir::cast(op.getDefaultBlock().getTerminator()); + SmallVector new_result_types; + new_result_types.reserve(default_yield.getNumOperands()); + bool found_cast = false; + for (auto& result : default_yield->getOpOperands()) { + auto delinearized_tensor = GetDelinearizedTensor(result.get()); + if (!delinearized_tensor.has_value()) { + new_result_types.push_back(result.get().getType()); + continue; + } + new_result_types.push_back(delinearized_tensor->getType()); + result.set(*delinearized_tensor); + found_cast = true; + } + if (!found_cast) { + return rewriter.notifyMatchFailure(op, "no cast found"); + } + Location loc = op.getLoc(); + // Update the "case" regions. + for (auto& case_region : op.getCaseRegions()) { + auto yield = mlir::cast( + case_region.getBlocks().front().getTerminator()); + mlir::OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yield); + for (auto&& [result, type] : + llvm::zip(yield->getOpOperands(), new_result_types)) { + if (result.get().getType() == type) continue; + result.set( + rewriter.create(loc, type, result.get()) + .getResult(0)); + } + } + // Create new IndexSwitchOp and move the old op's regions to the new one. + auto new_index_switch = rewriter.create( + loc, new_result_types, op.getArg(), op.getCases(), op.getNumCases()); + for (auto&& [old_region, new_region] : + llvm::zip(op.getRegions(), new_index_switch.getRegions())) { + rewriter.inlineRegionBefore(*old_region, *new_region, new_region->end()); + } + // Update the results. + rewriter.setInsertionPointAfter(new_index_switch); + SmallVector new_results(new_index_switch.getResults()); + for (auto&& [index, result] : llvm::enumerate(new_results)) { + Type old_type = op->getResult(index).getType(); + if (result.getType() == old_type) continue; + result = + rewriter.create(loc, old_type, result) + .getResult(0); + } + rewriter.replaceOp(op, new_results); + return mlir::success(); + } +}; + +struct RewriteSyncThreads : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SyncThreadsOp op, + PatternRewriter& rewriter) const override { + auto types = op.getResultTypes(); + if (HasOnlyFlatTensorsOrScalars(types)) { + return rewriter.notifyMatchFailure(op, "nothing to flatten"); + } + + auto loc = op.getLoc(); + + SmallVector new_operands; + new_operands.reserve(op.getNumOperands()); + llvm::SmallBitVector results_to_update(op.getNumResults(), false); + for (auto& operand : op->getOpOperands()) { + auto tensor_type = mlir::cast(operand.get().getType()); + if (tensor_type.getRank() < 2) continue; + results_to_update.set(operand.getOperandNumber()); + new_operands.push_back( + rewriter + .create( + loc, GetFlattenedType(tensor_type), operand.get()) + .getResult(0)); + } + auto new_op = rewriter.create(loc, TypeRange(new_operands), + new_operands); + SmallVector new_results; + new_results.reserve(op.getNumResults()); + for (auto [index, result] : llvm::enumerate(new_op.getResults())) { + if (!results_to_update.test(index)) { + new_results.push_back(result); + continue; + } + auto cast_to_orig_type = rewriter.create( + loc, result.getType(), result); + new_results.push_back(cast_to_orig_type.getResult(0)); + } + rewriter.replaceOp(op, new_results); + return mlir::success(); + } +}; + class FlattenTensorsPass : public impl::FlattenTensorsPassBase { public: @@ -414,10 +608,15 @@ class FlattenTensorsPass mlir::RewritePatternSet patterns(mlir_context); // clang-format off patterns.add< + RewriteAllocateShared, RewriteAtomicRMW, - RewriteForOp, + RewriteFor, RewriteFunctionSignatures, - RewriteIfOp, + RewriteIf, + RewriteIndexSwitch, + RewritePureCall, + RewriteSyncThreads, + RewriteTensorConstant, RewriteTensorExtract, RewriteTensorInsert >(mlir_context); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc index ff36d19cb756f5..63e5c75f56c03a 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -57,10 +56,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/layout_util.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/shape_util.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" @@ -172,27 +168,14 @@ struct RewriteFunctionSignatures : mlir::OpRewritePattern { } }; -Value GetLinearIndex(TypedValue tensor, - ValueRange indices, mlir::PatternRewriter& rewriter) { - auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape()); - if (auto encoding = tensor.getType().getEncoding()) { - *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector( - mlir::cast(encoding).getValues())); - } - auto linear_shape = - ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)}); - auto linearize_map = - GetBitcastMap(byte_shape, linear_shape, tensor.getContext()); - mlir::SmallVector result; - rewriter.createOrFold(result, tensor.getLoc(), indices, - ValueRange{}, linearize_map); - CHECK_EQ(result.size(), 1); - auto index = result.front(); - auto index_ty = rewriter.getIntegerType( - mlir::DataLayout::closest(rewriter.getInsertionBlock()->getParentOp()) +Value GetLinearIndex(ValueRange indices, mlir::ImplicitLocOpBuilder& b) { + CHECK_LE(indices.size(), 1) << "Only 0D and 1D tensors are supported"; + auto index = indices.empty() ? b.create(0) + : indices.front(); + auto index_ty = b.getIntegerType( + mlir::DataLayout::closest(b.getInsertionBlock()->getParentOp()) .getTypeSizeInBits(index.getType())); - return rewriter.create(tensor.getLoc(), index_ty, - index); + return b.create(index_ty, index); } std::tuple GetI4IndexAndNibble(Value linear_index, @@ -206,28 +189,25 @@ std::tuple GetI4IndexAndNibble(Value linear_index, } mlir::LLVM::GEPOp CreateGep(TypedValue tensor, - Value linear_index, mlir::PatternRewriter& rewriter, + Value linear_index, mlir::ImplicitLocOpBuilder& b, Type element_type = nullptr) { if (!element_type) { element_type = tensor.getType().getElementType(); } - auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - auto tensor_ptr = rewriter - .create( - tensor.getLoc(), ptr, tensor) - .getResult(0); - mlir::LLVMTypeConverter converter(rewriter.getContext()); + auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); + auto tensor_ptr = + b.create(ptr, tensor).getResult(0); + mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - auto gep = rewriter.create( - tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, linear_index); + auto gep = b.create(ptr, llvm_element_type, tensor_ptr, + linear_index); gep.setInbounds(true); return gep; } mlir::LLVM::GEPOp CreateGep(TypedValue tensor, - ValueRange indices, - mlir::PatternRewriter& rewriter) { - return CreateGep(tensor, GetLinearIndex(tensor, indices, rewriter), rewriter); + ValueRange indices, mlir::ImplicitLocOpBuilder& b) { + return CreateGep(tensor, GetLinearIndex(indices, b), b); } struct RewriteTensorExtract : mlir::OpRewritePattern { @@ -237,8 +217,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { mlir::tensor::ExtractOp op, mlir::PatternRewriter& rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto linear_index = - GetLinearIndex(op.getTensor(), op.getIndices(), rewriter); + auto linear_index = GetLinearIndex(op.getIndices(), b); Type element_type = op.getTensor().getType().getElementType(); Value is_low_nibble = nullptr; if (element_type == rewriter.getI4Type()) { @@ -247,7 +226,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { GetI4IndexAndNibble(linear_index, b); } - auto gep = CreateGep(op.getTensor(), linear_index, rewriter, element_type); + auto gep = CreateGep(op.getTensor(), linear_index, b, element_type); auto load = rewriter .create(gep.getLoc(), gep.getElemType(), gep) @@ -296,7 +275,7 @@ struct RewriteTransferRead op.getSource()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto linear_index = GetLinearIndex(source, op.getIndices(), rewriter); + auto linear_index = GetLinearIndex(op.getIndices(), b); mlir::VectorType vector_type = op.getVectorType(); if (vector_type.getElementType().isInteger(1)) { @@ -309,7 +288,7 @@ struct RewriteTransferRead b.create(1, linear_index.getType())); gep_element_type = b.getI8Type(); } - auto gep = CreateGep(source, linear_index, rewriter, gep_element_type); + auto gep = CreateGep(source, linear_index, b, gep_element_type); mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_vector_type = converter.convertType(vector_type); @@ -345,7 +324,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto tensor_dest = mlir::cast>(dest); - auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter); + auto linear_index = GetLinearIndex(op.getIndices(), b); auto element_type = tensor_dest.getType().getElementType(); Value is_low_nibble = nullptr; @@ -355,7 +334,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { GetI4IndexAndNibble(linear_index, b); } - auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type); + auto gep = CreateGep(tensor_dest, linear_index, b, element_type); auto scalar_value = op.getScalar(); if (is_low_nibble) { @@ -402,7 +381,7 @@ struct RewriteTransferWrite mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto tensor_dest = mlir::cast>(dest); - auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter); + auto linear_index = GetLinearIndex(op.getIndices(), b); auto element_type = tensor_dest.getType().getElementType(); mlir::Value vector_value = op.getVector(); @@ -420,7 +399,7 @@ struct RewriteTransferWrite // elements. vector_value = PermutePairsInVector(vector_value, b); } - auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type); + auto gep = CreateGep(tensor_dest, linear_index, b, element_type); mlir::LLVMTypeConverter converter(getContext()); auto llvm_type = converter.convertType(vector_value.getType()); @@ -724,7 +703,8 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { Location loc = op.getLoc(); llvm::StringRef sync_scope = is_amd_ ? "agent" : ""; - Value addr = CreateGep(op.getInput(), op.getIndices(), rewriter); + mlir::ImplicitLocOpBuilder b(loc, rewriter); + Value addr = CreateGep(op.getInput(), op.getIndices(), b); switch (atomic_bin_op) { case ml::AtomicBinOp::xchg: { @@ -932,7 +912,8 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { mlir::IntegerType::get(op.getContext(), small_type ? 32 : result_size); // Calculate load address for the input. - Value addr = CreateGep(input, op.getIndices(), rewriter); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value addr = CreateGep(input, op.getIndices(), b); Value shift, mask; if (small_type) { // Update input pointer by discarding the last two bits - i.e. align to diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc index 133f2005c0d609..cbd64b870e83d1 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc @@ -38,8 +38,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc index 9125ba09cda661..e483bfebedb979 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc @@ -41,7 +41,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc index 4cd9d4bef008b7..7c0845fff2011c 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -31,8 +32,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc index 490d3cd9b2d372..acbd9d3735ea46 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc @@ -41,7 +41,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/model/indexing_map.h" diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc index 43da6080ff14a2..f3d67e24ee3248 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -29,7 +30,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/model/indexing_map.h" diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD index 11db73111408ea..381d5a3220b1df 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD @@ -10,7 +10,7 @@ lit_test_suite( srcs = glob(["*.mlir"]), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/service/gpu/fusions/mlir/tests:mlir_fusions_opt", + "//xla/service/gpu/fusions/tools:mlir_fusions_opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir index 21a8dc2a0b7e79..18e3e30bc309c2 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir @@ -8,8 +8,7 @@ func.func @tensor_extract( : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), -// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 2]> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2]> // CHECK-LABEL: func.func @tensor_extract( // CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, @@ -37,6 +36,26 @@ func.func @tensor_insert( // ----- +func.func @update(%arg0: tensor<10x24xf32>) -> tensor<10x24xf32> { + %c1 = arith.constant 1 : index + %c42_f32 = arith.constant 42.0 : f32 + %out = tensor.insert %c42_f32 into %arg0[%c1, %c1] : tensor<10x24xf32> + func.return %out : tensor<10x24xf32> +} + +func.func @pure_call(%arg0: tensor<10x24xf32>) -> tensor<10x24xf32> { + %updated_tensor = xla_gpu.pure_call @update(%arg0) + : (tensor<10x24xf32>) -> (tensor<10x24xf32>) + func.return %updated_tensor : tensor<10x24xf32> +} +// CHECK-LABEL: func.func @pure_call( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<240xf32>) -> tensor<240xf32> { +// CHECK-NEXT: xla_gpu.pure_call @update(%[[TENSOR]]) +// CHECK-SAME: : (tensor<240xf32>) -> tensor<240xf32> +// CHECK-NEXT: return + +// ----- + func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) -> (tensor<2x4xf32>) { %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { @@ -47,9 +66,7 @@ func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) } return %ret : tensor<2x4xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 4 + d1), -// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3]> - +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3]> // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { @@ -74,11 +91,8 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) } {some_attr} return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32 } - -// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024), -// CHECK-SAME: domain: d0 in [0, 1023]> -// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5), -// CHECK-SAME: domain: d0 in [0, 63]> +// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024) +// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5) // CHECK-LABEL: func.func @for_loop( // CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>, // CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) { @@ -98,12 +112,9 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) // ----- -#map = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36), - domain: d0 in [0, 127], d1 in [0, 393749]> -#map1 = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), - domain: d0 in [0, 127], d1 in [0, 393749]> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 9), - domain: d0 in [0, 127], d1 in [0, 393749]> +#map = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749]> +#map1 = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749]> +#map2 = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749]> func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) -> tensor<4000x4x9xf32> { @@ -138,9 +149,67 @@ func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, // ----- +func.func @allocate_shared() -> tensor<10x15xf32> { + %shmem = xla_gpu.allocate_shared : tensor<10x15xf32> + func.return %shmem : tensor<10x15xf32> +} +// CHECK-LABEL: func.func @allocate_shared() -> tensor<150xf32> +// CHECK: xla_gpu.allocate_shared : tensor<150xf32> +// CHECK-NOT: builtin.unrealized_conversion_cast + +// ----- + +func.func @sync() -> (tensor<8x4xf32>, tensor<8x4xf32>) { + %shared1 = xla_gpu.allocate_shared : tensor<8x4xf32> + %shared2 = xla_gpu.allocate_shared : tensor<8x4xf32> + %sync:2 = xla_gpu.sync_threads %shared1, %shared2 + : tensor<8x4xf32>, tensor<8x4xf32> + return %sync#0, %sync#1 : tensor<8x4xf32>, tensor<8x4xf32> +} +// CHECK-LABEL: func.func @sync() -> (tensor<32xf32>, tensor<32xf32>) { +// CHECK: %[[SHARED1:.*]] = xla_gpu.allocate_shared : tensor<32xf32> +// CHECK: %[[SHARED2:.*]] = xla_gpu.allocate_shared : tensor<32xf32> +// CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHARED1]], %[[SHARED2]] +// CHECK-SAME: : tensor<32xf32>, tensor<32xf32> +// CHECK-NEXT: return + +// ----- + +func.func @index_switch(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, + %arg2: tensor<2x3xf32>, %arg3: tensor<2x3xf32> + ) -> (tensor<2x3xf32>, tensor<2x3xf32>) { + %block_id_y = gpu.block_id y {xla.range = [0 : index, 1 : index]} + %0:2 = scf.index_switch %block_id_y -> tensor<2x3xf32>, tensor<2x3xf32> + case 1 { + scf.yield %arg0, %arg3 : tensor<2x3xf32>, tensor<2x3xf32> + } + default { + scf.yield %arg1, %arg2 : tensor<2x3xf32>, tensor<2x3xf32> + } + return %0#0, %0#1: tensor<2x3xf32>, tensor<2x3xf32> +} +// CHECK-LABEL: func.func @index_switch +// CHECK-SAME: -> (tensor<6xf32>, tensor<6xf32>) +// CHECK-NOT: builtin.unrealized_conversion_cast + +// ----- + +func.func @constant() -> tensor<2x3xf32> { + %cst = arith.constant dense<[ + [-3.000000e+00, 2.000000e+00, 1.000000e+00], + [0.000000e+00, -3.000000e+00, 1.000000e+00] + ]> : tensor<2x3xf32> + return %cst : tensor<2x3xf32> +} +// CHECK-LABEL: func.func @constant +// CHECK-SAME: -> tensor<6xf32> +// CHECK-NOT: builtin.unrealized_conversion_cast + +// ----- + func.func @dangling_cast(%arg0: tensor<6xf32>, %arg1: index) -> i32 { %v = tensor.extract %arg0[%arg1] : tensor<6xf32> %cast = builtin.unrealized_conversion_cast %v : f32 to i32 func.return %cast : i32 } -// CHECK: FlattenTensorsPass failed to converge +// CHECK: FlattenTensorsPass failed to converge \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir index be8eb1eef94f8e..822c3a85c9a2a0 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir @@ -80,55 +80,28 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry> // ----- -module { - func.func @layout( - %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>, - %arg1: index, %arg2: index) -> f32 { - %v = tensor.extract %arg0[%arg1, %arg2] - : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> - func.return %v : f32 - } -} - -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), -// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 2]> -// CHECK-LABEL: @layout( -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index -// CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[X]], %[[Y]]) -// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64 -// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]] -// CHECK: llvm.load %[[PTR]] +func.func @store_control_flow( %arg0: tensor<2xf32>, %arg1: index) + -> tensor<2xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 1.0 : f32 -// ----- + %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> { + %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32> + scf.yield %new_out : tensor<2xf32> + } -module { - func.func @store_control_flow( - %arg0: tensor<2xf32>, - %arg1: index - ) -> tensor<2xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %cst2 = arith.constant 1.0 : f32 - - %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> { - %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32> - scf.yield %new_out : tensor<2xf32> - } - - %inbounds = arith.cmpi sle, %arg1, %c1 : index - %result = scf.if %inbounds -> tensor<2xf32> { - %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32> - scf.yield %if : tensor<2xf32> - } else { - scf.yield %for : tensor<2xf32> - } - func.return %result : tensor<2xf32> + %inbounds = arith.cmpi sle, %arg1, %c1 : index + %result = scf.if %inbounds -> tensor<2xf32> { + %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32> + scf.yield %if : tensor<2xf32> + } else { + scf.yield %for : tensor<2xf32> } + func.return %result : tensor<2xf32> } - // CHECK: @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -145,33 +118,25 @@ module { // ----- -module { - func.func @large_tensor( - %arg0: tensor<1024x1024x1024x6xf32>, - %arg1: index) -> f32 { - %v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32> - func.return %v : f32 - } +func.func @large_tensor(%arg0: tensor<8000000000xf32>, %arg1: index) -> f32 { + %v = tensor.extract %arg0[%arg1] : tensor<8000000000xf32> + func.return %v : f32 } - -// CHECK: @large_tensor +// CHECK-LABEL: @large_tensor // CHECK: arith.index_castui {{.*}} : index to i64 // ----- -module { - func.func @extract_from_constant(%arg0: tensor<2x1xf32>, - %arg1: index, %arg2: index) -> f32 { - %cst = arith.constant dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32> - %extracted = tensor.extract %arg0[%arg1, %arg2] : tensor<2x1xf32> - %extracted_0 = tensor.extract %cst[%arg1, %arg2] : tensor<2x1xf32> - %0 = arith.addf %extracted, %extracted_0 : f32 - return %0 : f32 - } +func.func @extract_from_constant(%arg0: tensor<2xf32>, %arg1: index) -> f32 { + %cst = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + %extracted = tensor.extract %arg0[%arg1] : tensor<2xf32> + %extracted_0 = tensor.extract %cst[%arg1] : tensor<2xf32> + %0 = arith.addf %extracted, %extracted_0 : f32 + return %0 : f32 } // CHECK: llvm.mlir.global private constant @global_cst_0(dense< // CHECK-SAME: [1.000000e+00, 2.000000e+00]> : tensor<2xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32> -// CHECK: @extract_from_constant +// CHECK-LABEL: @extract_from_constant // CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @global_cst_0 : !llvm.ptr // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ADDR_OF]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f32 @@ -180,31 +145,25 @@ module { // ----- -module { - func.func @vector_constant() -> vector<2xindex> { - %c1 = arith.constant dense<[1, 2]> : vector<2xindex> - func.return %c1 : vector<2xindex> - } +func.func @vector_constant() -> vector<2xindex> { + %c1 = arith.constant dense<[1, 2]> : vector<2xindex> + func.return %c1 : vector<2xindex> } - // vector constants should not be rewritten. // CHECK: @vector_constant // CHECK-NEXT: arith.constant // ----- -module { - func.func @complex_tensor_insert( - %arg0: tensor<10xcomplex>) -> tensor<10xcomplex> { - %c1 = arith.constant 1 : index - %real = arith.constant 3.0 : f32 - %imag = arith.constant 2.0 : f32 - %complex = complex.create %real, %imag : complex - %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex> - func.return %out : tensor<10xcomplex> - } +func.func @complex_tensor_insert( + %arg0: tensor<10xcomplex>) -> tensor<10xcomplex> { + %c1 = arith.constant 1 : index + %real = arith.constant 3.0 : f32 + %imag = arith.constant 2.0 : f32 + %complex = complex.create %real, %imag : complex + %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex> + func.return %out : tensor<10xcomplex> } - // CHECK: @complex_tensor_insert(%[[ARG0:.*]]: !llvm.ptr // CHECK: %[[C:.*]] = complex.create // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)> @@ -213,15 +172,12 @@ module { // ----- -module { - func.func @complex_tensor_extract( - %arg0: tensor<10xcomplex>) -> complex { - %c1 = arith.constant 1 : index - %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex> - func.return %v2 : complex - } +func.func @complex_tensor_extract( + %arg0: tensor<10xcomplex>) -> complex { + %c1 = arith.constant 1 : index + %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex> + func.return %v2 : complex } - // CHECK: @complex_tensor_extract(%[[ARG0:.*]]: !llvm.ptr // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)> // CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> !llvm.struct<(f32, f32)> @@ -229,46 +185,33 @@ module { // ----- -module { - // This example is a bit silly, in real life there wouldn't be a loop (the - // loop body would be executed by different threads). We're just doing it this - // way so control flow with shared memory is tested as well. - func.func @transpose_shared(%in: tensor<32x32xf32>, - %out: tensor<32x32xf32>) -> tensor<32x32xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - - %shared = xla_gpu.allocate_shared : tensor<32x32xf32> - %loaded_tile = scf.for %i = %c0 to %c32 step %c1 - iter_args(%tile = %shared) -> tensor<32x32xf32> { - %inner_loaded_tile = scf.for %j = %c0 to %c32 step %c1 - iter_args(%inner_tile = %tile) -> tensor<32x32xf32> { - %v = tensor.extract %in[%i, %j] : tensor<32x32xf32> - %inserted = tensor.insert %v into %inner_tile[%i, %j] - : tensor<32x32xf32> - scf.yield %inserted : tensor<32x32xf32> - } - scf.yield %inner_loaded_tile : tensor<32x32xf32> - } - - %synced = xla_gpu.sync_threads %shared : tensor<32x32xf32> - %written_tile = scf.for %i = %c0 to %c32 step %c1 - iter_args(%written = %out) -> tensor<32x32xf32> { - %inner_written_tile = scf.for %j = %c0 to %c32 step %c1 - iter_args(%inner_written = %written) -> tensor<32x32xf32> { - %v = tensor.extract %shared[%j, %i] : tensor<32x32xf32> - %inserted = tensor.insert %v into %inner_written[%i, %j] - : tensor<32x32xf32> - scf.yield %inserted : tensor<32x32xf32> - } - scf.yield %inner_written_tile : tensor<32x32xf32> - } - - return %written_tile : tensor<32x32xf32> +// This example is a bit silly, in real life there wouldn't be a loop (the +// loop body would be executed by different threads). We're just doing it this +// way so control flow with shared memory is tested as well. +func.func @transpose_shared(%in: tensor<1024xf32>, + %out: tensor<1024xf32>) -> tensor<1024xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + %shared = xla_gpu.allocate_shared : tensor<1024xf32> + %loaded_tile = scf.for %i = %c0 to %c1024 step %c1 + iter_args(%tile = %shared) -> tensor<1024xf32> { + %v = tensor.extract %in[%i] : tensor<1024xf32> + %inserted = tensor.insert %v into %tile[%i] : tensor<1024xf32> + scf.yield %inserted : tensor<1024xf32> } -} + %synced = xla_gpu.sync_threads %shared : tensor<1024xf32> + %written_tile = scf.for %i = %c0 to %c1024 step %c1 + iter_args(%written = %out) -> tensor<1024xf32> { + %v = tensor.extract %shared[%i] : tensor<1024xf32> + %inserted = tensor.insert %v into %written[%i] : tensor<1024xf32> + scf.yield %inserted : tensor<1024xf32> + } + + return %written_tile : tensor<1024xf32> +} // CHECK: llvm.mlir.global private @[[SHARED:shared_.*]]() // CHECK-SAME: {addr_space = 3 : i32} : !llvm.array<1024 x f32> // CHECK: @transpose_shared @@ -276,30 +219,24 @@ module { // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[ADDR]] // CHECK-SAME: : !llvm.ptr<3> to !llvm.ptr // CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] -// CHECK: llvm.store {{.*}} %[[ELEM_ADDR]] +// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] +// CHECK: llvm.store {{.*}} %[[ELEM_ADDR]] // CHECK: gpu.barrier // CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] -// CHECK: llvm.load %[[ELEM_ADDR]] +// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] +// CHECK: llvm.load %[[ELEM_ADDR]] // ----- -module { - func.func @atomic_rmw_f32(%in: tensor<2x4xf32>, %i: index, %j: index) - -> (tensor<2x4xf32>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { - ^bb0(%current : f32): - %c42 = arith.constant 1.0 : f32 - %add = arith.minimumf %current, %c42 : f32 - xla_gpu.yield %add : f32 - } - return %ret : tensor<2x4xf32> +func.func @atomic_rmw_f32(%in: tensor<8xf32>, %i: index) -> (tensor<8xf32>) { + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + ^bb0(%current : f32): + %c42 = arith.constant 1.0 : f32 + %add = arith.minimumf %current, %c42 : f32 + xla_gpu.yield %add : f32 } + return %ret : tensor<8xf32> } - // CHECK: @atomic_rmw_f32 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK-NEXT: %[[INIT:.*]] = llvm.load %[[ADDR]] @@ -309,19 +246,16 @@ module { // ----- -module { - func.func @atomic_rmw_f16(%in: tensor<2x4xf16>, %i: index, %j: index) - -> (tensor<2x4xf16>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> { - ^bb0(%current : f16): - %c1 = arith.constant 1.0 : f16 - %add = arith.addf %current, %c1 : f16 - xla_gpu.yield %add : f16 - } - return %ret : tensor<2x4xf16> +func.func @atomic_rmw_f16(%in: tensor<8xf16>, %i: index) + -> (tensor<8xf16>) { + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + ^bb0(%current : f16): + %c1 = arith.constant 1.0 : f16 + %add = arith.addf %current, %c1 : f16 + xla_gpu.yield %add : f16 } + return %ret : tensor<8xf16> } - // CHECK: @atomic_rmw_f16 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]] @@ -342,16 +276,14 @@ module { // ----- -module { - func.func @atomic_rmw_overwrite(%in: tensor<2x4xf16>, %i: index, %j: index) - -> (tensor<2x4xf16>) { - %c1 = arith.constant 1.0 : f16 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> { - ^bb0(%current : f16): - xla_gpu.yield %c1 : f16 - } - return %ret : tensor<2x4xf16> +func.func @atomic_rmw_overwrite(%in: tensor<8xf16>, %i: index) + -> (tensor<8xf16>) { + %c1 = arith.constant 1.0 : f16 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + ^bb0(%current : f16): + xla_gpu.yield %c1 : f16 } + return %ret : tensor<8xf16> } // CHECK: @atomic_rmw_overwrite // CHECK: %[[ADDR:.*]] = llvm.getelementptr @@ -370,26 +302,21 @@ module { // ----- -module { - func.func @shared_complex() -> tensor<10xcomplex> { - %shared = xla_gpu.allocate_shared : tensor<10xcomplex> - return %shared : tensor<10xcomplex> - } +func.func @shared_complex() -> tensor<10xcomplex> { + %shared = xla_gpu.allocate_shared : tensor<10xcomplex> + return %shared : tensor<10xcomplex> } - // CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x struct<(f32, f32)>> // CHECK: @shared_complex // ----- -module { - func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) -> tensor<10xi4> { - %v = tensor.extract %arg[%i] : tensor<10xi4> - %r = tensor.insert %v into %arg[%j] : tensor<10xi4> - return %r : tensor<10xi4> - } +func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) + -> tensor<10xi4> { + %v = tensor.extract %arg[%i] : tensor<10xi4> + %r = tensor.insert %v into %arg[%j] : tensor<10xi4> + return %r : tensor<10xi4> } - // CHECK: @i4_load_store // CHECK: llvm.getelementptr // CHECK-SAME: -> !llvm.ptr, i8 @@ -401,16 +328,14 @@ module { // ----- -module { - func.func @direct_atomic_rmw_overwrite(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_overwrite(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_overwrite // CHECK: %[[C2:.*]] = arith.constant 2 @@ -419,17 +344,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_addi(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.addi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_addi(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.addi %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_addi // CHECK: %[[C2:.*]] = arith.constant 2 @@ -438,17 +361,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_maxsi(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.maxsi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_maxsi(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.maxsi %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_maxsi // CHECK: %[[C2:.*]] = arith.constant 2 @@ -457,17 +378,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_maxui(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.maxui %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_maxui(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.maxui %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_maxui // CHECK: %[[C2:.*]] = arith.constant 2 @@ -476,17 +395,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_minsi(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.minsi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_minsi(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.minsi %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_minsi // CHECK: %[[C2:.*]] = arith.constant 2 @@ -495,17 +412,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_minui(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.minui %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_minui(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.minui %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_minui // CHECK: %[[C2:.*]] = arith.constant 2 @@ -514,17 +429,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_fadd_f32(%in: tensor<2x4xf32>, - %i: index, %j: index) -> (tensor<2x4xf32>) { - %c2 = arith.constant 2.0 : f32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { - ^bb0(%current : f32): - %min = arith.addf %current, %c2 : f32 - xla_gpu.yield %c2 : f32 - } - return %ret : tensor<2x4xf32> +func.func @direct_atomic_rmw_fadd_f32(%in: tensor<8xf32>, + %i: index) -> (tensor<8xf32>) { + %c2 = arith.constant 2.0 : f32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + ^bb0(%current : f32): + %min = arith.addf %current, %c2 : f32 + xla_gpu.yield %c2 : f32 } + return %ret : tensor<8xf32> } // CHECK-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK: %[[C2:.*]] = arith.constant 2 @@ -555,17 +468,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_fadd_f16(%in: tensor<2x4xf16>, - %i: index, %j: index) -> (tensor<2x4xf16>) { - %c2 = arith.constant 2.0 : f16 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> { - ^bb0(%current : f16): - %min = arith.addf %current, %c2 : f16 - xla_gpu.yield %c2 : f16 - } - return %ret : tensor<2x4xf16> +func.func @direct_atomic_rmw_fadd_f16(%in: tensor<8xf16>, + %i: index) -> (tensor<8xf16>) { + %c2 = arith.constant 2.0 : f16 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + ^bb0(%current : f16): + %min = arith.addf %current, %c2 : f16 + xla_gpu.yield %c2 : f16 } + return %ret : tensor<8xf16> } // CHECK-LABEL: @direct_atomic_rmw_fadd_f16 // CHECK-NOT: llvm.atomicrmw fadd @@ -591,17 +502,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<2x4xbf16>, - %i: index, %j: index) -> (tensor<2x4xbf16>) { - %c2 = arith.constant 2.0 : bf16 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xbf16> { - ^bb0(%current : bf16): - %min = arith.addf %current, %c2 : bf16 - xla_gpu.yield %c2 : bf16 - } - return %ret : tensor<2x4xbf16> +func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<8xbf16>, + %i: index) -> (tensor<8xbf16>) { + %c2 = arith.constant 2.0 : bf16 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xbf16> { + ^bb0(%current : bf16): + %min = arith.addf %current, %c2 : bf16 + xla_gpu.yield %c2 : bf16 } + return %ret : tensor<8xbf16> } // CHECK-LABEL: @direct_atomic_rmw_fadd_bf16 // CHECK-NOT: llvm.atomicrmw fadd @@ -613,17 +522,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_fadd_f64(%in: tensor<2x4xf64>, - %i: index, %j: index) -> (tensor<2x4xf64>) { - %c2 = arith.constant 2.0 : f64 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf64> { - ^bb0(%current : f64): - %min = arith.addf %current, %c2 : f64 - xla_gpu.yield %c2 : f64 - } - return %ret : tensor<2x4xf64> +func.func @direct_atomic_rmw_fadd_f64(%in: tensor<8xf64>, + %i: index) -> (tensor<8xf64>) { + %c2 = arith.constant 2.0 : f64 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf64> { + ^bb0(%current : f64): + %min = arith.addf %current, %c2 : f64 + xla_gpu.yield %c2 : f64 } + return %ret : tensor<8xf64> } // CHECK-LABEL: @direct_atomic_rmw_fadd_f64 // CHECK: %[[C2:.*]] = arith.constant 2 @@ -648,17 +555,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_maximumf(%in: tensor<2x4xf32>, - %i: index, %j: index) -> (tensor<2x4xf32>) { - %c2 = arith.constant 2.0 : f32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { - ^bb0(%current : f32): - %min = arith.maximumf %current, %c2 : f32 - xla_gpu.yield %c2 : f32 - } - return %ret : tensor<2x4xf32> +func.func @direct_atomic_rmw_maximumf(%in: tensor<8xf32>, + %i: index) -> (tensor<8xf32>) { + %c2 = arith.constant 2.0 : f32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + ^bb0(%current : f32): + %min = arith.maximumf %current, %c2 : f32 + xla_gpu.yield %c2 : f32 } + return %ret : tensor<8xf32> } // CHECK-LABEL: @direct_atomic_rmw_maximumf @@ -687,18 +592,15 @@ module { // ----- -module { - func.func @atomic_rmw_c32(%in: tensor<2x4xcomplex>, %i: index, %j: index) - -> (tensor<2x4xcomplex>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xcomplex> { - ^bb0(%current : complex): - %a = complex.add %current, %current : complex - xla_gpu.yield %a : complex - } - return %ret : tensor<2x4xcomplex> +func.func @atomic_rmw_c32(%in: tensor<8xcomplex>, %i: index) + -> (tensor<8xcomplex>) { + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xcomplex> { + ^bb0(%current : complex): + %a = complex.add %current, %current : complex + xla_gpu.yield %a : complex } + return %ret : tensor<8xcomplex> } - // CHECK-LABEL: @atomic_rmw_c32 // CHECK: scf.while (%[[ITER_ARG:.*]] = %{{.*}}) : (i64) -> i64 @@ -709,21 +611,18 @@ module { // ----- -module { - func.func @unused_index_switch_results(%i: index) -> index { - %ret, %ret2 = scf.index_switch %i -> tensor<2x4xi32>, tensor<3xf32> - case 0 { - %x, %y = "dummy.op1"() : () -> (tensor<2x4xi32>, tensor<3xf32>) - scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32> - } - default { - %x, %y = "dummy.op2"() : () -> (tensor<2x4xi32>, tensor<3xf32>) - scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32> - } - return %i : index +func.func @unused_index_switch_results(%i: index) -> index { + %ret, %ret2 = scf.index_switch %i -> tensor<8xi32>, tensor<3xf32> + case 0 { + %x, %y = "dummy.op1"() : () -> (tensor<8xi32>, tensor<3xf32>) + scf.yield %x, %y : tensor<8xi32>, tensor<3xf32> + } + default { + %x, %y = "dummy.op2"() : () -> (tensor<8xi32>, tensor<3xf32>) + scf.yield %x, %y : tensor<8xi32>, tensor<3xf32> } + return %i : index } - // CHECK-LABEL: func.func @unused_index_switch_results // CHECK-SAME: (%[[I:.*]]: index) // CHECK-NEXT: scf.index_switch %[[I]] @@ -738,17 +637,14 @@ module { // ----- -module { - func.func @transfer_write(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> { - %c16 = arith.constant 16 : index - %c22 = arith.constant 22 : index - %cst = arith.constant dense<[1.0, 2.0]> : vector<2xf32> - %out = vector.transfer_write %cst, %arg0[%c16] : vector<2xf32>, tensor<43xf32> - %out2 = vector.transfer_write %cst, %out[%c22] : vector<2xf32>, tensor<43xf32> - func.return %out2 : tensor<43xf32> - } +func.func @transfer_write(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> { + %c16 = arith.constant 16 : index + %c22 = arith.constant 22 : index + %cst = arith.constant dense<[1.0, 2.0]> : vector<2xf32> + %out = vector.transfer_write %cst, %arg0[%c16] : vector<2xf32>, tensor<43xf32> + %out2 = vector.transfer_write %cst, %out[%c22] : vector<2xf32>, tensor<43xf32> + func.return %out2 : tensor<43xf32> } - // CHECK-LABEL: @transfer_write // CHECK: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16] // CHECK-NEXT: llvm.store %[[CST:.*]], %[[PTR1]] @@ -757,32 +653,26 @@ module { // ----- -module { - func.func @transfer_read(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> vector<2xf32> { - %c16 = arith.constant 16 : index - %c0 = arith.constant 0.0 : f32 - %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf32>, vector<2xf32> - func.return %out : vector<2xf32> - } +func.func @transfer_read(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> vector<2xf32> { + %c16 = arith.constant 16 : index + %c0 = arith.constant 0.0 : f32 + %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf32>, vector<2xf32> + func.return %out : vector<2xf32> } - // CHECK-LABEL: @transfer_read // CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16] // CHECK-NEXT: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xf32> // ----- -module { - func.func @transfer_write_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}, - %v1: vector<2xi1>, %v2: vector<2xi1>) -> tensor<43xi1> { - %c16 = arith.constant 16 : index - %c22 = arith.constant 22 : index - %out = vector.transfer_write %v1, %arg0[%c16] : vector<2xi1>, tensor<43xi1> - %out2 = vector.transfer_write %v2, %out[%c22] : vector<2xi1>, tensor<43xi1> - func.return %out2 : tensor<43xi1> - } +func.func @transfer_write_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}, + %v1: vector<2xi1>, %v2: vector<2xi1>) -> tensor<43xi1> { + %c16 = arith.constant 16 : index + %c22 = arith.constant 22 : index + %out = vector.transfer_write %v1, %arg0[%c16] : vector<2xi1>, tensor<43xi1> + %out2 = vector.transfer_write %v2, %out[%c22] : vector<2xi1>, tensor<43xi1> + func.return %out2 : tensor<43xi1> } - // CHECK-LABEL: @transfer_write_i1 // CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr // CHECK-SAME: %[[V1:.*]]: vector<2xi1>, %[[V2:.*]]: vector<2xi1>) @@ -795,15 +685,12 @@ module { // ----- -module { - func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vector<2xi1> { - %c16 = arith.constant 16 : index - %false = arith.constant false - %out = vector.transfer_read %arg0[%c16], %false : tensor<43xi1>, vector<2xi1> - func.return %out : vector<2xi1> - } +func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vector<2xi1> { + %c16 = arith.constant 16 : index + %false = arith.constant false + %out = vector.transfer_read %arg0[%c16], %false : tensor<43xi1>, vector<2xi1> + func.return %out : vector<2xi1> } - // CHECK-LABEL: @transfer_read_i1 // CHECK-DAG: %[[C0:.*]] = arith.constant dense<0> : vector<2xi8> // CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16] @@ -811,44 +698,3 @@ module { // CHECK: %[[CAST:.*]] = arith.cmpi ne, %[[LOADED]], %[[C0]] // CHECK: return %[[CAST]] : vector<2xi1> -// ----- - -module { - func.func @transfer_write_i4(%arg0: tensor<43xi4> {xla.slice_index = 1}, - %v1: vector<4xi4>) -> tensor<43xi4> { - %c16 = arith.constant 16 : index - %out = vector.transfer_write %v1, %arg0[%c16] : vector<4xi4>, tensor<43xi4> - func.return %out : tensor<43xi4> - } -} - -// CHECK-LABEL: @transfer_write_i4 -// CHECK-SAME: , %[[V1:.*]]: vector<4xi4> -// CHECK-DAG: %[[A0:.*]] = vector.extract %[[V1]][0] -// CHECK-DAG: %[[A1:.*]] = vector.extract %[[V1]][1] -// CHECK-DAG: %[[A2:.*]] = vector.extract %[[V1]][2] -// CHECK-DAG: %[[A3:.*]] = vector.extract %[[V1]][3] -// CHECK-DAG: vector.insert %[[A0]], {{.*}}[1] -// CHECK-DAG: vector.insert %[[A1]], {{.*}}[0] -// CHECK-DAG: vector.insert %[[A2]], {{.*}}[3] -// CHECK-DAG: vector.insert %[[A3]], {{.*}}[2] - -module { - func.func @transfer_read_i4(%arg0: tensor<43xi4> {xla.slice_index = 1}) -> vector<4xi4> { - %c16 = arith.constant 16 : index - %c0 = arith.constant 0 : i4 - %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xi4>, vector<4xi4> - func.return %out : vector<4xi4> - } -} - -// CHECK-LABEL: @transfer_read_i4 -// CHECK: %[[LOADED:.*]] = llvm.load -// CHECK-DAG: %[[A0:.*]] = vector.extract %[[LOADED]][0] -// CHECK-DAG: %[[A1:.*]] = vector.extract %[[LOADED]][1] -// CHECK-DAG: %[[A2:.*]] = vector.extract %[[LOADED]][2] -// CHECK-DAG: %[[A3:.*]] = vector.extract %[[LOADED]][3] -// CHECK-DAG: vector.insert %[[A0]], {{.*}}[1] -// CHECK-DAG: vector.insert %[[A1]], {{.*}}[0] -// CHECK-DAG: vector.insert %[[A2]], {{.*}}[3] -// CHECK-DAG: vector.insert %[[A3]], {{.*}}[2] diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir index 16e4498b0c5380..09769fc382bd58 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir @@ -1,27 +1,25 @@ -// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file -xla-gpu-vectorize-loads-stores -canonicalize -cse | FileCheck %s - +// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file \ +// RUN: -xla-gpu-vectorize-loads-stores -canonicalize -cse \ +// RUN: | FileCheck %s #map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), domain: d0 in [0, 63], s0 in [0, 1]> -module { - func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 +func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 } - return %outer : f32 + scf.yield %inner : f32 } + return %outer : f32 } - // CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 63]> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor @@ -38,57 +36,25 @@ module { // ----- -module { - func.func @simple_read_2d(%arg0: tensor<64x2xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK-LABEL: @simple_read_2d -// CHECK-SAME: (%[[ARG0:.*]]: tensor -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[C0]]] -// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] -// CHECK-NEXT: vector.extract %[[V]][%[[J]]] - -// ----- - #map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 + 1), domain: d0 in [0, 63], s0 in [0, 1]> -module { - func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c63 = arith.constant 63 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 +func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 } - return %outer : f32 + scf.yield %inner : f32 } + return %outer : f32 } - // CHECK-LABEL: @misaligned_indexing_map // CHECK-NOT: vector.transfer_read @@ -96,50 +62,47 @@ module { #map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 3 + s0), domain: d0 in [0, 63], s0 in [0, 1]> -module { - func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c63 = arith.constant 63 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 +func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 } - return %outer : f32 + scf.yield %inner : f32 } + return %outer : f32 } - // CHECK-LABEL: @misaligned_indexing_map_2 // CHECK-NOT: vector.transfer_read // ----- -module { - func.func @misaligned_shape(%arg0: tensor<64x3xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %extracted = tensor.extract %arg0[%i, %j] : tensor<64x3xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 +#map = #xla_gpu.indexing_map<(d0)[s0] -> (3 * d0 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<192xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 } - return %outer : f32 + scf.yield %inner : f32 } + return %outer : f32 } - // CHECK-LABEL: @misaligned_shape // CHECK-NOT: vector.transfer_read @@ -147,26 +110,23 @@ module { #map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), domain: d0 in [0, 63], s0 in [0, 1]> -module { - func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c63 = arith.constant 63 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 +func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 } - return %outer : f32 + scf.yield %inner : f32 } + return %outer : f32 } - // CHECK-LABEL: @wrong_stride // CHECK-NOT: vector.transfer_read @@ -174,19 +134,20 @@ module { // We could vectorize this as a float vector load of double the size, but we // don't currently. -module { - func.func @simple_read_complex(%arg0: tensor<64x2xcomplex>, %i: index) -> (complex) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex - %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex { - %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xcomplex> - %added = complex.add %iter, %extracted : complex - scf.yield %added : complex - } - return %loop : complex +#map = #xla_gpu.indexing_map<(d0)[s0] -> (2 * d0 + s0), + domain: d0 in [0, 127], s0 in [0, 1]> +func.func @simple_read_complex(%arg0: tensor<128xcomplex>, %i: index) -> (complex) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex + %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xcomplex> + %added = complex.add %iter, %extracted : complex + scf.yield %added : complex } + return %loop : complex } // CHECK-LABEL: @simple_read_complex @@ -195,153 +156,140 @@ module { // ----- // This is vectorizable, but not currently supported. -module { - func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %extracted = tensor.extract %arg0[%j, %i] - : tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 +func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %extracted = tensor.extract %arg0[%j, %i] + : tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 } - return %outer : f32 + scf.yield %inner : f32 } + return %outer : f32 } - // CHECK-LABEL: @layout // CHECK-NOT: vector.transfer_read // ----- -module { - func.func @simple_write(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> { - %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32> - scf.yield %inserted : tensor<16x4xf32> - } - return %loop : tensor<16x4xf32> +func.func @simple_write(%arg0: tensor<64xf32>) -> tensor<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { + %inserted = tensor.insert %cst into %iter[%j] : tensor<64xf32> + scf.yield %inserted : tensor<64xf32> } + return %loop : tensor<64xf32> } - // CHECK-LABEL: @simple_write -// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[I:.*]]: index +// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[V:.*]] = scf.for // CHECK-NEXT: vector.insert // CHECK-NEXT: scf.yield -// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[I]], %[[C0]]] +// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[C0]]] // CHECK-NEXT: return %[[WRITTEN]] // ----- -module { - func.func @write_with_use(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> { - %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32> - "dummy.op1"(%inserted) : (tensor<16x4xf32>) -> () - scf.yield %inserted : tensor<16x4xf32> - } - return %loop : tensor<16x4xf32> +func.func @write_with_use(%arg0: tensor<64xf32>) -> tensor<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { + %inserted = tensor.insert %cst into %iter[%j] : tensor<64xf32> + "dummy.op1"(%inserted) : (tensor<64xf32>) -> () + scf.yield %inserted : tensor<64xf32> } + return %loop : tensor<64xf32> } - // CHECK-LABEL: @write_with_use // CHECK-NOT: transfer_write // ----- -module { - func.func @write_not_to_iter_arg(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> { + func.func @write_not_to_iter_arg(%arg0: tensor<64xf32>) -> tensor<64xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 2 : index %cst = arith.constant 0.0 : f32 - %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> { - %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32> - scf.yield %inserted : tensor<16x4xf32> + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { + %inserted = tensor.insert %cst into %arg0[%j] : tensor<64xf32> + scf.yield %inserted : tensor<64xf32> } - return %loop : tensor<16x4xf32> + return %loop : tensor<64xf32> } -} // CHECK-LABEL: @write_not_to_iter_arg // CHECK-NOT: transfer_write // ----- -module { - func.func @write_not_yielded(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> { - %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32> - scf.yield %arg0 : tensor<16x4xf32> - } - return %loop : tensor<16x4xf32> +func.func @write_not_yielded(%arg0: tensor<64xf32>) -> tensor<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { + %inserted = tensor.insert %cst into %arg0[%j] : tensor<64xf32> + scf.yield %arg0 : tensor<64xf32> } + return %loop : tensor<64xf32> } - // CHECK-LABEL: @write_not_yielded // CHECK-NOT: transfer_write // ----- #map = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), - domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]> -module { - func.func @multiple(%arg0: tensor<32x4096xf32>, %arg1: tensor<4096xbf16>, - %arg2: tensor<32xf32>, %arg3: tensor<32x4096xf32>, - %arg4: index) -> (tensor<32x4096xf32>, f32) { - %cst = arith.constant 1.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c8 = arith.constant 8 : index - %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32> - %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) { - %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) { - %2 = xla_gpu.apply_indexing #map(%j, %arg4)[%i] - %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32> - %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16> - %3 = arith.extf %extracted3 : bf16 to f32 - %4 = arith.addf %extracted2, %3 : f32 - %5 = arith.addf %extracted1, %4 : f32 - %6 = arith.addf %iter3, %5 : f32 - %inserted = tensor.insert %5 into %iter2[%i, %2] : tensor<32x4096xf32> - scf.yield %inserted, %6 : tensor<32x4096xf32>, f32 - } - scf.yield %1#0, %1#1 : tensor<32x4096xf32>, f32 + domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7]> +#map1 = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512), + domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]> +func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, + %arg2: tensor<32xf32>, %arg3: tensor<131072xf32>, + %arg4: index) -> (tensor<131072xf32>, f32) { + %cst = arith.constant 1.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32> + %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<131072xf32>, f32) { + %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<131072xf32>, f32) { + %2 = xla_gpu.apply_indexing #map(%j, %arg4)[%i] + %idx = xla_gpu.apply_indexing #map1(%i, %j, %arg4)[%i] + %extracted2 = tensor.extract %arg0[%idx] : tensor<131072xf32> + %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16> + %3 = arith.extf %extracted3 : bf16 to f32 + %4 = arith.addf %extracted2, %3 : f32 + %5 = arith.addf %extracted1, %4 : f32 + %6 = arith.addf %iter3, %5 : f32 + %inserted = tensor.insert %5 into %iter2[%idx] : tensor<131072xf32> + scf.yield %inserted, %6 : tensor<131072xf32>, f32 } - return %0#0, %0#1 : tensor<32x4096xf32>, f32 + scf.yield %1#0, %1#1 : tensor<131072xf32>, f32 } + return %0#0, %0#1 : tensor<131072xf32>, f32 } - -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 * 512), -// CHECK-SAME: domain: d0 in [0, 255], s0 in [0, 7]> +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 * 512), domain: d0 in [0, 255], s0 in [0, 7]> +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 32 + d1 * 2 + s0 * 512), domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7]> // CHECK-LABEL: @multiple // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]])[%[[I]]] +// CHECK-DAG: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]])[%[[I]]] +// CHECK-DAG: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]], %[[ARG4]])[%[[I]]] // CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]] -// CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]] +// CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[IDX]]] // CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>) // CHECK-DAG: vector.extract %[[READ1]][%[[J]]] // CHECK-DAG: vector.extract %[[READ2]][%[[J]]] @@ -351,23 +299,78 @@ module { // CHECK-NEXT: %[[TO_YIELD:.*]] = arith.addf // CHECK-NEXT: %[[V_NEXT:.*]] = vector.insert %[[TO_INSERT]], %[[V]] [%[[J]]] // CHECK-NEXT: scf.yield %[[TO_YIELD]], %[[V_NEXT]] -// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[INNER]]#1, %{{.*}}[%[[I]], %[[BASE]]] +// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[INNER]]#1, %{{.*}}[%[[IDX]]] // CHECK: scf.yield %[[WRITTEN]], %[[INNER]]#0 // ----- #map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0), domain: d0 in [0, 63], s0 in [0, 1]> +func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> ((d0 mod 16) * 4), +// CHECK-LABEL: @remainder_with_modulo +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] +// CHECK: vector.transfer_read {{.*}}[%[[BASE]]] + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK-LABEL: @remainder_with_modulo_misaligned +// CHECK-NOT: vector.transfer_read + +// ----- + +#map0 = #xla_gpu.indexing_map<(d0) -> (d0 + 5), + domain: d0 in [0, 63]> +#map1 = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> module { - func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { + func.func @apply_indexing_sequence(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c63 = arith.constant 63 : index %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %offset = xla_gpu.apply_indexing #map0(%i) %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla_gpu.apply_indexing #map1(%offset)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -378,19 +381,21 @@ module { } } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> ((d0 mod 16) * 4), -// CHECK-LABEL: @remainder_with_modulo -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2 + 10), +// CHECK-SAME: domain: d0 in [0, 63]> +// CHECK-LABEL: @apply_indexing_sequence +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]] // CHECK: vector.transfer_read {{.*}}[%[[BASE]]] // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0), - domain: d0 in [0, 63], s0 in [0, 1]> + +#map0 = #xla_gpu.indexing_map<(d0) -> (d0 + 5), + domain: d0 in [0, 63]> +#map1 = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> module { - func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { + func.func @apply_indexing_sequence_same_block(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -398,7 +403,10 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + // Usually, this will be hoisted by LICM or folded, so we do not detect + // this pattern. + %offset = xla_gpu.apply_indexing #map0(%i) + %idx = xla_gpu.apply_indexing #map1(%offset)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -409,5 +417,5 @@ module { } } -// CHECK-LABEL: @remainder_with_modulo_misaligned +// CHECK-LABEL: @apply_indexing_sequence_same_block // CHECK-NOT: vector.transfer_read diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc index b4f9dd9371fe2d..9795a96e387f53 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc @@ -40,15 +40,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { +namespace { #define GEN_PASS_DEF_VECTORIZELOADSANDSTORESPASS #include "xla/service/gpu/fusions/transforms/passes.h.inc" -namespace { +using mlir::Value; // Tries to find the stride of a symbol or dimension in an affine expression. // Returns std::nullopt if the stride could not be determined. @@ -120,12 +121,15 @@ int64_t GetAlignmentOfRemainder(mlir::AffineExpr expr, // - checks that the upper bound is 2 or 4. // Returns a vector type with the given upper bound and the tensor's element // type. +// All tensors are 1D after flatten-tensors pass. mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type, mlir::scf::ForOp loop) { - // TODO(jreiffers): Support layouts. if (tensor_type.getEncoding()) { return nullptr; } + if (tensor_type.getRank() != 1) { + return nullptr; + } if (!mlir::VectorType::isValidElementType(tensor_type.getElementType())) { return nullptr; } @@ -138,37 +142,22 @@ mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type, if (vector_size != 2 && vector_size != 4) { return nullptr; // Unsupported vector size. } - if (tensor_type.getRank() > 1 && - tensor_type.getShape().back() % *vector_size) { + if (tensor_type.getShape().back() % *vector_size) { return nullptr; // Misaligned start indices. } return mlir::VectorType::get({*vector_size}, tensor_type.getElementType()); } -std::optional> GetVectorBaseIndices( - mlir::ValueRange indices, mlir::scf::ForOp loop, - mlir::VectorType vector_type, mlir::ImplicitLocOpBuilder& b) { - if (indices.empty()) { - return std::nullopt; - } - - // The major dimensions' indices must all be defined outside the loop. - for (int i = 0; i < indices.size() - 1; ++i) { - if (!indices[i].getParentRegion()->isProperAncestor( - &loop.getBodyRegion())) { - return std::nullopt; - } - } - - mlir::Value induction_var = loop.getInductionVar(); - if (indices.back() == induction_var) { - llvm::SmallVector ret = indices; - ret.back() = b.create(0); - return ret; +std::optional GetVectorBaseIndices(Value index, mlir::scf::ForOp loop, + mlir::VectorType vector_type, + mlir::ImplicitLocOpBuilder& b) { + Value induction_var = loop.getInductionVar(); + if (index == induction_var) { + return b.create(0); } auto apply_indexing = - mlir::dyn_cast_or_null(indices.back().getDefiningOp()); + mlir::dyn_cast_or_null(index.getDefiningOp()); if (!apply_indexing) { return std::nullopt; } @@ -192,6 +181,11 @@ std::optional> GetVectorBaseIndices( ? mlir::getAffineDimExpr(index, b.getContext()) : mlir::getAffineSymbolExpr( index - map.getNumDims(), b.getContext()); + } else if (!operand.getParentRegion()->isProperAncestor( + &loop.getBodyRegion())) { + // If the operand is defined inside the loop, we can't hoist the + // apply_indexing outside the loop. + return std::nullopt; } } if (!induction_var_expr) { @@ -212,11 +206,8 @@ std::optional> GetVectorBaseIndices( operands[induction_var_operand_index] = b.create(0); - llvm::SmallVector ret = indices; - ret.back() = - b.create(operands, apply_indexing.getIndexingMap()) - ->getResult(0); - return ret; + return b.create(operands, apply_indexing.getIndexingMap()) + ->getResult(0); } bool IsConflictFree(mlir::tensor::ExtractOp op) { @@ -246,16 +237,14 @@ struct VectorizeLoad : mlir::OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); b.setInsertionPoint(loop); - auto vector_indices = - GetVectorBaseIndices(op.getIndices(), loop, vector_type, b); - if (!vector_indices) { + auto vector_index = + GetVectorBaseIndices(op.getIndices().front(), loop, vector_type, b); + if (!vector_index) { return rewriter.notifyMatchFailure( op, "the instruction does not access contiguous elements"); } - auto loaded_vector = b.create( - vector_type, op.getTensor(), *vector_indices, - llvm::ArrayRef{true}); + vector_type, op.getTensor(), *vector_index, llvm::ArrayRef{true}); rewriter.replaceOpWithNewOp( op, loaded_vector, loop.getInductionVar()); return mlir::success(); @@ -296,9 +285,9 @@ struct VectorizeStore : mlir::OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); b.setInsertionPoint(loop); - auto vector_indices = - GetVectorBaseIndices(op.getIndices(), loop, vector_type, b); - if (!vector_indices) { + auto vector_index = + GetVectorBaseIndices(op.getIndices().front(), loop, vector_type, b); + if (!vector_index) { return rewriter.notifyMatchFailure( op, "the instruction does not access contiguous elements"); } @@ -313,7 +302,7 @@ struct VectorizeStore : mlir::OpRewritePattern { .getInductionVar(); auto insert_op = yield_b.create( yield_loc, op.getScalar(), bbarg.front(), induction_var); - return llvm::SmallVector{insert_op.getResult()}; + return llvm::SmallVector{insert_op.getResult()}; }; int result_index = op->use_begin()->getOperandNumber(); auto new_for = *loop.replaceWithAdditionalYields( @@ -325,7 +314,7 @@ struct VectorizeStore : mlir::OpRewritePattern { auto filled_vector = new_for->getResults().back(); auto written = b.create( - filled_vector, new_for.getInits()[result_index], *vector_indices, + filled_vector, new_for.getInits()[result_index], *vector_index, llvm::ArrayRef{true}); new_for->getResult(result_index).replaceAllUsesWith(written.getResult()); diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 946815d5b6f152..a06b9d9dc052de 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -45,9 +45,9 @@ limitations under the License. #include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -84,7 +84,8 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) : analysis_(analysis), transpose_(analysis.tiled_transpose()), permutation_(transpose_.permutation), - input_shape_(Permute(transpose_.dimensions, permutation_)) { + input_shape_( + Permute(transpose_.dimensions, InversePermutation(permutation_))) { ConstHloInstructionSet transposes_to_tile; int index = 0; int64_t shmem_usage = 0; @@ -115,15 +116,20 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) auto compute_block_sizes = [this](int vector_size) { vector_size_ = vector_size; block_size_ = kBaseBlockSize * vector_size_; + block_sizes_.assign(input_shape_.size(), 1); if (MostMinorDimensionUnchanged()) { - block_sizes_ = {block_size_, block_size_, input_shape_.back()}; + block_sizes_.back() = input_shape_.back(); + block_sizes_[block_sizes_.size() - 2] = block_size_; + block_sizes_[permutation_[block_sizes_.size() - 2]] = block_size_; } else { - block_sizes_ = {1, 1, block_size_}; - block_sizes_[permutation_[2]] = block_size_; + block_sizes_.back() = block_size_; + block_sizes_[permutation_.back()] = block_size_; + } + output_block_sizes_ = Permute(block_sizes_, permutation_); + block_counts_.resize(block_sizes_.size()); + for (int64_t i = 0; i < block_sizes_.size(); ++i) { + block_counts_[i] = CeilOfRatio(input_shape_[i], block_sizes_[i]); } - block_counts_ = {CeilOfRatio(input_shape_[0], block_sizes_[0]), - CeilOfRatio(input_shape_[1], block_sizes_[1]), - CeilOfRatio(input_shape_[2], block_sizes_[2])}; }; // Compute initial block sizes without vectorization. We use the result to // determine whether we can vectorize. @@ -141,11 +147,12 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) device.threads_per_core_limit(); bool enough_shmem = shmem_usage * elems_per_thread <= device.shared_memory_per_block(); - bool aligned_dims = (input_shape_[2] % vec_size == 0) && - (input_shape_[permutation_[2]] % vec_size == 0); + bool aligned_dims = (input_shape_.back() % vec_size == 0) && + (input_shape_[permutation_.back()] % vec_size == 0); if (MostMinorDimensionUnchanged()) { aligned_dims = - input_shape_[0] % vec_size == 0 && input_shape_[1] % vec_size == 0; + input_shape_[input_shape_.size() - 2] % vec_size == 0 && + input_shape_[permutation_[input_shape_.size() - 2]] % vec_size == 0; } if (enough_work && enough_shmem && aligned_dims) { compute_block_sizes(vec_size); @@ -161,12 +168,8 @@ std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( analysis_.fusion_root(root_index).instruction(), hero)) { // The shape of non-transpose roots are bitcast compatible with the input // shape of transpose heroes. - auto map = ComposeIndexingMaps( - GetIndexing(/*input=*/true, hero.shape(), mlir_context), - GetBitcastMap(hero.shape(), analysis_.fusion_root(root_index).shape(), - mlir_context)); - map.Simplify(); - return map; + return GetIndexing(/*input=*/true, + analysis_.fusion_root(root_index).shape(), mlir_context); } return GetIndexing(/*input=*/false, hero.shape(), mlir_context); } @@ -196,17 +199,29 @@ LaunchDimensions MlirTransposeFusion::launch_dimensions() const { IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing( bool read, mlir::MLIRContext* ctx) const { - auto thread_offsets = - Permute(GetThreadOffsets(ctx), read ? Vector3{0, 1, 2} : permutation_); + auto thread_offsets = GetThreadOffsets(/*read=*/true, ctx); + if (!read) { + // Regarding shared memory indexing, the permutation we need to apply is + // just a swap of the two dimensions that are tiled. + if (MostMinorDimensionUnchanged()) { + std::swap(thread_offsets[thread_offsets.size() - 2], + thread_offsets[permutation_[permutation_.size() - 2]]); + } else { + std::swap(thread_offsets.back(), thread_offsets[permutation_.back()]); + } + } + std::vector dim_var_sizes(6, 1); + dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] = + kNumThreadsPerBlock; if (MostMinorDimensionUnchanged()) { return {mlir::AffineMap::get(6, 3, thread_offsets, ctx), - DimVarsFromTensorSizes({kNumThreadsPerBlock, 1, 1, 1, 1, 1}), + DimVarsFromTensorSizes(dim_var_sizes), RangeVarsFromTensorSizes( - {block_size_ / kNumRows, vector_size_, input_shape_[2]}), + {block_size_ / kNumRows, vector_size_, input_shape_.back()}), {}}; } return {mlir::AffineMap::get(6, 2, thread_offsets, ctx), - DimVarsFromTensorSizes({kNumThreadsPerBlock, 1, 1, 1, 1, 1}), + DimVarsFromTensorSizes(dim_var_sizes), RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}), {}}; } @@ -221,7 +236,9 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( auto shmem_tensor_size = block_sizes_; // Avoid bank conflicts. if (MostMinorDimensionUnchanged()) { - ++shmem_tensor_size[1]; + // Increase the dimension that is actually iterated over. The most minor + // dimension is always completely loaded into the shared memory tile. + ++shmem_tensor_size[shmem_tensor_size.size() - 2]; } else { ++shmem_tensor_size.back(); } @@ -386,7 +403,7 @@ absl::Status MlirTransposeFusion::EmitEntryFunction( } llvm::SmallVector MlirTransposeFusion::GetThreadOffsets( - mlir::MLIRContext* ctx) const { + bool read, mlir::MLIRContext* ctx) const { auto thread = mlir::getAffineDimExpr( KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx); auto loop = mlir::getAffineSymbolExpr(0, ctx); @@ -395,9 +412,10 @@ llvm::SmallVector MlirTransposeFusion::GetThreadOffsets( auto linear_index = loop * loop_stride + thread * vector_size_ + vector; if (MostMinorDimensionUnchanged()) { auto minor_dim = mlir::getAffineSymbolExpr(2, ctx); - linear_index = linear_index * input_shape_[2] + minor_dim; + linear_index = linear_index * input_shape_.back() + minor_dim; } - return DelinearizeInBoundsIndex(linear_index, block_sizes_); + return DelinearizeInBoundsIndex(linear_index, + read ? block_sizes_ : output_block_sizes_); } IndexingMap MlirTransposeFusion::GetIndexing(bool input, @@ -405,23 +423,30 @@ IndexingMap MlirTransposeFusion::GetIndexing(bool input, mlir::MLIRContext* ctx) const { auto raw_id = mlir::getAffineDimExpr( KernelFusionInterface::kIndexingMapBlockIdxDims[0], ctx); - auto block_ids = Permute(DelinearizeInBoundsIndex(raw_id, block_counts_), - input ? Vector3{0, 1, 2} : permutation_); - auto thread_offsets = GetThreadOffsets(ctx); + auto block_ids = DelinearizeInBoundsIndex(raw_id, block_counts_); + if (!input) { + absl::c_copy(Permute(block_ids, permutation_), block_ids.begin()); + } + auto thread_offsets = GetThreadOffsets(input, ctx); + const auto& permuted_block_sizes = input ? block_sizes_ : output_block_sizes_; llvm::SmallVector offsets; for (auto [block_id, block_size, thread] : - llvm::zip(block_ids, block_sizes_, thread_offsets)) { + llvm::zip(block_ids, permuted_block_sizes, thread_offsets)) { offsets.push_back(block_id * block_size + thread); } + std::vector dim_var_sizes(6, 1); + dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] = + kNumThreadsPerBlock; + dim_var_sizes[KernelFusionInterface::kIndexingMapBlockIdxDims[0]] = + Product(block_counts_); auto range_var_sizes = std::vector{block_size_ / kNumRows, vector_size_}; if (MostMinorDimensionUnchanged()) { - range_var_sizes.push_back(input_shape_[2]); + range_var_sizes.push_back(input_shape_.back()); } IndexingMap result{ mlir::AffineMap::get(6, range_var_sizes.size(), offsets, ctx), - DimVarsFromTensorSizes( - {kNumThreadsPerBlock, 1, 1, Product(block_counts_), 1, 1}), + DimVarsFromTensorSizes(dim_var_sizes), RangeVarsFromTensorSizes(range_var_sizes), {}}; auto normalized_shape = diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index 8d9e8b69327f25..afb2777967220e 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -35,7 +36,6 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" -#include "xla/util.h" namespace xla { namespace gpu { @@ -97,13 +97,14 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { mlir::MLIRContext* ctx) const; IndexingMap GetSharedMemoryIndexing(bool read, mlir::MLIRContext* ctx) const; llvm::SmallVector GetThreadOffsets( - mlir::MLIRContext* ctx) const; + bool read, mlir::MLIRContext* ctx) const; bool MostMinorDimensionUnchanged() const; TransposeDescription transpose_; - Vector3 permutation_; + absl::InlinedVector permutation_; std::vector input_shape_; std::vector block_sizes_; // In input elements. + std::vector output_block_sizes_; std::vector block_counts_; int vector_size_; int block_size_; diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index eec531df4634c7..d773503859d934 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "mlir/IR/MLIRContext.h" #include "xla/error_spec.h" #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -87,17 +88,17 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { )")); } -TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { +TEST_F(MlirTransposeFusionTest, ThreadIndexing201_SimplifiedTo021) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule module fusion { - %input = f32[100,64,32] parameter(0) - ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1} + %input = f32[1,6400,32] parameter(0) + ROOT transpose = f32[1,32,6400] transpose(%input), dimensions={0,2,1} } ENTRY entry { - %input = f32[100,64,32] parameter(0) - ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion + %input = f32[1,6400,32] parameter(0) + ROOT %fusion = f32[1,32,6400] fusion(%input), kind=kInput, calls=fusion })")); auto* root = module->entry_computation()->root_instruction(); @@ -108,8 +109,8 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( - d3 floordiv 2, - (d3 mod 2) * 32 + s0 * 4 + d0 floordiv 32, + 0, + d3 * 32 + s0 * 4 + d0 floordiv 32, d0 mod 32 ) domain: @@ -127,9 +128,9 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + 0, d0 floordiv 32 + s0 * 4, - d3 floordiv 2, - (d3 mod 2) * 32 + d0 mod 32 + d3 * 32 + d0 mod 32 ) domain: d0 in [0, 127] @@ -144,6 +145,72 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { )")); } +TEST_F(MlirTransposeFusionTest, Transpose_ThreadIndexing1302) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %param_0 = f32[19, 16, 16, 144] parameter(0) + ROOT %transpose= f32[16, 144, 19, 16] transpose( %param_0), + dimensions={1,3,0,2} + } + ENTRY main { + %param = f32[19, 16, 16, 144] parameter(0) + ROOT %fusion = f32[16, 144, 19, 16] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); + + MlirTransposeFusion fusion(analysis); + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + d3 floordiv 80, + (d3 floordiv 5) mod 16, + d0 floordiv 32 + s0 * 4, + (d3 mod 5) * 32 + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1519] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 3] + s1 in [0, 0] + (d3 mod 5) * 32 + d0 mod 32 in [0, 143] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + (d3 floordiv 5) mod 16, + (d3 mod 5) * 32 + s0 * 4 + d0 floordiv 32, + d3 floordiv 80, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1519] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] + (d3 mod 5) * 8 + s0 in [0, 35] + d0 mod 32 in [0, 15] + )")); +} + TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule module @@ -449,25 +516,6 @@ TEST_F(MlirTransposeFusionTest, Transpose021_NoEpilogue) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirTransposeFusionTest, Transpose_4D) { - auto kHloString = R"( - HloModule Transpose - - %fused_computation { - %param_0 = f64[2,24,6,4] parameter(0) - ROOT %transpose= f64[6,4,2,24] transpose(f64[2,24,6,4] %param_0), - dimensions={2,3,0,1} - } - ENTRY main { - %param = f64[2,24,6,4] parameter(0) - ROOT %fusion = f64[6,4,2,24] fusion(%param), kind=kInput, - calls=%fused_computation - } - )"; - TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirTransposeFusionTest, Transpose_2D) { auto kHloString = R"( HloModule Transpose @@ -483,30 +531,24 @@ TEST_F(MlirTransposeFusionTest, Transpose_2D) { calls=%fused_computation } )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirTransposeFusionTest, Transpose_2D_2) { +TEST_F(MlirTransposeFusionTest, Transpose_4D) { auto kHloString = R"( - HloModule m + HloModule Transpose %fused_computation { - %p0 = f32[17,2820]{0,1} parameter(0) - %p1 = f32[30,17,94] parameter(1) - - %bitcast0 = f32[2,3,5,17,94] bitcast(f32[30,17,94] %p1) - %transpose = f32[2,3,5,94,17] transpose(f32[2,3,5,17,94] %bitcast0), dimensions={0,1,2,4,3} - %bitcast1 = f32[2820,17]{1,0} bitcast(f32[2,3,5,94,17] %transpose) - %bitcast2 = f32[2820,17]{1,0} bitcast(f32[17,2820]{0,1} %p0) - %neg = f32[2820,17]{1,0} negate(f32[2820,17] %bitcast2) - ROOT %add = f32[2820,17]{1,0} add(f32[2820,17] %bitcast1, f32[2820,17]{1,0} %neg) + %param_0 = f32[19, 16, 16, 144] parameter(0) + ROOT %transpose= f32[16, 144, 19, 16] transpose( %param_0), + dimensions={1,3,0,2} } - ENTRY main { - %p1 = f32[30,17,94]{2,1,0} parameter(1) - %p0 = f32[17,2820]{0,1} parameter(0) - ROOT %fusion = f32[2820,17]{1,0} fusion(%p0, %p1), kind=kInput, calls=%fused_computation + %param = f32[19, 16, 16, 144] parameter(0) + ROOT %fusion = f32[16, 144, 19, 16] fusion(%param), kind=kInput, + calls=%fused_computation } )"; TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); @@ -518,19 +560,19 @@ TEST_F(MlirTransposeFusionTest, MultipleRootsForTranspose) { HloModule m %fused_computation { - %iota.0 = s32[200,200] iota(), iota_dimension=1 - %iota.1 = s32[200,200] iota(), iota_dimension=0 - %compare = pred[200,200] compare(%iota.0, %iota.1), direction=GE - %transpose = pred[200,200] transpose(%compare), dimensions={1,0} - %copy = pred[200,200] copy(%transpose) - %copy.1 = pred[200,200] copy(%transpose) - ROOT %tuple = (pred[200,200], pred[200,200], pred[200,200]{1,0}) + %iota.0 = s32[1,200,200] iota(), iota_dimension=1 + %iota.1 = s32[1,200,200] iota(), iota_dimension=0 + %compare = pred[1,200,200] compare(%iota.0, %iota.1), direction=GE + %transpose = pred[1,200,200] transpose(%compare), dimensions={0,2,1} + %copy = pred[1,200,200] copy(%transpose) + %copy.1 = pred[1,200,200] copy(%transpose) + ROOT %tuple = (pred[1,200,200], pred[1,200,200], pred[1,200,200]) tuple(%transpose, %copy, %copy.1) } ENTRY main { ROOT %fusion = - (pred[200,200]{1,0}, pred[200,200]{1,0}, pred[200,200]{1,0}) + (pred[1,200,200], pred[1,200,200], pred[1,200,200]) fusion(), kind=kInput, calls=%fused_computation } )"; @@ -543,13 +585,13 @@ TEST_F(MlirTransposeFusionTest, PartialTile) { HloModule m fused_computation { - %p0 = f64[24,2,6,4] parameter(0) - ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} + %p0 = f64[24,2,24] parameter(0) + ROOT %t = f64[24,2,24] transpose(%p0), dimensions={2,1,0} } ENTRY main { - %p0 = f64[24,2,6,4] parameter(0) - ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput, calls=%fused_computation + %p0 = f64[24,2,24] parameter(0) + ROOT %fusion = f64[24,2,24] fusion(%p0), kind=kInput, calls=%fused_computation } )"; TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); @@ -561,20 +603,19 @@ TEST_F(MlirTransposeFusionTest, MixedIndexing) { HloModule m fused_computation { - %p0 = f64[24,2,6,4] parameter(0) - %bc = f64[24,2,24] bitcast(%p0) - %t1 = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} - %t2 = f64[24,2,24] transpose(%bc), dimensions={2,1,0} + %p0 = f64[24,2,24] parameter(0) + %t1 = f64[24,2,24] transpose(%p0), dimensions={2,1,0} + %b = f64[6,4,2,24] bitcast(%t1) %p1 = f64[] parameter(1) %bc1 = f64[6,4,2,24] broadcast(%p1), dimensions={} %bc2 = f64[24,2,24] broadcast(%p1), dimensions={} - %a1 = f64[6,4,2,24] add(%t1, %bc1) - %a2 = f64[24,2,24] add(%t2, %bc2) + %a1 = f64[6,4,2,24] add(%b, %bc1) + %a2 = f64[24,2,24] add(%t1, %bc2) ROOT %t = (f64[6,4,2,24], f64[24,2,24]) tuple(%a1, %a2) } ENTRY main { - %p0 = f64[24,2,6,4] parameter(0) + %p0 = f64[24,2,24] parameter(0) %p1 = f64[] parameter(1) ROOT %fusion = (f64[6,4,2,24], f64[24,2,24]) fusion(%p0, %p1), kind=kInput, calls=%fused_computation diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index b8113f9600cdb6..d859942239c617 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -42,6 +42,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/mlir_hlo", @@ -58,8 +59,8 @@ cc_library( "//xla/service/gpu:target_util", "//xla/service/gpu:triton_fusion_analysis", "//xla/service/gpu:triton_tiling_propagation", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/fusions/transforms:passes", "//xla/service/gpu/llvm_gpu_backend", "//xla/service/gpu/model:affine_map_printer", @@ -68,6 +69,7 @@ cc_library( "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", "//xla/service/gpu/model:tiled_hlo_computation", "//xla/service/gpu/model:tiled_hlo_instruction", + "//xla/service/gpu/model:triton_emitter_constraints", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", @@ -155,6 +157,7 @@ gentbl_cc_library( cc_library( name = "passes", srcs = [ + "generalize_kernel_signature.cc", "passes.cc", "prevent_mmav3_loop_unrolling.cc", "sparse_extensions.cc", @@ -223,6 +226,7 @@ xla_test( "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", @@ -298,6 +302,7 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu/model:tiled_hlo_computation", + "//xla/service/gpu/model:triton_emitter_constraints", "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -327,6 +332,7 @@ xla_cc_test( "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/service/gpu/model:tiled_hlo_computation", "//xla/service/gpu/model:tiled_hlo_instruction", + "//xla/service/gpu/model:triton_emitter_constraints", "//xla/service/llvm_ir:llvm_util", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", @@ -401,8 +407,14 @@ xla_test( cc_library( name = "triton_support", - srcs = ["triton_support.cc"], - hdrs = ["triton_support.h"], + srcs = [ + "triton_support.cc", + "triton_support_legacy.cc", + ], + hdrs = [ + "triton_support.h", + "triton_support_legacy.h", + ], deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -442,8 +454,10 @@ xla_cc_test( "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_matchers", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 59be95deb7182b..2a95ea833f4bcc 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -56,8 +56,6 @@ absl::Status CreateTritonPipeline( mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, const BlockLevelParameters& block_level_parameters, mt::nvidia_gpu::ClusterInfo& out_cluster_info) { - // TODO(ROCm): Check whether value different than 0 can be used. - const int ccAsInt = 0; // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. const int threadsPerWarp = 32; auto ccRocm = std::get(cc); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/generalize_kernel_signature.cc b/third_party/xla/xla/service/gpu/fusions/triton/generalize_kernel_signature.cc new file mode 100644 index 00000000000000..7ce29350b2d42c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/generalize_kernel_signature.cc @@ -0,0 +1,130 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/fusions/triton/passes.h" + +namespace xla::gpu { +namespace { + +// Extract additional attributes from an LLVM function that are not passed +// to the builder directly. +mlir::SmallVector GetExtraAttrs( + mlir::LLVM::LLVMFuncOp func) { + llvm::StringSet<> registered_attr_names{ + func.getSymNameAttrName().getValue(), + func.getFunctionTypeAttrName().getValue(), + func.getLinkageAttrName().getValue(), + func.getDsoLocalAttrName().getValue(), + func.getCConvAttrName().getValue(), + func.getArgAttrsAttrName().getValue(), + func.getFunctionEntryCountAttrName().getValue()}; + return llvm::to_vector( + llvm::make_filter_range(func->getAttrs(), [&](mlir::NamedAttribute attr) { + return !registered_attr_names.contains(attr.getName().getValue()); + })); +} + +// Strip address spaces from function parameters. +void StripParameterAddressSpaces(mlir::RewriterBase& rewriter, + mlir::LLVM::LLVMFuncOp func) { + // Figure out what the new signature should be. + mlir::LLVM::LLVMFunctionType func_ty = func.getFunctionType(); + mlir::SmallVector generic_func_params( + llvm::map_range(func_ty.getParams(), [](mlir::Type type) -> mlir::Type { + auto ptr_ty = mlir::dyn_cast(type); + if (!ptr_ty) return type; + if (ptr_ty.getAddressSpace() != mlir::NVVM::kGlobalMemorySpace) + return type; + return mlir::LLVM::LLVMPointerType::get(ptr_ty.getContext()); + })); + mlir::LLVM::LLVMFunctionType generic_func_ty = + func_ty.clone(generic_func_params, func_ty.getReturnTypes()); + + // Create a function with the new signature. + mlir::SmallVector arg_attrs(llvm::map_range( + func.getArgAttrsAttr().getValue(), [](mlir::Attribute attr) { + return mlir::cast(attr); + })); + auto generic_func = rewriter.create( + func.getLoc(), func.getSymName(), generic_func_ty, func.getLinkage(), + func.getDsoLocal(), func.getCConv(), /*comdat=*/nullptr, + GetExtraAttrs(func), arg_attrs, func.getFunctionEntryCount()); + + // Convert generic address spaces back to original ones within the function + // body. + mlir::Block* entry = generic_func.addEntryBlock(rewriter); + rewriter.setInsertionPointToEnd(entry); + mlir::SmallVector converted_args; + for (auto [arg, type] : + llvm::zip(generic_func.getArguments(), func_ty.getParams())) { + mlir::Value converted = arg; + if (arg.getType() != type) { + converted = + rewriter.create(arg.getLoc(), type, arg); + } + converted_args.push_back(converted); + } + + // Move the rest of function body from the original function. + rewriter.cloneRegionBefore(func.getBody(), generic_func.getBody(), + generic_func.getBody().end()); + rewriter.eraseOp(func); + rewriter.mergeBlocks(entry->getNextNode(), entry, converted_args); +} + +#define GEN_PASS_DEF_GENERALIZEKERNELSIGNATUREPASS +#include "xla/service/gpu/fusions/triton/passes.h.inc" + +// Rewrite signatures of kernel functions to use generic data pointers and +// cast them to global ones within the kernel. +struct GeneralizeKernelSignaturePass + : public impl::GeneralizeKernelSignaturePassBase< + GeneralizeKernelSignaturePass> { + void runOnOperation() override { + mlir::IRRewriter rewriter(&getContext()); + getOperation()->walk([&](mlir::LLVM::LLVMFuncOp func) { + if (!func->hasAttr(mlir::NVVM::NVVMDialect::getKernelFuncAttrName())) { + return; + } + rewriter.setInsertionPointAfter(func); + StripParameterAddressSpaces(rewriter, func); + }); + } +}; + +} // namespace + +std::unique_ptr CreateGeneralizeKernelSignaturePass() { + return std::make_unique(); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.cc b/third_party/xla/xla/service/gpu/fusions/triton/passes.cc index fefb0f7f398f87..0d0ff381874644 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/passes.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.cc @@ -20,19 +20,14 @@ limitations under the License. #include "mlir/IR/Visitors.h" namespace xla::gpu { -namespace { - -using ::mlir::WalkResult; - -} // namespace bool ContainsOp(mlir::Operation* op, llvm::function_ref fn) { - return op - ->walk([&](mlir::Operation* nested_op) { - return fn(nested_op) ? WalkResult::interrupt() : WalkResult::advance(); - }) - .wasInterrupted(); + auto visitor = [&](mlir::Operation* nested_op) { + return fn(nested_op) ? mlir::WalkResult::interrupt() + : mlir::WalkResult::advance(); + }; + return op->walk(visitor).wasInterrupted(); } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.h b/third_party/xla/xla/service/gpu/fusions/triton/passes.h index 39066ba0f40654..9bb3ab6a92d6cf 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/passes.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.h @@ -36,6 +36,7 @@ std::unique_ptr CreateSparseLocalLoadToLLVMPass(); std::unique_ptr CreateSparseDotOpToLLVMPass(); std::unique_ptr CreateSparseWGMMAOpToLLVMPass(); std::unique_ptr CreatePreventMmaV3LoopUnrollingPass(); +std::unique_ptr CreateGeneralizeKernelSignaturePass(); // Returns true if the `op` contains an operation in it's regions that satisfies // the `fn`. diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.td b/third_party/xla/xla/service/gpu/fusions/triton/passes.td index b1366d0bd8c7e3..f437a44b37c8a4 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/passes.td +++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.td @@ -19,7 +19,7 @@ limitations under the License. include "mlir/Pass/PassBase.td" def SparseAddEncodingPass : Pass<"sparse-add-encoding", "mlir::ModuleOp"> { - let summary = "Adds sparse dot encoding."; + let summary = "Add sparse encoding for all the arguments of a SparseDotOp."; let options = [ Option<"num_warps_", "num-warps", "int32_t", /*default=*/"4", "Number of warps">, @@ -36,6 +36,10 @@ def SparseAddEncodingPass : Pass<"sparse-add-encoding", "mlir::ModuleOp"> { def SparseBlockedToMMAPass : Pass<"sparse-blocked-to-mma", "mlir::ModuleOp"> { let summary = "Add convert layouts to/from MMA before and after SparseDotOp."; + let description = [{ + Add convert layouts to and from MMA before and after SparseDotOp. In MMAV3, + shared memory allocations will be used for A and B operands. + }]; let dependentDialects = [ "triton::gpu::TritonGPUDialect", ]; @@ -91,4 +95,15 @@ def PreventMmaV3LoopUnrollingPass let constructor = "CreatePreventMmaV3LoopUnrollingPass()"; } + +def GeneralizeKernelSignaturePass + : Pass<"generalize-kernel-signature"> { + let summary = "Rewrite kernels to use generic data pointer arguments."; + let description = [{ + Rewrite signatures of kernel functions from global pointers to generic + pointers and cast them to global ones within the kernel. + }]; + let constructor = "CreateGeneralizeKernelSignaturePass()"; +} + #endif // XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc index 5b2d331a136467..7bd6a3b04fbaa2 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc @@ -99,7 +99,17 @@ namespace { #define GEN_PASS_DEF_SPARSEWGMMAOPTOLLVMPASS #include "xla/service/gpu/fusions/triton/passes.h.inc" -// Add sparse encoding for all the arguments of a SparseDotOp. +constexpr int kThreadsPerWarp = 32; +// Each 16x16 original sparse matrix tile requires 16 metadata values of +// 16-bit size, where the first thread (T0) in each 4-thread group holds two +// such values in a register (32-bit). +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-matrix-storage +constexpr int kTileSize = 16; +constexpr int kMetaElementsBitSize = 2; +// Metadata elements are packed into 16-bits values. +constexpr int kMetaElementsPerPackedValue = 16 / kMetaElementsBitSize; +constexpr int kColumnsPerCtaTile = kTileSize / kMetaElementsPerPackedValue; + struct SparseAddEncoding : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -215,8 +225,6 @@ struct SparseAddEncodingPass } }; -// Add convert layouts to and from MMA before and after SparseDotOp. In MMAV3, -// shared memory allocations will be used for A and B operands. class SparseBlockedToMMA : public RewritePattern { using ConvertLayoutOp = triton::gpu::ConvertLayoutOp; using SparseDotOp = triton::gpu::SparseDotOp; @@ -336,6 +344,42 @@ struct SparseBlockedToMMAPass } }; +struct SparseRemoveLayoutConversionPass + : public impl::SparseRemoveLayoutConversionPassBase< + SparseRemoveLayoutConversionPass> { + void runOnOperation() override { + getOperation().walk([&](triton::gpu::ConvertLayoutOp op) { + ImplicitLocOpBuilder builder(op.getLoc(), op); + // Skip if the source is already in shared memory. + auto src_encoding = + cast(op.getSrc().getType()).getEncoding(); + if (isa(src_encoding)) { + return; + } + auto dst_type = cast(op.getType()); + // Skip if the destination is not a sparse dot meta. + if (!isa( + dst_type.getEncoding())) { + return; + } + + auto shared_layout = builder.getAttr( + // Packing metadata elements together. No swizzling. + /*vec=*/kMetaElementsPerPackedValue, /*perPhase=*/1, /*maxPhase=*/1, + triton::gpu::getOrder(src_encoding), + triton::gpu::getCTALayout(src_encoding)); + auto mem_type = triton::MemDescType::get( + dst_type.getShape(), dst_type.getElementType(), shared_layout, + builder.getAttr()); + Value alloc = + builder.create(mem_type, op.getSrc()); + Value convert = builder.create(dst_type, alloc); + op.replaceAllUsesWith(convert); + op.erase(); + }); + } +}; + class SparseLocalLoadToLLVM : public ConvertOpToLLVMPattern { public: @@ -359,17 +403,6 @@ class SparseLocalLoadToLLVM LogicalResult lowerSharedToSparseMeta( triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - constexpr int kThreadsPerWarp = 32; - // Each 16x16 original sparse matrix tile requires 16 metadata values of - // 16-bit size, where the first thread (T0) in each 4-thread group holds two - // such values in a register (32-bit). - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-matrix-storage - constexpr int kTileSize = 16; - constexpr int kMetaElementsBitSize = 2; - // Metadata elements are packed into 16-bits values. - constexpr int kMetaElementsPerPackedValue = 16 / kMetaElementsBitSize; - constexpr int kColumnsPerCtaTile = kTileSize / kMetaElementsPerPackedValue; - auto loc = op.getLoc(); auto load_sparse_encoding = cast( cast(op.getResult().getType()).getEncoding()); @@ -451,38 +484,6 @@ class SparseLocalLoadToLLVM } }; -struct SparseRemoveLayoutConversionPass - : public impl::SparseRemoveLayoutConversionPassBase< - SparseRemoveLayoutConversionPass> { - void runOnOperation() override { - getOperation().walk([&](triton::gpu::ConvertLayoutOp op) { - ImplicitLocOpBuilder builder(op.getLoc(), op); - auto srcEncoding = - cast(op.getSrc().getType()).getEncoding(); - if (isa(srcEncoding)) { - return; - } - auto dstType = cast(op.getType()); - if (!isa(dstType.getEncoding())) { - return; - } - - auto ctaLayout = triton::gpu::getCTALayout(srcEncoding); - auto sharedLayout = builder.getAttr( - 8, 1, 1, triton::gpu::getOrder(srcEncoding), ctaLayout); - auto sharedMemorySpace = - builder.getAttr(); - auto memType = - triton::MemDescType::get(dstType.getShape(), dstType.getElementType(), - sharedLayout, sharedMemorySpace); - Value alloc = - builder.create(memType, op.getSrc()); - Value convert = builder.create(dstType, alloc); - op.replaceAllUsesWith(convert); - op.erase(); - }); - } -}; bool IsLocalLoadWithSparseEncoding(Operation *op) { auto local_load = mlir::dyn_cast(op); @@ -501,7 +502,7 @@ struct SparseLocalLoadToLLVMPass // Allocate shared memory and set barrier // This is also done in the TritonGPUToLLVMPass but we need to do it before // we write the local load op to LLVM to have barriers in the right place. - // See b/351986109. + // See b/358375493. ModuleAllocation allocation(getOperation()); ModuleMembarAnalysis membar_pass(&allocation); membar_pass.run(); @@ -656,7 +657,6 @@ LogicalResult convertSparseMMA(triton::gpu::SparseDotOp op, // ----- Hopper implementation. -constexpr int kThreadsPerWarp = 32; constexpr int kWarpsInGroup = 4; constexpr int kMmaAccumulatorCount = 2; constexpr int kMmaLineSize = 128; @@ -867,6 +867,8 @@ struct SparseDotOpToLLVMPass TritonGPUToLLVMTypeConverter typeConverter(context, option); RewritePatternSet patterns(context); patterns.add(typeConverter); + // TODO(b/358375493): Remove this once TritonGPUToLLVMTypeConverter is + // splitted into smaller passes. populateGpuToNVVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 5488e48dc3016d..97c3576f9d092d 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -47,7 +47,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" #include "llvm/Support/FileSystem.h" -#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" @@ -77,7 +76,6 @@ limitations under the License. #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -87,7 +85,6 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/TypeID.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" @@ -110,9 +107,10 @@ limitations under the License. #include "xla/service/algorithm_util.h" #include "xla/service/dump.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/fusions/triton/passes.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" @@ -123,6 +121,7 @@ limitations under the License. #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/gpu/target_util.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/gpu/triton_tiling_propagation.h" @@ -135,6 +134,7 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" @@ -153,7 +153,6 @@ namespace gpu { namespace ma = ::mlir::arith; namespace mm = ::mlir::math; -namespace ml = ::mlir::LLVM; namespace mn = ::mlir::NVVM; namespace mt = ::mlir::triton; @@ -188,6 +187,10 @@ absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { return b.getI1Type(); case S8: return b.getI8Type(); + case S4: // The unpacking to i8 is supported by the emitter. + // We pass the s4 tensor as i8 tensor with the minor dimension having 2x + // less elements and unpack in the inner loop of the triton kernel. + return b.getI8Type(); case F8E5M2: return b.getFloat8E5M2Type(); case F8E4M3FN: @@ -647,16 +650,27 @@ struct DimProperties { int split_value; }; -absl::StatusOr EmitBroadcast( - ImplicitLocOpBuilder& b, const TritonFusionAnalysis* analysis, - TritonFusionAnalysis::Scope scope, - absl::Span tiled_dimensions, - const HloInstruction& broadcast, Value input) { +struct Side { + explicit Side(TritonFusionAnalysis::Scope scope, + std::vector tiled_dims = {}, + std::optional batch_dim_idx = std::nullopt) + : scope(scope), tiled_dims(tiled_dims), batch_dim_idx(batch_dim_idx) {} + TritonFusionAnalysis::Scope scope; + std::vector tiled_dims; + std::optional batch_dim_idx; + int64_t unpack_dim_idx = 0; +}; + +absl::StatusOr EmitBroadcast(ImplicitLocOpBuilder& b, + const TritonFusionAnalysis* analysis, + const Side& side, + const HloInstruction& broadcast, + Value input) { TF_RET_CHECK(analysis != nullptr); std::vector out_shape; - for (const DimProperties& dim : tiled_dimensions) { + for (const DimProperties& dim : side.tiled_dims) { const TensorIterationSpec::DimIterationSpec* spec = - analysis->IterSpec(scope, &broadcast, dim.index); + analysis->IterSpec(side.scope, &broadcast, dim.index); if (spec != nullptr && spec->at(0).stride > 0) { out_shape.push_back(dim.block_size); } @@ -673,10 +687,10 @@ absl::StatusOr EmitBroadcast( // Add broadcasted dimensions one by one. Value expanded_input = tensor_input; int dim_idx = 0; - for (const DimProperties& dim : tiled_dimensions) { - if (analysis->IterSpec(scope, &broadcast, dim.index) != nullptr && - analysis->IterSpec(scope, &broadcast, dim.index)->at(0).stride > 0) { - if (analysis->IterSpec(scope, broadcast.operand(0), dim.index) == + for (const DimProperties& dim : side.tiled_dims) { + if (auto* spec = analysis->IterSpec(side.scope, &broadcast, dim.index); + spec != nullptr && spec->at(0).stride > 0) { + if (analysis->IterSpec(side.scope, broadcast.operand(0), dim.index) == nullptr) { // Broadcasted dimension. expanded_input = b.create(expanded_input, dim_idx); @@ -690,8 +704,7 @@ absl::StatusOr EmitBroadcast( absl::StatusOr EmitScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const TritonFusionAnalysis* analysis, TritonFusionAnalysis::Scope scope, - absl::Span tiled_dimensions, + const TritonFusionAnalysis* analysis, const Side& side, absl::Span instructions, absl::flat_hash_map& values); @@ -797,7 +810,7 @@ absl::StatusOr EmitReduce( TF_ASSIGN_OR_RETURN( Value result, EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, - TritonFusionAnalysis::Scope::OUTPUT, {}, to_emit, + Side(TritonFusionAnalysis::Scope::OUTPUT), to_emit, region_values)); b.create(SmallVector({result})); b.setInsertionPointAfter(reduction); @@ -851,7 +864,7 @@ absl::StatusOr EmitNestedFusion( TF_RET_CHECK(to_emit.back() == fusion_computation->root_instruction()); return EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, - TritonFusionAnalysis::Scope::OUTPUT, {}, to_emit, + Side(TritonFusionAnalysis::Scope::OUTPUT), to_emit, region_values); } @@ -1014,19 +1027,58 @@ absl::StatusOr EmitTiledScope( return values[tiled_computation.GetRoot()]; } +// Emit sequence of operations for unpacking 2xi4 -> i8. +absl::StatusOr EmitUnpackInt4(ImplicitLocOpBuilder& b, + const HloInstruction* hlo, + const Side& side, Value& value) { + VLOG(6) << "EmitUnpackInt4: " << hlo->ToString(); + auto input_type = mlir::cast(value.getType()); + if (input_type.getShape().size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("UnpackInt4 works only for 2d inputs: ", hlo->ToString())); + } + // We use shifts instead the mask because we need to keep the sign bit. + Value shift4 = + Splat(b, CreateConst(b, b.getI8Type(), 4), input_type.getShape()); + Value lo = b.create(b.create(value, shift4), shift4); + Value hi = b.create(value, shift4); + Value result = b.create(hi, lo); + SmallVector result_shape(input_type.getShape()); + result_shape[side.unpack_dim_idx] *= 2; + if (side.unpack_dim_idx == 0) { + result = b.create(result, b.getDenseI32ArrayAttr({0, 2, 1})); + } + auto type = mlir::RankedTensorType::get(result_shape, b.getI8Type()); + return b.create(type, result, /*allow_reorder=*/false); +} + // Emit sequence of instructions using compatible tiling ordered producers // before consumers. absl::StatusOr EmitScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const TritonFusionAnalysis* analysis, TritonFusionAnalysis::Scope scope, - absl::Span tiled_dimensions, + const TritonFusionAnalysis* analysis, const Side& side, absl::Span instructions, absl::flat_hash_map& values) { for (const HloInstruction* hlo : instructions) { Value result; - if (hlo->opcode() == HloOpcode::kConcatenate || - hlo->opcode() == HloOpcode::kDynamicSlice) { + if (hlo->opcode() == HloOpcode::kConvert && + hlo->operand(0)->shape().element_type() == S4) { + if (!hlo->GetModule() + ->config() + .debug_options() + .xla_gpu_enable_triton_gemm_int4()) { + return absl::UnimplementedError( + "Int4 support is not enabled in the debug options."); + } + + TF_ASSIGN_OR_RETURN( + auto unpacked, EmitUnpackInt4(b, hlo, side, values[hlo->operand(0)])); + std::vector operands({unpacked}); + TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, + device_info, *hlo, operands)); + } else if (hlo->opcode() == HloOpcode::kConcatenate || + hlo->opcode() == HloOpcode::kDynamicSlice) { // Parameter loads and their concatenations are handled outside EmitScope. TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); continue; @@ -1042,9 +1094,8 @@ absl::StatusOr EmitScope( // Splat makes it a tensor to avoid type mismatches. result = Splat(b, constant, {}); } else if (hlo->opcode() == HloOpcode::kBroadcast) { - TF_ASSIGN_OR_RETURN( - result, EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo, - values[hlo->operand(0)])); + TF_ASSIGN_OR_RETURN(result, EmitBroadcast(b, analysis, side, *hlo, + values[hlo->operand(0)])); } else if (HloInstruction::IsOpElementwise(hlo->opcode())) { std::vector operands; operands.reserve(hlo->operands().size()); @@ -1079,86 +1130,6 @@ absl::StatusOr EmitScope( return values[instructions.back()]; } -// Extract additional attributes from an LLVM function that are not passed -// to the builder directly. -SmallVector GetExtraAttrs(ml::LLVMFuncOp func) { - llvm::StringSet<> registered_attr_names{ - func.getSymNameAttrName().getValue(), - func.getFunctionTypeAttrName().getValue(), - func.getLinkageAttrName().getValue(), - func.getDsoLocalAttrName().getValue(), - func.getCConvAttrName().getValue(), - func.getArgAttrsAttrName().getValue(), - func.getFunctionEntryCountAttrName().getValue()}; - return llvm::to_vector( - llvm::make_filter_range(func->getAttrs(), [&](mlir::NamedAttribute attr) { - return !registered_attr_names.contains(attr.getName().getValue()); - })); -} - -// Strip address spaces from function parameters. -void StripParameterAddressSpaces(mlir::RewriterBase& rewriter, - ml::LLVMFuncOp func) { - // Figure out what the new signature should be. - ml::LLVMFunctionType func_ty = func.getFunctionType(); - SmallVector generic_func_params( - llvm::map_range(func_ty.getParams(), [](Type type) -> Type { - auto ptr_ty = mlir::dyn_cast(type); - if (!ptr_ty) return type; - if (ptr_ty.getAddressSpace() != mn::kGlobalMemorySpace) return type; - return ml::LLVMPointerType::get(ptr_ty.getContext()); - })); - ml::LLVMFunctionType generic_func_ty = - func_ty.clone(generic_func_params, func_ty.getReturnTypes()); - - // Create a function with the new signature. - SmallVector arg_attrs(llvm::map_range( - func.getArgAttrsAttr().getValue(), [](mlir::Attribute attr) { - return mlir::cast(attr); - })); - auto generic_func = rewriter.create( - func.getLoc(), func.getSymName(), generic_func_ty, func.getLinkage(), - func.getDsoLocal(), func.getCConv(), /*comdat=*/nullptr, - GetExtraAttrs(func), arg_attrs, func.getFunctionEntryCount()); - - // Convert generic address spaces back to original ones within the function - // body. - mlir::Block* entry = generic_func.addEntryBlock(rewriter); - rewriter.setInsertionPointToEnd(entry); - SmallVector converted_args; - for (auto [arg, type] : - llvm::zip(generic_func.getArguments(), func_ty.getParams())) { - Value converted = arg; - if (arg.getType() != type) { - converted = rewriter.create(arg.getLoc(), type, arg); - } - converted_args.push_back(converted); - } - - // Move the rest of function body from the original function. - rewriter.cloneRegionBefore(func.getBody(), generic_func.getBody(), - generic_func.getBody().end()); - rewriter.eraseOp(func); - rewriter.mergeBlocks(entry->getNextNode(), entry, converted_args); -} - -// Rewrite signatures of kernel functions to use generic data pointers and -// cast them to global ones within the kernel. -struct GeneralizeKernelSignaturePass - : mlir::PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GeneralizeKernelSignaturePass); - void runOnOperation() override { - mlir::IRRewriter rewriter(&getContext()); - getOperation()->walk([&](ml::LLVMFuncOp func) { - if (!func->hasAttr(mn::NVVMDialect::getKernelFuncAttrName())) { - return; - } - rewriter.setInsertionPointAfter(func); - StripParameterAddressSpaces(rewriter, func); - }); - } -}; - const TensorIterationSpec::DimIterationSpec* GetLhsNoncontractingSplitSpec( const TritonFusionAnalysis& analysis, int64_t lhs_noncontracting_dim_idx) { const TensorIterationSpec::DimIterationSpec* result = nullptr; @@ -1385,12 +1356,6 @@ absl::Status ValidateMatMulConfig(const TritonGemmConfig& config, return absl::OkStatus(); } -struct Side { - TritonFusionAnalysis::Scope scope; - std::vector tiled_dims; - std::optional batch_dim_idx; -}; - // if (index < limits[0]) { // return choices[0]; // } else if (index < limits[1]) { @@ -1522,11 +1487,10 @@ class MatMulEmitterHelper { return to_emit; } - Value MakeInput(Side& side, int64_t operand_index, + Value MakeInput(const Side& side, int64_t operand_index, absl::flat_hash_map& values) { return *EmitScope( - b_, libdevice_path_, device_info_, &analysis_, side.scope, - side.tiled_dims, + b_, libdevice_path_, device_info_, &analysis_, side, dot_instr_->parent()->MakeInstructionPostOrderFrom( const_cast(*dot_instr_->operand(operand_index))), values); @@ -1551,6 +1515,7 @@ class MatMulEmitterHelper { Value base; std::vector bounds; std::vector strides; + std::vector strides_sizes; // We use it to detect the minor dim. // Offsets from tensor origin, same for all thread blocks. std::vector tensor_offsets; std::vector block_dims; @@ -1641,7 +1606,9 @@ class MatMulEmitterHelper { for (const HloInstruction* input : inputs) { specs.push_back( analysis_.IterSpec(side.scope, input, properties.index)); - input_strides.push_back(Cst64(specs.back()->at(0).stride)); + const auto stride = specs.back()->at(0).stride; + strides_sizes.push_back(stride); + input_strides.push_back(Cst64(stride)); input_offsets.push_back(b_.create( pid_offset, Cst32(specs.back()->at(0).slice_start))); input_bounds.push_back(Cst64(specs.back()->at(0).count)); @@ -1816,9 +1783,14 @@ class MatMulEmitterHelper { if (has_batch_offset) { Value pid_batch = b_.create(launch_config_.batch_program_id_dim); + Value pid_offset_batch = b_.create( b_.create(Cst(offset_batch), ConvertScalar(pid_batch)), batch_stride); + + if (hlo->shape().element_type() == PrimitiveType::S4) { + pid_offset_batch = b_.create(pid_offset_batch, Cst(2)); + } base = AddPtr(b_, base, pid_offset_batch); } @@ -1837,6 +1809,18 @@ class MatMulEmitterHelper { // Load of a scalar. return base; } + if (hlo->shape().element_type() == PrimitiveType::S4) { + // Divide the stride by 2 for S4 inputs except for the minor dimension. + for (int i = 0; i < strides.size(); ++i) { + // We assume that the pack happens along the minor dimension. + if (strides_sizes[i] == 1) { // minor dimension + auto s4_bound = b_.create(bounds[i], Cst64(2)); + bounds[i] = s4_bound; + continue; + } + strides[i] = b_.create(strides[i], Cst64(2)); + } + } auto tensor_ptr = mlir::cast( b_.create(base, bounds, strides, tensor_offsets, block_dims, dim_order) @@ -2156,6 +2140,140 @@ absl::Status CheckGemmTilingComplexityHeuristic( return absl::OkStatus(); } +class Scopes { + public: + Scopes(ImplicitLocOpBuilder& b, const TritonFusionAnalysis& analysis, + const MatMulDims& dims, const TritonGemmConfig& config, + const MatMulLaunchConfig launch_config, bool is_sparse) + : lhs_(TritonFusionAnalysis::Scope::LHS), + rhs_(TritonFusionAnalysis::Scope::RHS), + out_(TritonFusionAnalysis::Scope::OUTPUT) { + constexpr int group_m = 8; + const int64_t width = group_m * launch_config.grid_n; + + auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; + + auto pid_nc = b.create( + launch_config.noncontracting_program_id_dim); + pid_k_ = (config.split_k > 1) + ? b.create(mt::ProgramIDDim::Z) + : Value{}; + + auto group_id = b.create(pid_nc, c32(width)); + ma::ConstantOp group_m_op = c32(group_m); + auto first_pid_m = b.create(group_id, group_m_op); + auto sub0 = b.create(c32(launch_config.grid_m), first_pid_m); + auto group_size = b.create( + b.create(ma::CmpIPredicate::slt, sub0, group_m_op), sub0, + group_m_op); + + pid_m_ = b.create(first_pid_m, + b.create(pid_nc, group_size)); + + pid_n_ = b.create(b.create(pid_nc, c32(width)), + group_size); + + int lhs_non_contracting_block_size = config.block_m; + int lhs_contracting_block_size = config.block_k; + int lhs_unpack_dim_idx = 0; + if (is_int4_param(analysis, TritonFusionAnalysis::Scope::LHS)) { + if (dims.lhs_contracting_dim_idx > dims.lhs_noncontracting_dim_idx) { + // lhs is int4 and the contracting dimension is minor. + lhs_contracting_block_size /= 2; + lhs_unpack_dim_idx = 1; + } else { + // lhs is int4 and the contracting dimension is major. + lhs_non_contracting_block_size /= 2; + lhs_unpack_dim_idx = 0; + } + } + if (is_sparse) { + lhs_contracting_block_size /= 2; + } + lhs_.tiled_dims = { + DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_, + lhs_non_contracting_block_size, + /*split_value=*/1), + DimProperties(dims.lhs_contracting_dim_idx, pid_k_, + lhs_contracting_block_size, config.split_k)}; + lhs_.batch_dim_idx = dims.lhs_batch_dim_idx; + lhs_.unpack_dim_idx = lhs_unpack_dim_idx; + + int rhs_contracting_block_size = config.block_k; + int rhs_non_contracting_block_size = config.block_n; + int rhs_unpack_dim_idx = 0; + if (is_int4_param(analysis, TritonFusionAnalysis::Scope::RHS)) { + if (dims.rhs_contracting_dim_idx > dims.rhs_noncontracting_dim_idx) { + // rhs is int4 and the contracting dimension is minor. + rhs_contracting_block_size /= 2; + rhs_unpack_dim_idx = 0; + } else { + // rhs is int4 and the contracting dimension is major. + rhs_non_contracting_block_size /= 2; + rhs_unpack_dim_idx = 1; + } + } + rhs_.tiled_dims = { + DimProperties(dims.rhs_contracting_dim_idx, pid_k_, + rhs_contracting_block_size, config.split_k), + DimProperties(dims.rhs_noncontracting_dim_idx, pid_n_, + rhs_non_contracting_block_size, + /*split_value=*/1)}; + rhs_.batch_dim_idx = dims.rhs_batch_dim_idx; + rhs_.unpack_dim_idx = rhs_unpack_dim_idx; + + out_.tiled_dims = {DimProperties(dims.out_lhs_noncontracting_dim_idx, + pid_m_, config.block_m, + /*split_value=*/1), + DimProperties(dims.out_rhs_noncontracting_dim_idx, + pid_n_, config.block_n, + /*split_value=*/1)}; + out_.batch_dim_idx = dims.out_batch_dim_idx; + + if (is_sparse) { + meta_ = Side{TritonFusionAnalysis::Scope::META, + /*tiled_dims=*/ + {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_, + config.block_m, + /*split_value=*/1), + DimProperties(dims.lhs_contracting_dim_idx, pid_k_, + config.block_k / 16, config.split_k)}, + dims.lhs_batch_dim_idx}; + } + } + + std::vector input_scopes() const { + if (meta_.has_value()) { + return {&lhs_, &rhs_, &meta_.value()}; + } + return {&lhs_, &rhs_}; + } + const Side& lhs() const { return lhs_; } + const Side& rhs() const { return rhs_; } + const Side& out() const { return out_; } + const std::optional& meta() const { return meta_; } + const Value& pid_m() const { return pid_m_; } + const Value& pid_k() const { return pid_k_; } + const Value& pid_n() const { return pid_n_; } + + static bool is_int4_param(const TritonFusionAnalysis& analysis, + TritonFusionAnalysis::Scope scope) { + const ConstHloInstructionSet& params = analysis.ScopeParameters(scope); + return params.size() == 1 && + (*params.cbegin())->shape().element_type() == S4; + } + + private: + Side lhs_; + Side rhs_; + Side out_; + std::optional meta_; + + Value pid_m_; + Value pid_k_; + Value pid_n_; +}; + } // namespace // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. @@ -2240,30 +2358,6 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, MatMulEmitterHelper emitter(libdevice_path, device_info, dot_instr, b, index_ty, dims, launch_config, analysis); - constexpr int group_m = 8; - const int64_t width = group_m * launch_config.grid_n; - - auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; - - auto pid_nc = - b.create(launch_config.noncontracting_program_id_dim); - Value pid_k = (split_k > 1) - ? b.create(mt::ProgramIDDim::Z) - : Value{}; - - auto group_id = b.create(pid_nc, c32(width)); - ma::ConstantOp group_m_op = c32(group_m); - auto first_pid_m = b.create(group_id, group_m_op); - auto sub0 = b.create(c32(launch_config.grid_m), first_pid_m); - auto group_size = b.create( - b.create(ma::CmpIPredicate::slt, sub0, group_m_op), sub0, - group_m_op); - - auto pid_m = b.create(first_pid_m, - b.create(pid_nc, group_size)); - auto pid_n = b.create(b.create(pid_nc, c32(width)), - group_size); - TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType()); ma::ConstantOp accumulator_init = @@ -2274,46 +2368,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, absl::flat_hash_map iter_args_to_inputs; absl::flat_hash_map> iter_args_to_boundary_checks; - Side lhs{TritonFusionAnalysis::Scope::LHS, - /*tiled_dims=*/ - {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m, block_m, - /*split_value=*/1), - DimProperties(dims.lhs_contracting_dim_idx, pid_k, - block_k / (1 + is_sparse), split_k)}, - dims.lhs_batch_dim_idx}; - Side rhs{ - TritonFusionAnalysis::Scope::RHS, - /*tiled_dims=*/ - {DimProperties(dims.rhs_contracting_dim_idx, pid_k, block_k, split_k), - DimProperties(dims.rhs_noncontracting_dim_idx, pid_n, block_n, - /*split_value=*/1)}, - dims.rhs_batch_dim_idx}; - Side out{TritonFusionAnalysis::Scope::OUTPUT, - /*tiled_dims=*/ - {DimProperties(dims.out_lhs_noncontracting_dim_idx, pid_m, block_m, - /*split_value=*/1), - DimProperties(dims.out_rhs_noncontracting_dim_idx, pid_n, block_n, - /*split_value=*/1)}, - dims.out_batch_dim_idx}; - - std::vector scopes = {lhs, rhs}; - if (is_sparse) { - scopes.push_back( - {TritonFusionAnalysis::Scope::META, - /*tiled_dims=*/ - {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m, block_m, - /*split_value=*/1), - DimProperties(dims.lhs_contracting_dim_idx, pid_k, block_k / 16, - split_k)}, - dims.lhs_batch_dim_idx}); - } + // Calculate the sizes of the lhs, rhs, meta, and output sides. + Scopes scopes(b, analysis, dims, config, launch_config, is_sparse); + + auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; constexpr size_t kLhsMetaOperandIdx = HloDotInstruction::kOperands; size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size(); size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size(); absl::flat_hash_map triton_type_for_input; - for (const Side& side : {lhs, rhs}) { + for (const Side& side : {scopes.lhs(), scopes.rhs()}) { for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { TF_ASSIGN_OR_RETURN(Type input_ty, TritonType(b, input->shape().element_type())); @@ -2330,7 +2395,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, // Load tiles of all parameters of LHS and RHS scopes and advance pointers. for (int i = 0; i < iter_args.size() - 1; ++i) { const int index = i < lsize ? 0 : i < lsize + rsize ? 1 : 2; - Side& side = scopes[index]; + const Side& side = *(scopes.input_scopes()[index]); const HloInstruction* param_hlo = iter_args_to_inputs[i]; Type param_ty = index == kLhsMetaOperandIdx @@ -2370,10 +2435,10 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, } // Emit all operations of LHS and RHS scopes. - Value dot_input_lhs = emitter.MakeInput(lhs, 0, values[0]); - Value dot_input_rhs = emitter.MakeInput(rhs, 1, values[1]); + Value dot_input_lhs = emitter.MakeInput(scopes.lhs(), 0, values[0]); + Value dot_input_rhs = emitter.MakeInput(scopes.rhs(), 1, values[1]); Value dot_input_meta = - is_sparse ? emitter.MakeInput(scopes.back(), 2, values[2]) : Value{}; + is_sparse ? emitter.MakeInput(*scopes.meta(), 2, values[2]) : Value{}; // Operation in the fusion before the dot can alter the elements of the // tiles that were zero masked during loads. These have to be zeroed here @@ -2386,9 +2451,10 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, auto elements_in_tile = b.create(c32(dims.k / denom), ki); int size = block_k / denom; auto range_k = Range(b, size); - if (pid_k != nullptr) { + if (scopes.pid_k() != nullptr) { range_k = b.create( - range_k, Splat(b, b.create(pid_k, c32(size)), size)); + range_k, + Splat(b, b.create(scopes.pid_k(), c32(size)), size)); } auto ty = mlir::cast(input.getType()); TensorValue range_expanded = mlir::cast( @@ -2464,15 +2530,15 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, SmallVector iter_args; iter_args.reserve(lsize + rsize + 1 + is_sparse); - for (const Side& side : scopes) { - for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { + for (const Side* side : scopes.input_scopes()) { + for (const HloInstruction* input : ScopeInputs(analysis, side->scope)) { TF_RET_CHECK( iter_args_to_inputs.insert({iter_args.size(), input}).second); TF_ASSIGN_OR_RETURN(SmallVector arguments, GetArguments(fn, *input)); TF_ASSIGN_OR_RETURN(Value tensor_ptr, emitter.EmitTensorPointer( - input, side, arguments, pid_k, + input, *side, arguments, scopes.pid_k(), iter_args_to_boundary_checks[iter_args.size()])); iter_args.push_back(tensor_ptr); } @@ -2499,17 +2565,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, std::vector boundary_checks; TF_ASSIGN_OR_RETURN(SmallVector arguments, GetArguments(fn, *input)); - TF_ASSIGN_OR_RETURN(Value tensor_pointer, - emitter.EmitTensorPointer(input, out, arguments, - pid_k, boundary_checks)); + TF_ASSIGN_OR_RETURN( + Value tensor_pointer, + emitter.EmitTensorPointer(input, scopes.out(), arguments, + scopes.pid_k(), boundary_checks)); TF_RET_CHECK(values_out .insert({input, EmitParameterLoad(b, tensor_pointer, boundary_checks)}) .second); } TF_RETURN_IF_ERROR(EmitScope(b, libdevice_path, device_info, &analysis, - TritonFusionAnalysis::Scope::OUTPUT, - out.tiled_dims, to_emit, values_out) + scopes.out(), to_emit, values_out) .status()); } @@ -2522,9 +2588,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, TF_ASSIGN_OR_RETURN( Value tensor_pointer, emitter.EmitTensorPointer( - producer, out, - {fn.getArgument(i + dot_instr->parent()->num_parameters())}, pid_k, - boundary_checks)); + producer, scopes.out(), + {fn.getArgument(i + dot_instr->parent()->num_parameters())}, + scopes.pid_k(), boundary_checks)); b.create(tensor_pointer, values_out[producer], boundary_checks, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); } @@ -2665,8 +2731,9 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, const BlockLevelParameters& block_level_parameters) { const HloComputation* computation = fusion->fused_instructions_computation(); SymbolicTileAnalysisOrError symbolic_tile_analysis_or = - SymbolicTileAnalysis::AnalyzeComputation(*computation, - builder.getContext()); + SymbolicTileAnalysis::AnalyzeComputation( + *computation, builder.getContext(), + TritonEmitterConstraints::GetBuilder()); if (std::holds_alternative(symbolic_tile_analysis_or)) { return Internal( "Unsupported fusion in EmitGeneric: %s", @@ -2756,10 +2823,19 @@ absl::Status CreateInternalError(std::string_view message, os << message << "\n"; os << fusion->fused_instructions_computation()->ToString() << "\n"; os << "triton_module: \n"; - triton_module->print(os); + triton_module->print(os, mlir::OpPrintingFlags().enableDebugInfo(true, true)); return absl::InternalError(err); } +absl::Status DoSupportType(const DebugOptions& debug_options, + PrimitiveType type) { + if (type == S4 && !debug_options.xla_gpu_enable_triton_gemm_int4()) { + return absl::FailedPreconditionError( + "Int4 support is not enabled in the debug options."); + } + return absl::OkStatus(); +} + absl::StatusOr> CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, @@ -2776,10 +2852,12 @@ absl::StatusOr> CreateTritonModule( llvm_ir::CreateMlirModuleOp(loc); b.setInsertionPointToEnd(triton_module->getBody()); + const auto debug_options = fusion->GetModule()->config().debug_options(); // Build Triton kernel. SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { PrimitiveType type = p->shape().element_type(); + TF_RETURN_IF_ERROR(DoSupportType(debug_options, type)); Type ir_type; if (type == U16) { ir_type = b.getI16Type(); @@ -2837,10 +2915,17 @@ absl::StatusOr> CreateTritonModule( "Failed to create Triton module for fusion:", fusion, *triton_module); } - VLOG(6) << llvm_ir::DumpToString(*triton_module); + auto dump_triton_ir = [&]() { + std::string triton_ir; + llvm::raw_string_ostream os(triton_ir); + triton_module->print(os, + mlir::OpPrintingFlags().enableDebugInfo(true, true)); + return triton_ir; + }; + VLOG(6) << dump_triton_ir(); if (DumpingEnabledForHloModule(*hlo_computation->parent())) { DumpToFileInDirOrStdout(*hlo_computation->parent(), "triton_ir", "ttir", - llvm_ir::DumpToString(*triton_module)); + dump_triton_ir()); } return std::move(triton_module); @@ -2948,16 +3033,16 @@ absl::StatusOr CompileTritonToLLVM( .ok()) { return Internal("Failed to create Triton pipeline."); } - if (log_stream.has_value()) { - pm.printAsTextualPipeline(log_stream.value()); - log_stream->write("\n\n", 2); - } // Triton generates pointers to the global address space, while XLA needs a // kernel signature with pointers to the generic address space. - pm.addPass(std::make_unique()); + pm.addPass(CreateGeneralizeKernelSignaturePass()); // llvm::Linker::linkModules() segfaults if we don't strip locations. pm.addPass(mlir::createStripDebugInfoPass()); + if (log_stream.has_value()) { + pm.printAsTextualPipeline(log_stream.value()); + log_stream->write("\n\n", 2); + } bool succeeded = mlir::succeeded(pm.run(triton_module)); if (log_stream.has_value()) { diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 720a8b4ca305c1..016d9925c4dcff 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include #include "absl/algorithm/container.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" @@ -108,6 +109,7 @@ class TritonGemmTest : public TritonTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); // Always rewrite Gemms with Triton regardless of size. debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + debug_options.set_xla_gpu_enable_triton_gemm_int4(true); return debug_options; } @@ -136,6 +138,245 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { } }; +TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDim) { + // We prove that triton can handle int4 dot with non minor + // lhs_contracting_dim. + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[1024,8]{1,0} parameter(0) + lhs_converted = bf16[1024,8]{1,0} convert(lhs) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={0}, + rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[1024,8]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDimWithBatchDim0) { + // We prove that triton can handle int4 dot with non minor + // lhs_contracting_dim. + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[16,1024,8]{2,1,0} parameter(0) + lhs_converted = bf16[16,1024,8]{2,1,0} convert(lhs) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} dot(lhs_converted, rhs), + lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = s4[16,1024,8]{2,1,0} parameter(0) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, LHSInt4MinorContractingDim) { + // We prove that triton can handle int4 dot with minor lhs_contracting_dim. + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[8,1024]{1,0} parameter(0) + lhs_converted = bf16[8,1024]{1,0} convert(lhs) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[8,1024]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, Int4ConvertPlusNegate) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[8,1024]{1,0} parameter(0) + lhs_converted = bf16[8,1024]{1,0} convert(lhs) + lhs_negated = bf16[8,1024]{1,0} negate(lhs_converted) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_negated, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[8,1024]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, LHSInt4MinorContractingDimWithBatchDim0) { + // We prove that triton can handle int4 dot with minor lhs_contracting_dim. + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[16,8,1024]{2,1,0} parameter(0) + lhs_converted = bf16[16,8,1024]{2,1,0} convert(lhs) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} dot(lhs_converted, rhs), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = s4[16,8,1024]{2,1,0} parameter(0) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDim) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[1024,4]{1,0} parameter(1) + rhs_converted = bf16[1024,4]{1,0} convert(rhs) + ROOT dot = bf16[8,4] dot(lhs, rhs_converted), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + + ENTRY main { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDim) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[4,1024]{1,0} parameter(1) + rhs_converted = bf16[4,1024]{1,0} convert(rhs) + ROOT dot = bf16[8,4] dot(lhs, rhs_converted), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[4,1024]{1,0} parameter(1) + ROOT dot = bf16[8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDimWithBatchDim) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,1024,4]{2,1,0} parameter(1) + rhs_converted = bf16[16,1024,4]{2,1,0} convert(rhs) + ROOT dot = bf16[16,8,4] dot(lhs, rhs_converted), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDimWithBatchDim0) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,4,1024]{2,1,0} parameter(1) + rhs_converted = bf16[16,4,1024]{2,1,0} convert(rhs) + ROOT dot = bf16[16,8,4] dot(lhs, rhs_converted), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={2} + } + + ENTRY main { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,4,1024]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + TEST_F(TritonTest, TestGemm) { const std::string kHloText = R"( HloModule t, is_scheduled=true @@ -1564,7 +1805,7 @@ ENTRY e { p0 = bf16[8192,512]{1,0} parameter(0) p1 = bf16[512,512]{1,0} parameter(1) p2 = bf16[8192,512]{1,0} parameter(2) - ROOT fusion = bf16[8192,512]{1,0} fusion(p0,p1,p2), kind=kCustom, calls=triton_computation, + ROOT fusion = bf16[8192,512]{1,0} fusion(p0,p1,p2), kind=kCustom, calls=triton_computation, backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"256","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4","num_ctas":"1"}}} })"; @@ -2558,10 +2799,12 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Transpose( - m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast( + m::Fusion(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)) + .WithFusionKind(HloInstruction::FusionKind::kInput)))); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } @@ -2649,9 +2892,9 @@ TEST_F(TritonGemmTestAny, HloModule t ENTRY e { - parameter_0 = f32[32,4000] parameter(0) - parameter_1 = f32[32,4000,6400] parameter(1) - ROOT dot = f32[32,6400] dot(parameter_0, parameter_1), lhs_batch_dims={0}, + parameter_0 = f32[1,40] parameter(0) + parameter_1 = f32[1,40,250000] parameter(1) + ROOT dot = f32[1,250000] dot(parameter_0, parameter_1), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; @@ -2671,9 +2914,9 @@ TEST_F(TritonGemmTestAny, HloModule t ENTRY e { - parameter_0 = f32[32,4000,6400] parameter(0) - parameter_1 = f32[32,4000] parameter(1) - ROOT dot = f32[32,6400] dot(parameter_0, parameter_1), lhs_batch_dims={0}, + parameter_0 = f32[1,40,250000] parameter(0) + parameter_1 = f32[1,40] parameter(1) + ROOT dot = f32[1,250000] dot(parameter_0, parameter_1), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index d59b6cc48c7ac0..28f630ba5f5a98 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -609,6 +609,64 @@ ENTRY entry { // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. +TEST_F(TritonEmitterTest, + EmitterFailsIfFusionBackendConfigDoesNotSatisfyConstraints) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +max_computation { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(param_0, param_1) +} + +fused_computation { + param_0 = f32[8192,50304] parameter(0) + constant = f32[] constant(-inf) + reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=max_computation + broadcast = f32[8192,50304] broadcast(reduce), dimensions={0} + ROOT subtract = f32[8192,50304] subtract(param_0, broadcast) +} + +ENTRY entry_computation { + param_0 = f32[8192,50304] parameter(0) + ROOT fusion = f32[8192,50304] fusion(param_0), + kind=kCustom, calls=fused_computation, + backend_config={"fusion_backend_config": { + "kind":"__triton", + "block_level_fusion_config": {"output_tile_sizes": ["1024","1"], + "num_warps": "1"}}} +})")); + const HloFusionInstruction* triton_fusion = Cast( + hlo_module->entry_computation()->root_instruction()); + + auto compute_capability = + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, /*minor=*/0}; + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(compute_capability); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = {1024, 1}; + block_level_parameters.num_warps = 1; + + // Because of reduce, we need to load full rows from param_0 and the load tile + // will be 1024 * 65536 = 67108864 elements, that is larger than the limit of + // 1048576. + EXPECT_THAT( + TritonWrapper("test_fn", triton_fusion, compute_capability, dev_info, + block_level_parameters, &llvm_module, mlir_context), + tsl::testing::StatusIs( + absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr( + "Tile parameters 1024, 1 do not satisfy constraints."))); +} + +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should b +// moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterReductionFusion) { const std::string kHloText = R"( HloModule t diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc index a327c0c1c74c88..92616fa78f7225 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -112,7 +113,9 @@ TritonMakeTensorPtrTest::CreateAndTileParameterHloInstruction( verified_hlo_module->entry_computation()->root_instruction()); SymbolicTileAnalysisOrError symbolic_tile_analysis_or = - SymbolicTileAnalysis::AnalyzeFusion(*fusion_adaptor, &mlir_context_); + SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_, + TritonEmitterConstraints::GetBuilder()); CHECK( std::holds_alternative(symbolic_tile_analysis_or)); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index 7ae8f5efcfd4b6..0c66c03d8aed7c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc index 44c9d51c5921d0..942e27f3226982 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc @@ -15,10 +15,7 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_support.h" -#include -#include #include -#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" @@ -33,279 +30,43 @@ limitations under the License. #include "xla/layout.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/variant_visitor.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/tensor_float_32_utils.h" namespace xla { namespace gpu { -namespace legacy_triton { - -bool IsDistributiveOverAddition(const HloInstruction& hlo) { - // The list is most likely incomplete. - // For example division can be added too but only for operand #0. - if (hlo.opcode() == HloOpcode::kMultiply || - hlo.opcode() == HloOpcode::kNegate || - hlo.opcode() == HloOpcode::kBitcast || - hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kCopy || - hlo.opcode() == HloOpcode::kTranspose || - hlo.opcode() == HloOpcode::kConvert || - hlo.opcode() == HloOpcode::kBroadcast || - hlo.opcode() == HloOpcode::kSlice) { - return true; - } - return false; -} - -// Types that are supported by Triton as dot output. -// -// BF16 is supported in a sense that all operations on it are implemented -// through F32 and converts have to be inserted into the HLO graph, but -// they can be missing during fusion. -bool IsTritonSupportedDotOutputType( - const PrimitiveType t, const se::GpuComputeCapability& gpu_version) { - switch (t) { - case F16: - case F32: - return true; - case F8E5M2: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastAmpere(); - }, - [](const se::RocmComputeCapability& cc) { - return false; - }}, - gpu_version); - - case F8E4M3FN: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastHopper(); - }, - [](const se::RocmComputeCapability& cc) { - return false; - }}, - gpu_version); - case BF16: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return true; - }, - [](const se::RocmComputeCapability& cc) { - return cc.has_bf16_dtype_support(); - }}, - gpu_version); - default: - return false; - } -}; - -// Data types that are supported by the Triton emitters. -// TODO(b/266862493): Support more data types (F8, F64, etc.). -bool IsTritonSupportedDataType(PrimitiveType type, - const se::GpuComputeCapability& gpu_version) { - if (IsTritonSupportedDotOutputType(type, gpu_version)) { - return true; - } - switch (type) { - case PRED: - case S8: - case S16: - case S32: - return true; - default: - return false; - } -} -std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( - PrimitiveType element_type) { - std::vector ret = {HloOpcode::kConvert}; - if (element_type == PrimitiveType::PRED) { - ret.push_back(HloOpcode::kNot); - return ret; - } - ret.push_back(HloOpcode::kAbs); - ret.push_back(HloOpcode::kNegate); - if (element_type == PrimitiveType::F32 || - element_type == PrimitiveType::BF16 || - element_type == PrimitiveType::F64) { - absl::c_copy(std::vector{HloOpcode::kCos, HloOpcode::kExp, - HloOpcode::kExpm1, HloOpcode::kFloor, - HloOpcode::kCeil, HloOpcode::kLog, - HloOpcode::kLog1p, HloOpcode::kRsqrt, - HloOpcode::kSin, HloOpcode::kSqrt, - HloOpcode::kCbrt, HloOpcode::kTan, - HloOpcode::kTanh, HloOpcode::kErf}, - std::back_inserter(ret)); - } - return ret; -} - -std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( - PrimitiveType element_type) { - if (element_type == PrimitiveType::PRED) { - return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, - HloOpcode::kCompare}; - } - std::vector ret = {HloOpcode::kAdd, HloOpcode::kCompare, - HloOpcode::kMaximum, HloOpcode::kMinimum, - HloOpcode::kMultiply, HloOpcode::kSubtract}; - if (element_type == PrimitiveType::F32 || - element_type == PrimitiveType::BF16 || - element_type == PrimitiveType::F64) { - ret.push_back(HloOpcode::kAtan2); - ret.push_back(HloOpcode::kDivide); - ret.push_back(HloOpcode::kPower); - } - return ret; -} - -std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( - PrimitiveType element_type) { - return {HloOpcode::kSelect, HloOpcode::kClamp}; -} - -bool IsTritonSupportedElementwiseUpToFloatNormalization( - HloOpcode opcode, PrimitiveType element_type) { - return absl::c_linear_search( - TritonSupportedUnaryElementwiseUpToFloatNormalization( - element_type), - opcode) || - absl::c_linear_search( - TritonSupportedBinaryElementwiseUpToFloatNormalization( - element_type), - opcode) || - absl::c_linear_search( - TritonSupportedTernaryElementwiseUpToFloatNormalization( - element_type), - opcode); -} - -CodegenDecision CanTritonHandleElementwise( - const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { - if (!IsTritonSupportedDataType(instr.shape().element_type(), gpu_version)) { - return "Unsupported output data type."; - } - - for (const HloInstruction* operand : instr.operands()) { - if (!IsTritonSupportedDataType(operand->shape().element_type(), - gpu_version)) { - return "Unsupported input data type."; - } - } - - if (instr.opcode() == HloOpcode::kConstant) { - return CodegenDecision{}; - } else if (!IsTritonSupportedElementwiseUpToFloatNormalization( - instr.opcode(), instr.operand(0)->shape().element_type())) { - return "Unsupported elementwise operation."; - } - return CodegenDecision{}; -} - -bool IsDotAlgorithmSupportedByTriton( - PrecisionConfig::Algorithm algorithm, - const se::GpuComputeCapability& gpu_version) { - auto cuda_compute_capability = - std::get_if(&gpu_version); - auto rocm_compute_capability = - std::get_if(&gpu_version); - switch (algorithm) { - case PrecisionConfig::ALG_DOT_TF32_TF32_F32: - if (cuda_compute_capability) { - return true; - } - return false; - case PrecisionConfig::ALG_DOT_BF16_BF16_F32: - case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: - case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: - if (cuda_compute_capability) { - return true; - } - if (rocm_compute_capability) { - return rocm_compute_capability->has_bf16_dtype_support(); - } - return false; - - // TODO(b/326579472): Fix the support of this algorithm and maybe allow it - // here. - case PrecisionConfig::ALG_DOT_F16_F16_F32: - // TODO(b/311331155): Triton F32 is about 3x slower than Triton TF32 and is - // slow to compile. Disable it for now. - case PrecisionConfig::ALG_DOT_F32_F32_F32: - default: - return false; - } -} - -// Filters GEMMs which can be handled using Triton. -CodegenDecision CanTritonHandleGEMM( - const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { - auto cuda_compute_capability = - std::get_if(&gpu_version); - auto rocm_compute_capability = - std::get_if(&gpu_version); - - CHECK(cuda_compute_capability || rocm_compute_capability); - - if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) { - if (!tsl::tensor_float_32_execution_enabled() || - absl::c_any_of(dot.precision_config().operand_precision(), - [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return "Having non-default operand precisions or TensorFloat-32 disabled " - "for Dot op with unset algorithm."; - } - } else { - if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), - gpu_version)) { - return "Unsupported algorithm on the current device(s)."; - } - } - - // TODO(b/266862493): Support more output types. - if (!IsTritonSupportedDotOutputType(dot.shape().element_type(), - gpu_version)) { - return "Unsupported output data type for Dot op."; - } - - if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), - gpu_version) || - !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), - gpu_version)) { - return "Unsupported input data type for Dot op."; - } - - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); - - // TODO(b/269580541): support multiple batch dimensions. - if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return "Multiple batch dimensions."; - } - - return CodegenDecision{}; -} +namespace legacy_triton { // Filters Reduces which can be handled using Triton. +// TODO(b/345763510): The function is in use by the new version of the triton +// support but the implementation of this function relies on the legacy +// IsTritonSupport... functions. It should be rewritten for the new +// infrastructure. legacy_triton:: prefix is used to avoid name collision with +// the new implementation and for clarity. CodegenDecision CanTritonHandleReduce( const HloReduceInstruction& reduce, const se::GpuComputeCapability& gpu_version) { - if (!IsTritonSupportedDataType(reduce.shape().element_type(), gpu_version)) { + if (!legacy_triton::IsTritonSupportedDataType(reduce.shape().element_type(), + gpu_version)) { return "Unsupported output data type for Reduce op."; } for (const HloInstruction* operand : reduce.operands()) { - if (!IsTritonSupportedDataType(operand->shape().element_type(), - gpu_version)) { + if (!legacy_triton::IsTritonSupportedDataType( + operand->shape().element_type(), gpu_version)) { return "Unsupported input data type for Reduce op."; } } bool is_triton_supported_reduction_computation = [&]() { - return absl::c_all_of( - reduce.to_apply()->instructions(), [&](const HloInstruction* instr) { - return IsTritonSupportedInstruction(*instr, gpu_version); - }); + return absl::c_all_of(reduce.to_apply()->instructions(), + [&](const HloInstruction* instr) { + return legacy_triton::IsTritonSupportedInstruction( + *instr, gpu_version); + }); }(); if (!is_triton_supported_reduction_computation) { return "Unsupported reduction computation by Triton."; @@ -317,96 +78,6 @@ CodegenDecision CanTritonHandleReduce( return "Reduction is not a row-reduction of a single operand."; } -bool NoNonContractingDimension(const HloDotInstruction& dot) { - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); - if (dim_numbers.lhs_batch_dimensions().size() + - dim_numbers.lhs_contracting_dimensions().size() == - dot.operand(0)->shape().rank() || - dim_numbers.rhs_batch_dimensions().size() + - dim_numbers.rhs_contracting_dimensions().size() == - dot.operand(1)->shape().rank()) { - return true; - } - return false; -} - -CodegenDecision IsTritonSupportedDynamicSlice( - const HloDynamicSliceInstruction& instr) { - for (const HloInstruction* index_operand : instr.index_operands()) { - switch (index_operand->shape().element_type()) { - case S8: - case S16: - case S32: - break; // supported - default: - return CodegenDecision( - "Dynamic slice is only supported with S8, S16, or S32 indices."); - } - } - - // Similar to normal slice, we cannot slice a non-major-most dimension as - // that would introduce non-contiguous strides under tiling. The existing - // check against this in GetRequirementsIfSupportedOrder is not suitable for - // dynamic slices, so we instead check for this here. - const HloInstruction* input = instr.operand(0); - Layout in_layout = input->shape().layout(); - int64_t majormost_dim_id = - in_layout.minor_to_major(in_layout.minor_to_major_size() - 1); - - for (int i = 0; i < input->shape().dimensions_size(); ++i) { - if (i == majormost_dim_id) { - continue; - } else if (input->shape().dimensions(i) != instr.slice_sizes(i)) { - return CodegenDecision( - "Unsupported dynamic slice on non-major-most dimension."); - } - } - - // TODO(b/343143854): Check the subtleties of which dynamic slices are - // supported, for example that a fragmented dimension cannot be sliced. - return CodegenDecision{}; -} - -CodegenDecision IsTritonSupportedInstruction( - const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { - if (instr.IsElementwise()) { - return CanTritonHandleElementwise(instr, gpu_version); - } - - switch (instr.opcode()) { - case HloOpcode::kDot: { - auto* dot = Cast(&instr); - // Cases where lhs or rhs have no non-contracting dims are not handled. - if (NoNonContractingDimension(*dot)) { - return "No non-contracting dimensions."; - } - return CanTritonHandleGEMM(*dot, gpu_version); - } - case HloOpcode::kTuple: { - if (instr.IsRoot()) { - return CodegenDecision{}; - } - return "Only supports root tuples."; - } - case HloOpcode::kDynamicSlice: { - return IsTritonSupportedDynamicSlice( - *Cast(&instr)); - } - case HloOpcode::kBitcast: - case HloOpcode::kTranspose: - case HloOpcode::kSlice: - case HloOpcode::kReshape: - case HloOpcode::kPad: - case HloOpcode::kConcatenate: - case HloOpcode::kParameter: - case HloOpcode::kBroadcast: - return CodegenDecision{}; - default: - break; - } - return "Unsupported opcode."; -} - } // namespace legacy_triton namespace { @@ -563,8 +234,8 @@ CodegenDecision IsTritonSupportedInstructionImpl( return CodegenDecision{}; } - bool output_type_is_supported = - IsTritonSupportedDataType(instr.shape().element_type(), gpu_version); + auto type = instr.shape().element_type(); + bool output_type_is_supported = IsTritonSupportedDataType(type, gpu_version); if (!output_type_is_supported) { return "Unsupported output data type."; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h index abd2a4087216a7..14431e85b74f33 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h @@ -18,91 +18,16 @@ limitations under the License. // This file is the home of the basic Triton support checks which are used by // multiple other components. -#include - #include "absl/status/status.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" namespace xla { namespace gpu { -using CodegenDecision = FusionDecision; - -namespace legacy_triton { - -// Tells if f(a+b) == f(a) + f(b). -bool IsDistributiveOverAddition(const HloInstruction& hlo); - -// Allowlist of unary elementwise operations supported by the legacy Triton -// emitters. -// -// Note: this is not an accurate representation of what is actually supported by -// the Triton emitters, because operations affected by FloatNormalization may -// be tagged as "supported" here, even though FloatNormalization is required to -// make them work. We could fix this, but this is code we aim to delete soon, so -// it doesn't seem worth it. We'll revisit this decision if the code doesn't go -// away soon. -std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( - PrimitiveType); - -// Allowlist of binary elementwise operations supported by the legacy Triton -// emitters. -// -// Note: this is not an accurate representation of what is actually supported by -// the Triton emitters, because operations affected by FloatNormalization may -// be tagged as "supported" here, even though FloatNormalization is required to -// make them work. We could fix this, but this is code we aim to delete soon, so -// it doesn't seem worth it. We'll revisit this decision if the code doesn't go -// away soon. -std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( - PrimitiveType); - -// Allowlist of ternary elementwise operations supported by the legacy Triton -// emitters. -// -// Note: this is not an accurate representation of what is actually supported by -// the Triton emitters, because operations affected by FloatNormalization may -// be tagged as "supported" here, even though FloatNormalization is required to -// make them work. We could fix this, but this is code we aim to delete soon, so -// it doesn't seem worth it. We'll revisit this decision if the code doesn't go -// away soon. -std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( - PrimitiveType); -// Data types that are supported by the legacy Triton emitters. -bool IsTritonSupportedDataType(PrimitiveType, const se::GpuComputeCapability&); - -// Checks elementwise operation against unary, binary, and ternary elementwise -// operations supported by the legacy Triton emitters. -// -// Note: this is not an accurate representation of what is actually supported by -// the Triton emitters, because operations affected by FloatNormalization may -// be tagged as "supported" here, even though FloatNormalization is required to -// make them work. We could fix this, but this is code we aim to delete soon, so -// it doesn't seem worth it. We'll revisit this decision if the code doesn't go -// away soon. -bool IsTritonSupportedElementwiseUpToFloatNormalization(HloOpcode, - PrimitiveType); - -CodegenDecision CanTritonHandleGEMM( - const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version); - -// Checks instruction against the requirements of the legacy Triton emitters. -CodegenDecision IsTritonSupportedInstruction( - const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); - -// Checks dynamic slice against the requirements of the legacy Triton emitters. -// -// This is exposed separately from IsTritonSupportedInstruction because we can -// use it in the dimension order propagation without adding a dependency on the -// GPU version. -CodegenDecision IsTritonSupportedDynamicSlice( - const HloDynamicSliceInstruction& instr); -} // namespace legacy_triton +using CodegenDecision = FusionDecision; // Checks that Triton officially supports the provided compute capability. // diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc new file mode 100644 index 00000000000000..b07630b7cb7734 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -0,0 +1,396 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/triton/triton_support.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/variant_visitor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/tensor_float_32_utils.h" + +namespace xla { +namespace gpu { +namespace legacy_triton { + +bool IsDistributiveOverAddition(const HloInstruction& hlo) { + // The list is most likely incomplete. + // For example division can be added too but only for operand #0. + if (hlo.opcode() == HloOpcode::kMultiply || + hlo.opcode() == HloOpcode::kNegate || + hlo.opcode() == HloOpcode::kBitcast || + hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kCopy || + hlo.opcode() == HloOpcode::kTranspose || + hlo.opcode() == HloOpcode::kConvert || + hlo.opcode() == HloOpcode::kBroadcast || + hlo.opcode() == HloOpcode::kSlice) { + return true; + } + return false; +} + +// Types that are supported by Triton as dot output. +// +// BF16 is supported in a sense that all operations on it are implemented +// through F32 and converts have to be inserted into the HLO graph, but +// they can be missing during fusion. +bool IsTritonSupportedDotOutputType( + const PrimitiveType t, const se::GpuComputeCapability& gpu_version) { + switch (t) { + case F16: + case F32: + return true; + case F8E5M2: + return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return cc.IsAtLeastAmpere(); + }, + [](const se::RocmComputeCapability& cc) { + return false; + }}, + gpu_version); + + case F8E4M3FN: + return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return cc.IsAtLeastHopper(); + }, + [](const se::RocmComputeCapability& cc) { + return false; + }}, + gpu_version); + case BF16: + return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return true; + }, + [](const se::RocmComputeCapability& cc) { + return cc.has_bf16_dtype_support(); + }}, + gpu_version); + default: + return false; + } +}; + +// Data types that are supported by the Triton emitters. +// TODO(b/266862493): Support more data types (F8, F64, etc.). +bool IsTritonSupportedDataType(PrimitiveType type, + const se::GpuComputeCapability& gpu_version) { + if (IsTritonSupportedDotOutputType(type, gpu_version)) { + return true; + } + switch (type) { + case PRED: + case S8: + case S16: + case S32: + return true; + default: + return false; + } +} + +CodegenDecision IsInstructionSupportsDataTypes( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (!IsTritonSupportedDataType(instr.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + + for (const HloInstruction* operand : instr.operands()) { + const auto operand_type = operand->shape().element_type(); + switch (instr.opcode()) { + case HloOpcode::kConvert: + // TODO(b/358580281): remove DebugOptions from this function after + // enabling int4 in Triton GEMM. + if (operand_type == S4 && instr.GetModule() + ->config() + .debug_options() + .xla_gpu_enable_triton_gemm_int4()) { + continue; + } + [[fallthrough]]; + default: + if (!IsTritonSupportedDataType(operand_type, gpu_version)) { + return "Unsupported input data type."; + } + } + } + return CodegenDecision{}; +} + +std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( + PrimitiveType element_type) { + std::vector ret = {HloOpcode::kConvert}; + if (element_type == PrimitiveType::PRED) { + ret.push_back(HloOpcode::kNot); + return ret; + } + ret.push_back(HloOpcode::kAbs); + ret.push_back(HloOpcode::kNegate); + if (element_type == PrimitiveType::F32 || + element_type == PrimitiveType::BF16 || + element_type == PrimitiveType::F64) { + absl::c_copy(std::vector{HloOpcode::kCos, HloOpcode::kExp, + HloOpcode::kExpm1, HloOpcode::kFloor, + HloOpcode::kCeil, HloOpcode::kLog, + HloOpcode::kLog1p, HloOpcode::kRsqrt, + HloOpcode::kSin, HloOpcode::kSqrt, + HloOpcode::kCbrt, HloOpcode::kTan, + HloOpcode::kTanh, HloOpcode::kErf}, + std::back_inserter(ret)); + } + return ret; +} + +std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( + PrimitiveType element_type) { + if (element_type == PrimitiveType::PRED) { + return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, + HloOpcode::kCompare}; + } + std::vector ret = {HloOpcode::kAdd, HloOpcode::kCompare, + HloOpcode::kMaximum, HloOpcode::kMinimum, + HloOpcode::kMultiply, HloOpcode::kSubtract}; + if (element_type == PrimitiveType::F32 || + element_type == PrimitiveType::BF16 || + element_type == PrimitiveType::F64) { + ret.push_back(HloOpcode::kAtan2); + ret.push_back(HloOpcode::kDivide); + ret.push_back(HloOpcode::kPower); + } + return ret; +} + +std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( + PrimitiveType element_type) { + return {HloOpcode::kSelect, HloOpcode::kClamp}; +} + +bool IsTritonSupportedElementwiseUpToFloatNormalization( + HloOpcode opcode, PrimitiveType element_type) { + return absl::c_linear_search( + TritonSupportedUnaryElementwiseUpToFloatNormalization( + element_type), + opcode) || + absl::c_linear_search( + TritonSupportedBinaryElementwiseUpToFloatNormalization( + element_type), + opcode) || + absl::c_linear_search( + TritonSupportedTernaryElementwiseUpToFloatNormalization( + element_type), + opcode); +} + +CodegenDecision CanTritonHandleElementwise( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (auto decision = IsInstructionSupportsDataTypes(instr, gpu_version); + !decision.CanFuse()) { + return decision; + } + if (instr.opcode() == HloOpcode::kConstant) { + return CodegenDecision{}; + } else if (!IsTritonSupportedElementwiseUpToFloatNormalization( + instr.opcode(), instr.operand(0)->shape().element_type())) { + return "Unsupported elementwise operation."; + } + return CodegenDecision{}; +} + +bool IsDotAlgorithmSupportedByTriton( + PrecisionConfig::Algorithm algorithm, + const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); + switch (algorithm) { + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + if (cuda_compute_capability) { + return true; + } + return false; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + if (cuda_compute_capability) { + return true; + } + if (rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + return false; + + // TODO(b/326579472): Fix the support of this algorithm and maybe allow it + // here. + case PrecisionConfig::ALG_DOT_F16_F16_F32: + // TODO(b/311331155): Triton F32 is about 3x slower than Triton TF32 and is + // slow to compile. Disable it for now. + case PrecisionConfig::ALG_DOT_F32_F32_F32: + default: + return false; + } +} + +// Filters GEMMs which can be handled using Triton. +CodegenDecision CanTritonHandleGEMM( + const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); + + CHECK(cuda_compute_capability || rocm_compute_capability); + + if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) { + if (!tsl::tensor_float_32_execution_enabled() || + absl::c_any_of(dot.precision_config().operand_precision(), + [](int x) { return x != PrecisionConfig::DEFAULT; })) { + return "Having non-default operand precisions or TensorFloat-32 disabled " + "for Dot op with unset algorithm."; + } + } else { + if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), + gpu_version)) { + return "Unsupported algorithm on the current device(s)."; + } + } + + // TODO(b/266862493): Support more output types. + if (!IsTritonSupportedDotOutputType(dot.shape().element_type(), + gpu_version)) { + return "Unsupported output data type for Dot op."; + } + + if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), + gpu_version) || + !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), + gpu_version)) { + return "Unsupported input data type for Dot op."; + } + + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + + // TODO(b/269580541): support multiple batch dimensions. + if (dim_numbers.lhs_batch_dimensions().size() > 1) { + return "Multiple batch dimensions."; + } + + return CodegenDecision{}; +} + +bool NoNonContractingDimension(const HloDotInstruction& dot) { + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + if (dim_numbers.lhs_batch_dimensions().size() + + dim_numbers.lhs_contracting_dimensions().size() == + dot.operand(0)->shape().rank() || + dim_numbers.rhs_batch_dimensions().size() + + dim_numbers.rhs_contracting_dimensions().size() == + dot.operand(1)->shape().rank()) { + return true; + } + return false; +} + +CodegenDecision IsTritonSupportedDynamicSlice( + const HloDynamicSliceInstruction& instr) { + for (const HloInstruction* index_operand : instr.index_operands()) { + switch (index_operand->shape().element_type()) { + case S8: + case S16: + case S32: + break; // supported + default: + return CodegenDecision( + "Dynamic slice is only supported with S8, S16, or S32 indices."); + } + } + + // Similar to normal slice, we cannot slice a non-major-most dimension as + // that would introduce non-contiguous strides under tiling. The existing + // check against this in GetRequirementsIfSupportedOrder is not suitable for + // dynamic slices, so we instead check for this here. + const HloInstruction* input = instr.operand(0); + Layout in_layout = input->shape().layout(); + int64_t majormost_dim_id = + in_layout.minor_to_major(in_layout.minor_to_major_size() - 1); + + for (int i = 0; i < input->shape().dimensions_size(); ++i) { + if (i == majormost_dim_id) { + continue; + } else if (input->shape().dimensions(i) != instr.slice_sizes(i)) { + return CodegenDecision( + "Unsupported dynamic slice on non-major-most dimension."); + } + } + + // TODO(b/343143854): Check the subtleties of which dynamic slices are + // supported, for example that a fragmented dimension cannot be sliced. + return CodegenDecision{}; +} + +CodegenDecision IsTritonSupportedInstruction( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (instr.IsElementwise()) { + return CanTritonHandleElementwise(instr, gpu_version); + } + + switch (instr.opcode()) { + case HloOpcode::kDot: { + auto* dot = Cast(&instr); + // Cases where lhs or rhs have no non-contracting dims are not handled. + if (NoNonContractingDimension(*dot)) { + return "No non-contracting dimensions."; + } + return CanTritonHandleGEMM(*dot, gpu_version); + } + case HloOpcode::kTuple: { + if (instr.IsRoot()) { + return CodegenDecision{}; + } + return "Only supports root tuples."; + } + case HloOpcode::kDynamicSlice: { + return IsTritonSupportedDynamicSlice( + *Cast(&instr)); + } + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kSlice: + case HloOpcode::kReshape: + case HloOpcode::kPad: + case HloOpcode::kConcatenate: + case HloOpcode::kParameter: + case HloOpcode::kBroadcast: + return CodegenDecision{}; + default: + break; + } + return "Unsupported opcode."; +} + +} // namespace legacy_triton +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.h new file mode 100644 index 00000000000000..da088465fa43e8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.h @@ -0,0 +1,110 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_SUPPORT_LEGACY_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_SUPPORT_LEGACY_H_ + +// This file is the home of the basic Triton support checks which are used by +// multiple other components. + +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/instruction_fusion.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +using CodegenDecision = FusionDecision; + +namespace legacy_triton { + +// Tells if f(a+b) == f(a) + f(b). +bool IsDistributiveOverAddition(const HloInstruction& hlo); + +// Allowlist of unary elementwise operations supported by the legacy Triton +// emitters. +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( + PrimitiveType); + +// Allowlist of binary elementwise operations supported by the legacy Triton +// emitters. +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( + PrimitiveType); + +// Allowlist of ternary elementwise operations supported by the legacy Triton +// emitters. +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( + PrimitiveType); + +// Data types that are supported by the legacy Triton emitters. +bool IsTritonSupportedDataType(PrimitiveType, const se::GpuComputeCapability&); + +// Checks elementwise operation against unary, binary, and ternary elementwise +// operations supported by the legacy Triton emitters. +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +bool IsTritonSupportedElementwiseUpToFloatNormalization(HloOpcode, + PrimitiveType); + +CodegenDecision CanTritonHandleGEMM( + const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version); + +// Checks instruction against the requirements of the legacy Triton emitters. +CodegenDecision IsTritonSupportedInstruction( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); + +// Checks dynamic slice against the requirements of the legacy Triton emitters. +// +// This is exposed separately from IsTritonSupportedInstruction because we can +// use it in the dimension order propagation without adding a dependency on the +// GPU version. +CodegenDecision IsTritonSupportedDynamicSlice( + const HloDynamicSliceInstruction& instr); + +} // namespace legacy_triton +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_SUPPORT_LEGACY_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc index 3eefd362564f71..41adc715e3849e 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc @@ -15,6 +15,8 @@ limitations under the License. // TODO(b/343158720): Simplify the tests in this file after a generic emitter // has landed. +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" + #include #include #include @@ -34,7 +36,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/fusions/triton/triton_test_utils.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc index a47dd530cd228b..89da38e7350091 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_support.h" +#include #include #include #include @@ -26,8 +27,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" @@ -145,7 +148,7 @@ auto AllDevicesToTest() { // Generates all the possible test combinations for a given opcodes. A test // combination is a tuple of the form (data_type, opcode, compute_capability). -auto AllTestCombinationsForOpcodes(std::vector&& opcodes) { +auto AllTestCombinationsForOpcodes(absl::Span opcodes) { std::vector> test_combinations; for (PrimitiveType data_type : AllXlaDataTypes()) { @@ -226,10 +229,13 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16}, cc); } -INSTANTIATE_TEST_SUITE_P(BitcastOrReshapeTestSuite, BitcastOrReshapeTest, - AllTestCombinationsForOpcodes({HloOpcode::kBitcast, - HloOpcode::kReshape}), - TritonSupportTestTypeOpcodeAndDeviceToString); +constexpr std::array kTestedOpsBitcastReshape = {HloOpcode::kBitcast, + HloOpcode::kReshape}; + +INSTANTIATE_TEST_SUITE_P( + BitcastOrReshapeTestSuite, BitcastOrReshapeTest, + AllTestCombinationsForOpcodes(kTestedOpsBitcastReshape), + TritonSupportTestTypeOpcodeAndDeviceToString); using UnaryElementwiseTest = TritonSupportTestWithParam; @@ -280,36 +286,38 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc); } +constexpr std::array kTestedOpsUnaryElementwise = {HloOpcode::kAbs, + HloOpcode::kCbrt, + HloOpcode::kCeil, + HloOpcode::kClz, + HloOpcode::kConvert, + HloOpcode::kCos, + HloOpcode::kErf, + HloOpcode::kExp, + HloOpcode::kExpm1, + HloOpcode::kFloor, + HloOpcode::kImag, + HloOpcode::kIsFinite, + HloOpcode::kLog, + HloOpcode::kLog1p, + HloOpcode::kLogistic, + HloOpcode::kNegate, + HloOpcode::kNot, + HloOpcode::kPopulationCount, + HloOpcode::kReal, + HloOpcode::kReducePrecision, + HloOpcode::kRoundNearestAfz, + HloOpcode::kRoundNearestEven, + HloOpcode::kRsqrt, + HloOpcode::kSign, + HloOpcode::kSin, + HloOpcode::kSqrt, + HloOpcode::kTan, + HloOpcode::kTanh}; + INSTANTIATE_TEST_SUITE_P( UnaryElementwiseTestSuite, UnaryElementwiseTest, - AllTestCombinationsForOpcodes({HloOpcode::kAbs, - HloOpcode::kCbrt, - HloOpcode::kCeil, - HloOpcode::kClz, - HloOpcode::kConvert, - HloOpcode::kCos, - HloOpcode::kErf, - HloOpcode::kExp, - HloOpcode::kExpm1, - HloOpcode::kFloor, - HloOpcode::kImag, - HloOpcode::kIsFinite, - HloOpcode::kLog, - HloOpcode::kLog1p, - HloOpcode::kLogistic, - HloOpcode::kNegate, - HloOpcode::kNot, - HloOpcode::kPopulationCount, - HloOpcode::kReal, - HloOpcode::kReducePrecision, - HloOpcode::kRoundNearestAfz, - HloOpcode::kRoundNearestEven, - HloOpcode::kRsqrt, - HloOpcode::kSign, - HloOpcode::kSin, - HloOpcode::kSqrt, - HloOpcode::kTan, - HloOpcode::kTanh}), + AllTestCombinationsForOpcodes(kTestedOpsUnaryElementwise), TritonSupportTestTypeOpcodeAndDeviceToString); using BinaryElementwiseTest = TritonSupportTestWithParam; @@ -353,15 +361,27 @@ ENTRY triton_computation { skip_failure_branch_to_avoid_crash); } +constexpr std::array kTestedOpsBinaryElementwise = { + HloOpcode::kAnd, + HloOpcode::kOr, + HloOpcode::kXor, + HloOpcode::kAdd, + HloOpcode::kMultiply, + HloOpcode::kMaximum, + HloOpcode::kMinimum, + HloOpcode::kSubtract, + HloOpcode::kAtan2, + HloOpcode::kDivide, + HloOpcode::kRemainder, + HloOpcode::kPower, + HloOpcode::kShiftLeft, + HloOpcode::kShiftRightArithmetic, + HloOpcode::kShiftRightLogical, + HloOpcode::kCompare}; + INSTANTIATE_TEST_SUITE_P( BinaryElementwiseTestSuite, BinaryElementwiseTest, - AllTestCombinationsForOpcodes( - {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, HloOpcode::kAdd, - HloOpcode::kMultiply, HloOpcode::kMaximum, HloOpcode::kMinimum, - HloOpcode::kSubtract, HloOpcode::kAtan2, HloOpcode::kDivide, - HloOpcode::kRemainder, HloOpcode::kPower, HloOpcode::kShiftLeft, - HloOpcode::kShiftRightArithmetic, HloOpcode::kShiftRightLogical, - HloOpcode::kCompare}), + AllTestCombinationsForOpcodes(kTestedOpsBinaryElementwise), TritonSupportTestTypeOpcodeAndDeviceToString); using TernaryElementwiseTest = TritonSupportTestWithParam; @@ -387,10 +407,13 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc); } -INSTANTIATE_TEST_SUITE_P(TernaryElementwiseTestSuite, TernaryElementwiseTest, - AllTestCombinationsForOpcodes({HloOpcode::kSelect, - HloOpcode::kClamp}), - TritonSupportTestTypeOpcodeAndDeviceToString); +constexpr std::array kTestedOpsTernaryElementwise = {HloOpcode::kSelect, + HloOpcode::kClamp}; + +INSTANTIATE_TEST_SUITE_P( + TernaryElementwiseTestSuite, TernaryElementwiseTest, + AllTestCombinationsForOpcodes(kTestedOpsTernaryElementwise), + TritonSupportTestTypeOpcodeAndDeviceToString); using ReduceTest = TritonSupportTestWithParam; @@ -569,8 +592,10 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } +constexpr std::array kTestedOpsReduction = {HloOpcode::kReduce}; + INSTANTIATE_TEST_SUITE_P(ReduceTestSuite, ReduceTest, - AllTestCombinationsForOpcodes({HloOpcode::kReduce}), + AllTestCombinationsForOpcodes(kTestedOpsReduction), TritonSupportTestTypeOpcodeAndDeviceToString); using CollectiveTest = TritonSupportTestWithParam; @@ -642,13 +667,119 @@ TEST_P(CollectiveTest, UnsupportedCollectivesFailGracefullyWithTriton) { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -INSTANTIATE_TEST_SUITE_P( - CollectiveTestSuite, CollectiveTest, - AllTestCombinationsForOpcodes({HloOpcode::kAllGather, HloOpcode::kAllReduce, - HloOpcode::kAllToAll, - HloOpcode::kCollectivePermute, - HloOpcode::kReduceScatter}), - TritonSupportTestTypeOpcodeAndDeviceToString); +constexpr std::array kTestedOpsCollectives = { + HloOpcode::kAllGather, HloOpcode::kAllReduce, HloOpcode::kAllToAll, + HloOpcode::kCollectivePermute, HloOpcode::kReduceScatter}; + +INSTANTIATE_TEST_SUITE_P(CollectiveTestSuite, CollectiveTest, + AllTestCombinationsForOpcodes(kTestedOpsCollectives), + TritonSupportTestTypeOpcodeAndDeviceToString); + +absl::flat_hash_set AllTestedOpcodes() { + // The return set is initialized with ops that are implicitly tested. + absl::flat_hash_set ret{HloOpcode::kParameter}; + + ret.insert(kTestedOpsBitcastReshape.begin(), kTestedOpsBitcastReshape.end()); + ret.insert(kTestedOpsUnaryElementwise.begin(), + kTestedOpsUnaryElementwise.end()); + ret.insert(kTestedOpsBinaryElementwise.begin(), + kTestedOpsBinaryElementwise.end()); + ret.insert(kTestedOpsTernaryElementwise.begin(), + kTestedOpsTernaryElementwise.end()); + ret.insert(kTestedOpsReduction.begin(), kTestedOpsReduction.end()); + ret.insert(kTestedOpsCollectives.begin(), kTestedOpsCollectives.end()); + return ret; +} + +absl::flat_hash_set AllUntestedOpcodes() { + return absl::flat_hash_set{HloOpcode::kAddDependency, + HloOpcode::kAfterAll, + HloOpcode::kAllGatherDone, + HloOpcode::kAllGatherStart, + HloOpcode::kAllReduceDone, + HloOpcode::kAllReduceStart, + HloOpcode::kAsyncDone, + HloOpcode::kAsyncStart, + HloOpcode::kAsyncUpdate, + HloOpcode::kBatchNormGrad, + HloOpcode::kBatchNormInference, + HloOpcode::kBatchNormTraining, + HloOpcode::kBitcastConvert, + HloOpcode::kBroadcast, + HloOpcode::kCall, + HloOpcode::kCholesky, + HloOpcode::kCollectiveBroadcast, + HloOpcode::kCollectivePermuteDone, + HloOpcode::kCollectivePermuteStart, + HloOpcode::kComplex, + HloOpcode::kConcatenate, + HloOpcode::kConditional, + HloOpcode::kConstant, + HloOpcode::kConvolution, + HloOpcode::kCopy, + HloOpcode::kCopyDone, + HloOpcode::kCopyStart, + HloOpcode::kCustomCall, + HloOpcode::kDomain, + HloOpcode::kDot, + HloOpcode::kDynamicReshape, + HloOpcode::kDynamicSlice, + HloOpcode::kDynamicUpdateSlice, + HloOpcode::kFft, + HloOpcode::kFusion, + HloOpcode::kGather, + HloOpcode::kGetDimensionSize, + HloOpcode::kGetTupleElement, + HloOpcode::kInfeed, + HloOpcode::kIota, + HloOpcode::kMap, + HloOpcode::kOptimizationBarrier, + HloOpcode::kOutfeed, + HloOpcode::kPad, + HloOpcode::kPartitionId, + HloOpcode::kRecv, + HloOpcode::kRecvDone, + HloOpcode::kReduceWindow, + HloOpcode::kReplicaId, + HloOpcode::kReverse, + HloOpcode::kRng, + HloOpcode::kRngBitGenerator, + HloOpcode::kRngGetAndUpdateState, + HloOpcode::kScatter, + HloOpcode::kSelectAndScatter, + HloOpcode::kSend, + HloOpcode::kSendDone, + HloOpcode::kSetDimensionSize, + HloOpcode::kSlice, + HloOpcode::kSort, + HloOpcode::kStochasticConvert, + HloOpcode::kTopK, + HloOpcode::kTranspose, + HloOpcode::kTriangularSolve, + HloOpcode::kTuple, + HloOpcode::kWhile}; +} + +TEST(OpCoverage, TestedAndUntestedDoNotOverlap) { + absl::flat_hash_set untested_opcodes = AllUntestedOpcodes(); + for (HloOpcode tested : AllTestedOpcodes()) { + EXPECT_FALSE(untested_opcodes.contains(tested)) + << "Opcode `" << HloOpcodeString(tested) + << "` appears in both tested and untested opcodes."; + } +} + +TEST(OpCoverage, AllOpcodesAppearInTestedOrUntested) { + absl::flat_hash_set untested_opcodes = AllUntestedOpcodes(); + absl::flat_hash_set tested_opcodes = AllTestedOpcodes(); + for (int opcode_index = 0; opcode_index < HloOpcodeCount(); ++opcode_index) { + auto opcode = static_cast(opcode_index); + EXPECT_TRUE(untested_opcodes.contains(opcode) || + tested_opcodes.contains(opcode)) + << "Opcode `" << HloOpcodeString(opcode) + << "` does not appear in tested or untested opcodes."; + } +} } // namespace } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index ae1de21696a693..be7ba5c92c3da9 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -74,7 +74,6 @@ limitations under the License. #include "xla/service/all_reduce_folder.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_reduce_reassociate.h" -#include "xla/service/all_reduce_splitter.h" #include "xla/service/async_collective_creator.h" #include "xla/service/batchnorm_expander.h" #include "xla/service/bitcast_dtypes_expander.h" @@ -114,25 +113,14 @@ limitations under the License. #include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h" #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/conv_layout_normalization.h" -#include "xla/service/gpu/dot_operand_converter.h" #include "xla/service/gpu/execution_stream_assignment.h" #include "xla/service/gpu/fusion_pipeline.h" -#include "xla/service/gpu/gpu_algebraic_simplifier.h" -#include "xla/service/gpu/gpu_all_gather_optimizer.h" -#include "xla/service/gpu/gpu_async_collective_annotator.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" -#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h" #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" -#include "xla/service/gpu/gpu_layout_assignment.h" #include "xla/service/gpu/gpu_p2p_pipeliner.h" -#include "xla/service/gpu/gpu_reduce_scatter_creator.h" -#include "xla/service/gpu/gpu_sanitize_constant_names.h" -#include "xla/service/gpu/gpu_scatter_expander.h" #include "xla/service/gpu/gpu_spmd_pipeline.h" -#include "xla/service/gpu/gpu_windowed_einsum_handler.h" #include "xla/service/gpu/hlo_fusion_stats.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" @@ -142,32 +130,27 @@ limitations under the License. #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/gpu/move_copy_to_users.h" -#include "xla/service/gpu/pipelined_p2p_rewriter.h" #include "xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h" -#include "xla/service/gpu/reduction_degenerate_dim_remover.h" -#include "xla/service/gpu/reduction_dimension_grouper.h" -#include "xla/service/gpu/reduction_layout_normalizer.h" -#include "xla/service/gpu/reduction_splitter.h" #include "xla/service/gpu/reduction_utils.h" -#include "xla/service/gpu/rename_fusions.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/runtime_intrinsics.h" -#include "xla/service/gpu/scatter_slice_simplifier.h" -#include "xla/service/gpu/softmax_rewriter_triton.h" -#include "xla/service/gpu/stream_attribute_annotator.h" -#include "xla/service/gpu/stream_attribute_async_wrapper.h" #include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/gpu/topk_specializer.h" -#include "xla/service/gpu/topk_splitter.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "xla/service/gpu/transforms/algorithm_checker.h" +#include "xla/service/gpu/transforms/all_gather_optimizer.h" #include "xla/service/gpu/transforms/all_reduce_blueconnect.h" +#include "xla/service/gpu/transforms/all_reduce_splitter.h" +#include "xla/service/gpu/transforms/async_collective_annotator.h" +#include "xla/service/gpu/transforms/async_wrapper.h" #include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h" #include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h" #include "xla/service/gpu/transforms/command_buffer_scheduling.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" +#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h" #include "xla/service/gpu/transforms/cudnn_custom_call_converter.h" #include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/transforms/dot_dimension_sorter.h" +#include "xla/service/gpu/transforms/dot_operand_converter.h" #include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h" #include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" @@ -175,8 +158,27 @@ limitations under the License. #include "xla/service/gpu/transforms/gemm_fusion.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/gemv_rewriter.h" -#include "xla/service/gpu/tree_reduction_rewriter.h" -#include "xla/service/gpu/triton_fusion_numerics_verifier.h" +#include "xla/service/gpu/transforms/layout_assignment.h" +#include "xla/service/gpu/transforms/move_copy_to_users.h" +#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h" +#include "xla/service/gpu/transforms/reduce_scatter_creator.h" +#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h" +#include "xla/service/gpu/transforms/reduction_dimension_grouper.h" +#include "xla/service/gpu/transforms/reduction_layout_normalizer.h" +#include "xla/service/gpu/transforms/reduction_splitter.h" +#include "xla/service/gpu/transforms/rename_fusions.h" +#include "xla/service/gpu/transforms/sanitize_constant_names.h" +#include "xla/service/gpu/transforms/scatter_expander.h" +#include "xla/service/gpu/transforms/scatter_slice_simplifier.h" +#include "xla/service/gpu/transforms/softmax_rewriter_triton.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" +#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h" +#include "xla/service/gpu/transforms/topk_specializer.h" +#include "xla/service/gpu/transforms/topk_splitter.h" +#include "xla/service/gpu/transforms/transpose_dimension_grouper.h" +#include "xla/service/gpu/transforms/tree_reduction_rewriter.h" +#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h" +#include "xla/service/gpu/transforms/windowed_einsum_handler.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_computation_deduplicator.h" #include "xla/service/hlo_constant_folding.h" @@ -408,6 +410,16 @@ GpuThunkAotCompilationResult::LoadExecutable( platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(), /*llvm_module_constants=*/nullptr, /*emit_kernels=*/false); + + absl::string_view cache_file_path = + hlo_module->config().debug_options().xla_gpu_kernel_cache_file(); + if (!cache_file_path.empty() && + hlo_module->config() + .debug_options() + .xla_gpu_enable_llvm_module_compilation_parallelism()) { + TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path)); + } + auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); TF_RETURN_IF_ERROR( ir_emitter->EmitHloComputation(hlo_module->entry_computation())); @@ -495,7 +507,7 @@ AlgebraicSimplifierOptions LayoutInsensitiveAlgebraicSimplifierOptions( AlgebraicSimplifierOptions layout_insensitive_algsimp_opts = opts_from_compiler; layout_insensitive_algsimp_opts.set_conv_is_lowerable_callback( - GpuConvRewriter::ConvIsLowerable); + ConvRewriter::ConvIsLowerable); layout_insensitive_algsimp_opts.set_enable_dot_strength_reduction( hlo_module_config.debug_options() .xla_gpu_enable_dot_strength_reduction()); @@ -629,7 +641,7 @@ absl::Status RunOptimizationPasses( HloPassPipeline pipeline("optimization"); AddHloVerifier(&pipeline); if (debug_options.xla_gpu_multi_streamed_windowed_einsum()) { - pipeline.AddPass(); + pipeline.AddPass(); } pipeline.AddPass(); pipeline.AddPass(); @@ -1123,7 +1135,7 @@ absl::Status RunPostFusionCollectiveOptimizationPasses(HloModule* hlo_module) { return false; } }; - pipeline.AddPass(convert_to_async); + pipeline.AddPass(convert_to_async); return pipeline.Run(hlo_module).status(); } @@ -1179,6 +1191,8 @@ absl::Status RunPostFusionVerificationPasses( absl::Status GpuCompiler::OptimizeHloModule( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config) { + tsl::profiler::TraceMe traceme("GpuCompiler::OptimizeHloModule"); + CheckNotScheduled(hlo_module); LogDebugOptions(hlo_module); @@ -1307,6 +1321,30 @@ absl::Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { .status(); } +namespace { +void AddGemmRewriterPasses(HloPassPipeline& pipeline, + const DebugOptions& debug_options, + const se::GpuComputeCapability gpu_version, + const int32_t toolkit_version) { + // Adding bias to GEMMs is helpful for skipping kernel launches for `add` + // operations. However, the bias term can add dependencies between the GEMMs + // that could otherwise be parallelized. Because of this, we disable bias + // addition when async dot is enabled. + GemmRewriterOptions::BiasMode bias_mode = + GemmRewriterOptions::BiasMode::kBias; + if (debug_options.xla_gpu_async_dot()) { + bias_mode = GemmRewriterOptions::BiasMode::kNoBias; + } + + pipeline.AddPass( + gpu_version, toolkit_version, + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only, bias_mode}); + pipeline.AddPass( + gpu_version, toolkit_version, + GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only, bias_mode}); +} +} // namespace + absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config, @@ -1399,10 +1437,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(gpu_version); } - pipeline.AddPass(gpu_version, GetToolkitVersion(), - /*f8_rewrite=*/true); - pipeline.AddPass(gpu_version, GetToolkitVersion(), - /*f8_rewrite=*/false); + // Rewrite GEMMs into custom calls. + AddGemmRewriterPasses(pipeline, debug_options, gpu_version, + GetToolkitVersion()); // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); @@ -1417,6 +1454,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); // Run Softmax fusion after layout normalization. We expect a default layout @@ -1445,7 +1483,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( bool ignore_small_reduce_dims = !debug_options.xla_gpu_enable_priority_fusion(); pipeline.AddPass>(ignore_small_reduce_dims); - pipeline.AddPass>(gpu_version); + pipeline.AddPass>(gpu_version); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -1470,13 +1508,23 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); // TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated // here for possibly better cuBLAS performance. - pipeline.AddPass(gpu_version, GetToolkitVersion(), - /*f8_rewrite=*/true); - pipeline.AddPass(gpu_version, GetToolkitVersion(), - /*f8_rewrite=*/false); + AddGemmRewriterPasses(pipeline, debug_options, gpu_version, + GetToolkitVersion()); + // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); + // Wrap `dot` operations into async computations in an effort to parallelize + // matrix operations. This pass needs to run after the GEMM rewriter so that + // we still use the native GEMM implementation. + if (debug_options.xla_gpu_async_dot()) { + pipeline.AddPass([](HloInstruction* instruction) { + // TODO(b/339654953): Use a better heuristic to determine whether a + // `dot` operation should be wrapped in an async computation. + return instruction->opcode() == HloOpcode::kCustomCall; + }); + } + pipeline.AddPass( static_cast(stream_executor::MemoryType::kHost), /* after_layout= */ true); @@ -2062,6 +2110,8 @@ GpuCompiler::CompileToBackendResult( HloModule* module, llvm::LLVMContext* llvm_context, se::StreamExecutor* executor, const CompileOptions& options, const se::DeviceDescription& gpu_device_info) { + tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); + TF_RETURN_IF_ERROR(RunPreSchedulingPasses(module, executor)); TF_ASSIGN_OR_RETURN( ScheduleMetadata schedule_metadata, @@ -2150,8 +2200,8 @@ absl::StatusOr> GpuCompiler::RunBackend( }}; BinaryMap dnn_compiled_graphs; if (stream_exec) { - TF_RETURN_IF_ERROR(RunCudnnFusionCompilerPass(module.get(), stream_exec, - &dnn_compiled_graphs)); + TF_RETURN_IF_ERROR(RunCudnnCompilerPasses(module.get(), stream_exec, + &dnn_compiled_graphs)); } const DebugOptions& debug_opts = module->config().debug_options(); @@ -2484,7 +2534,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( pipeline.AddPass( gpu_device_info, toolkit_version, driver_version.value_or(toolkit_version)); - pipeline.AddPass(); + pipeline.AddPass(); } AddHloVerifier(&main_pipeline, diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index aa22bfcf3ba338..456e6755b0d83a 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -171,10 +171,10 @@ class GpuCompiler : public LLVMCompiler { return absl::OkStatus(); } - // Runs cuDNN fusion compiler pass. - virtual absl::Status RunCudnnFusionCompilerPass( - HloModule* module, se::StreamExecutor* stream_exec, - BinaryMap* dnn_compiled_graphs) { + // Runs cuDNN fusion and custom call compiler passes. + virtual absl::Status RunCudnnCompilerPasses(HloModule* module, + se::StreamExecutor* stream_exec, + BinaryMap* dnn_compiled_graphs) { return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 93057c595976dd..9e07fb6e38fe66 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -33,8 +34,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/primitive_util.h" +#include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_hlo_schedule.h" @@ -44,8 +49,11 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/xla_debug_info_manager.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" @@ -804,6 +812,78 @@ TEST_F(KernelCacheTest, AllKernelsAreCachedBecauseSplitModuleUsesRoundRobin) { EXPECT_EQ(CacheEntryCount(), 4); } +TEST_F(KernelCacheTest, CachingWorksWithLoadedExecutables) { + const std::string kHloAdd1 = R"( +add1 { + p = s32[] parameter(0) + c = s32[] constant(1) + ROOT a = s32[] add(p, c) +} + +ENTRY e { + p = s32[] parameter(0) + ROOT r = s32[] fusion(p), kind=kLoop, calls=add1 +})"; + + const std::string kHloAdd2 = R"( +add2 { + p = s32[] parameter(0) + c = s32[] constant(2) + ROOT a = s32[] add(p, c) +} + +ENTRY e { + p = s32[] parameter(0) + ROOT r = s32[] fusion(p), kind=kLoop, calls=add2 +})"; + + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + se::PlatformManager::PlatformWithName("cuda")); + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, + platform->ExecutorForDevice(0)); + + Compiler* compiler = backend().compiler(); + AotCompilationOptions aot_options(compiler->PlatformId()); + aot_options.set_executor(stream_exec); + + auto test = [this, &compiler, &aot_options](absl::string_view hlo, int input, + int expected_result) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto module_group = std::make_unique(std::move(module)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> aot_results, + compiler->CompileAheadOfTime(std::move(module_group), aot_options)); + + TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, + aot_results[0]->SerializeAsString()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr aot_result, + compiler->LoadAotCompilationResult(serialized_aot_result)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + aot_result->LoadExecutable(compiler, aot_options.executor())); + + const xla::Literal literal_input = + xla::LiteralUtil::CreateR0(input); + const xla::Literal literal_expected_result = + xla::LiteralUtil::CreateR0(expected_result); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, + GetHloRunner().value()->ExecuteWithExecutable( + executable.get(), {&literal_input})); + + EXPECT_TRUE(LiteralTestUtil::Equal(result, literal_expected_result)); + }; + + test(kHloAdd1, 1, 2); + test(kHloAdd2, 1, 3); + // The test used to fail on the second execution of the second module when it + // was already cached. + test(kHloAdd2, 1, 3); +} + class KernelCacheTestSingleThreaded : public KernelCacheTest { public: DebugOptions GetDebugOptionsForTest() override { diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 25fe6510f0baa8..40055de2fa0f16 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -878,9 +878,7 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( } module_allocations_[executor][i] = buffer_allocations.GetDeviceAddress(i); - VLOG(5) << "Gpu address changed for module " << module_name_ - << ", allocation info: \n" - << allocations[i].ToShortString(); + VLOG(5) << "Gpu address changed for module " << module_name_; } } } diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc deleted file mode 100644 index 566c0068f5dbba..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc +++ /dev/null @@ -1,719 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/gpu_fused_mha_runner.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "Eigen/Core" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace { -using se::DeviceMemory; -using se::DeviceMemoryBase; -using se::dnn::DataType; -using se::dnn::MatmulTensorDescriptor; -using se::dnn::TensorDescriptor; - -template -absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, - RunFusedMHAOptions options, - DeviceMemory lhs_bmm1_buffer, - DeviceMemory rhs_bmm1_buffer, - DeviceMemory rhs_bmm2_buffer, - DeviceMemory output_buffer, - DeviceMemoryBase bias_buffer, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase activation_output, - DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { - se::dnn::LazyOpRunner *lazy_runner = - options.runner_cache->AsFusedMHARunner(); - std::optional> local_runner; - if (!lazy_runner) { - local_runner.emplace(params.config->algorithm); - lazy_runner = &*local_runner; - } - std::optional dropout_rate; - if (params.config->dropout_rate) { - dropout_rate = *params.config->dropout_rate; - } - - std::optional seed; - if (params.config->seed) { - seed = *params.config->seed; - } - - TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAOp::Config config, - params.config->AsDnnFusedMHAOpConfig()); - TF_ASSIGN_OR_RETURN(auto *runner, - lazy_runner->GetOrCreateRunner(config, stream)); - return (*runner)(stream, options.profile_result, scratch_memory, - lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer, - output_buffer, bias_buffer, activation_output, seqlen_q, - seqlen_k); -} - -template -absl::Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHAOptions options) { - auto lhs_bmm1_buffer = se::DeviceMemory(params.lhs_bmm1_buffer); - auto rhs_bmm1_buffer = se::DeviceMemory(params.rhs_bmm1_buffer); - auto rhs_bmm2_buffer = se::DeviceMemory(params.rhs_bmm2_buffer); - auto output_buffer = se::DeviceMemory(params.output_buffer); - auto activation_buffer = - params.activation_buffer.has_value() - ? se::DeviceMemory(*params.activation_buffer) - : se::DeviceMemoryBase(); - auto bias_buffer = params.bias_buffer.has_value() - ? se::DeviceMemory(*params.bias_buffer) - : se::DeviceMemoryBase(); - auto seqlen_q_buffer = - params.seqlen_q_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_q_buffer) - : se::DeviceMemoryBase(); - auto seqlen_k_buffer = - params.seqlen_k_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_k_buffer) - : se::DeviceMemoryBase(); - se::dnn::AlgorithmDesc algorithm = params.config->algorithm; - if (options.runner_cache) { - algorithm = options.runner_cache->ToAlgorithmDesc(); - } - - absl::Status run_status = absl::OkStatus(); - switch (params.config->kind) { - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - run_status = RunFusedMHA( - params, stream, options, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, bias_buffer, scratch_memory, - activation_buffer, seqlen_q_buffer, seqlen_k_buffer); - break; - default: - return Internal("Invalid cuDNN fMHA kind"); - } - - if (!run_status.ok()) { - return run_status; - } - - if (!stream->ok()) { - return Internal("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); - } - - return absl::OkStatus(); -} - -template -absl::Status RunFusedMHABackward( - GpufMHABackwardParams params, se::Stream *stream, - RunFusedMHABackwardOptions options, - DeviceMemory bmm1_grad_gemm1_rhs_buffer, - DeviceMemory bmm1_grad_gemm2_rhs_buffer, - DeviceMemory bmm2_grad_gemm1_lhs_buffer, - DeviceMemory bmm2_grad_gemm2_rhs_buffer, - DeviceMemory d_output_buffer, - DeviceMemory d_bmm1_lhs_buffer, - DeviceMemory d_bmm1_rhs_buffer, - DeviceMemory d_bmm2_rhs_buffer, DeviceMemoryBase d_s_buffer, - DeviceMemoryBase d_bias_buffer, DeviceMemoryBase fwd_output_buffer, - DeviceMemoryBase bias_buffer, DeviceMemoryBase scratch_memory, - DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { - se::dnn::LazyOpRunner *lazy_runner = - options.runner_cache->AsFusedMHABackwardRunner(); - std::optional> - local_runner; - if (!lazy_runner) { - local_runner.emplace(params.config->algorithm); - lazy_runner = &*local_runner; - } - std::optional dropout_rate; - if (params.config->dropout_rate) { - dropout_rate = *params.config->dropout_rate; - } - - std::optional seed; - if (params.config->seed) { - seed = *params.config->seed; - } - - TF_ASSIGN_OR_RETURN(se::dnn::FusedMHABackwardOp::Config config, - params.config->AsDnnFusedMHABackwardOpConfig()); - TF_ASSIGN_OR_RETURN(auto *runner, - lazy_runner->GetOrCreateRunner(config, stream)); - // TODO: pass in real softmax_sum, dQ_accum, fwd_output - return (*runner)(stream, options.profile_result, scratch_memory, - bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, seqlen_q, seqlen_k); - return absl::OkStatus(); -} - -template -absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, - se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHABackwardOptions options) { - auto bmm1_grad_gemm1_rhs_buffer = - se::DeviceMemory(params.bmm1_grad_gemm1_rhs_buffer); - auto bmm1_grad_gemm2_rhs_buffer = - se::DeviceMemory(params.bmm1_grad_gemm2_rhs_buffer); - auto bmm2_grad_gemm1_lhs_buffer = - se::DeviceMemory(params.bmm2_grad_gemm1_lhs_buffer); - auto bmm2_grad_gemm2_rhs_buffer = - se::DeviceMemory(params.bmm2_grad_gemm2_rhs_buffer); - auto d_output_buffer = se::DeviceMemory(params.d_output_buffer); - auto d_bmm1_lhs_buffer = - se::DeviceMemory(params.d_bmm1_lhs_buffer); - auto d_bmm1_rhs_buffer = - se::DeviceMemory(params.d_bmm1_rhs_buffer); - auto d_bmm2_rhs_buffer = - se::DeviceMemory(params.d_bmm2_rhs_buffer); - - // optional buffers - auto d_s_buffer = params.d_s_buffer.has_value() - ? se::DeviceMemory(*params.d_s_buffer) - : se::DeviceMemoryBase(); - - auto d_bias_buffer = params.d_bias_buffer.has_value() - ? se::DeviceMemory(*params.d_bias_buffer) - : se::DeviceMemoryBase(); - - auto fwd_output_buffer = - params.fwd_output_buffer.has_value() - ? se::DeviceMemory(*params.fwd_output_buffer) - : se::DeviceMemoryBase(); - - auto bias_buffer = params.bias_buffer.has_value() - ? se::DeviceMemory(*params.bias_buffer) - : se::DeviceMemoryBase(); - - auto seqlen_q_buffer = - params.seqlen_q_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_q_buffer) - : se::DeviceMemoryBase(); - - auto seqlen_k_buffer = - params.seqlen_k_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_k_buffer) - : se::DeviceMemoryBase(); - - se::dnn::AlgorithmDesc algorithm = params.config->algorithm; - if (options.runner_cache) { - algorithm = options.runner_cache->ToAlgorithmDesc(); - } - - absl::Status run_status = absl::OkStatus(); - switch (params.config->kind) { - case CudnnfMHAKind::kBackwardSoftmaxDropout: - case CudnnfMHAKind::kBackwardSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - run_status = RunFusedMHABackward( - params, stream, options, bmm1_grad_gemm1_rhs_buffer, - bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, - bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, scratch_memory, seqlen_q_buffer, - seqlen_k_buffer); - break; - default: - return Internal("Invalid cuDNN fMHA kind"); - } - - if (!run_status.ok()) { - return run_status; - } - - if (!stream->ok()) { - return Internal("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); - } - - return run_status; -} -} // namespace - -/*static*/ absl::StatusOr GpufMHAConfig::For( - const GpufMHADescriptor &desc) { - // Get shapes from desc. - const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape; - const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape; - const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape; - const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape; - const Shape &output_shape = desc.output_shapes[0]; - - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN( - DataType lhs_bmm1_type, - GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type())); - TF_ASSIGN_OR_RETURN( - DataType rhs_bmm1_type, - GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType rhs_bmm2_type, - GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type())); - TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type, - GetDNNDataTypeFromPrimitiveType( - intermediate_lhs_bmm2_shape.element_type())); - TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType( - output_shape.element_type())); - GpufMHAConfig config; - config.input_type = lhs_bmm1_shape.element_type(); - config.output_type = output_shape.element_type(); - - // Get MatmulTensorDescriptors for BMM1 - config.lhs_bmm1 = - MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(), - desc.lhs_bmm1_shape.layout().minor_to_major(), - desc.bmm1_dnums.lhs_batch_dimensions(), - desc.bmm1_dnums.lhs_contracting_dimensions()); - config.rhs_bmm1 = - MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(), - desc.rhs_bmm1_shape.layout().minor_to_major(), - desc.bmm1_dnums.rhs_batch_dimensions(), - desc.bmm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 - config.rhs_bmm2 = - MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(), - desc.rhs_bmm2_shape.layout().minor_to_major(), - desc.bmm2_dnums.rhs_batch_dimensions(), - desc.bmm2_dnums.rhs_contracting_dimensions()); - - config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For( - lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(), - desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(), - desc.bmm2_dnums.lhs_batch_dimensions(), - desc.bmm2_dnums.lhs_contracting_dimensions()); - - config.output = TensorDescriptor::For(output_type, output_shape.dimensions(), - output_shape.layout().minor_to_major()); - - if (desc.output_shapes.size() > 1) { - const Shape &activation_shape = desc.output_shapes.back(); - // Generally, activation should have same type as output, but set it - // explicityly just to be safe. - TF_ASSIGN_OR_RETURN( - DataType activation_type, - GetDNNDataTypeFromPrimitiveType(activation_shape.element_type())); - config.activation = - TensorDescriptor::For(activation_type, activation_shape.dimensions(), - activation_shape.layout().minor_to_major()); - } - - if (desc.mask_shape) { - const Shape &mask_shape = *desc.mask_shape; - TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( - mask_shape.element_type())); - config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), - mask_shape.layout().minor_to_major()); - } - - if (desc.bias_shape) { - const Shape &bias_shape = *desc.bias_shape; - TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( - bias_shape.element_type())); - config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), - bias_shape.layout().minor_to_major()); - } - config.kind = desc.kind; - config.mask_type = desc.mask_type; - const CudnnfMHABackendConfig &backend_config = desc.backend_config; - config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - config.fmha_scale.emplace(backend_config.fmha_scale()); - config.dropout_rate.emplace(backend_config.dropout_rate()); - config.seed.emplace(backend_config.seed()); - return config; -} - -absl::StatusOr -GpufMHAConfig::AsDnnFusedMHAOpConfig() const { - double scale = 1.0; - if (fmha_scale.has_value()) { - scale = *fmha_scale; - } - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); - - return se::dnn::FusedMHAOp::Config{ - scale, lhs_bmm1, rhs_bmm1, rhs_bmm2, intermediate_lhs_bmm2, - output, bias, activation, dropout_rate, seed, - mask_type}; -} - -/*static*/ absl::StatusOr GpufMHABackwardConfig::For( - const GpufMHABackwardDescriptor &desc) { - // Get shapes from desc. - - const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape; - const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape; - const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape; - const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape; - const Shape &d_output_shape = desc.d_output_shape; - const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape; - const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape; - const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape; - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm1_grad_gemm1_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm1_grad_gemm2_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm2_grad_gemm1_lhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm2_grad_gemm2_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_output_type, - GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm1_lhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm1_rhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm2_rhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type())); - - GpufMHABackwardConfig config; - config.input_type = bmm1_grad_gemm1_rhs_shape.element_type(); - config.output_type = d_bmm1_lhs_shape.element_type(); - - // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1 - config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For( - bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(), - desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(), - desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(), - desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2 - config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For( - bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(), - desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(), - desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(), - desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 grad GEMM 1 - config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For( - bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), - desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(), - desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions()); - - config.d_output = MatmulTensorDescriptor::For( - d_output_type, d_output_shape.dimensions(), - desc.d_output_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(), - desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 grad GEMM 2 - config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For( - bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(), - desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(), - desc.bmm2_grad_gemm2_dnums - .rhs_contracting_dimensions()); // FMHA TODO: transpose here? - - config.d_bmm1_lhs = - TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(), - d_bmm1_lhs_shape.layout().minor_to_major()); - config.d_bmm1_rhs = - TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(), - d_bmm1_rhs_shape.layout().minor_to_major()); - config.d_bmm2_rhs = - TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(), - d_bmm2_rhs_shape.layout().minor_to_major()); - config.d_s = TensorDescriptor::For( - bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), - bmm2_grad_gemm1_lhs_shape.layout().minor_to_major()); - - if (desc.d_bias_shape) { - const Shape &d_bias_shape = *desc.d_bias_shape; - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType( - d_bias_shape.element_type())); - config.d_bias = - TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(), - d_bias_shape.layout().minor_to_major()); - } - - if (desc.mask_shape) { - const Shape &mask_shape = *desc.mask_shape; - TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( - mask_shape.element_type())); - config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), - mask_shape.layout().minor_to_major()); - } - if (desc.fwd_output_shape) { - const Shape &fwd_output_shape = *desc.fwd_output_shape; - TF_ASSIGN_OR_RETURN( - DataType fwd_output_type, - GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type())); - config.fwd_output = - TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(), - fwd_output_shape.layout().minor_to_major()); - } - - if (desc.bias_shape) { - const Shape &bias_shape = *desc.bias_shape; - TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( - bias_shape.element_type())); - config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), - bias_shape.layout().minor_to_major()); - } - - config.kind = desc.kind; - config.mask_type = desc.mask_type; - config.force_deterministic = desc.force_deterministic; - const CudnnfMHABackendConfig &backend_config = desc.backend_config; - config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - config.fmha_scale.emplace(backend_config.fmha_scale()); - config.dropout_rate.emplace(backend_config.dropout_rate()); - config.seed.emplace(backend_config.seed()); - return config; -} - -absl::StatusOr -GpufMHABackwardConfig::AsDnnFusedMHABackwardOpConfig() const { - double scale = 1.0; - if (fmha_scale.has_value()) { - scale = *fmha_scale; - } - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); - - return se::dnn::FusedMHABackwardOp::Config{scale, - bmm1_grad_gemm1_rhs, - bmm1_grad_gemm2_rhs, - bmm2_grad_gemm1_lhs, - bmm2_grad_gemm2_rhs, - d_output, - d_bmm1_lhs, - d_bmm1_rhs, - d_bmm2_rhs, - d_s, - d_bias, - fwd_output, - bias, - dropout_rate, - seed, - mask_type, - force_deterministic}; -} - -/*static*/ absl::StatusOr GpufMHAParams::For( - const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer) { - GpufMHAParams params; - params.config = &config; - params.lhs_bmm1_buffer = lhs_bmm1_buffer; - params.rhs_bmm1_buffer = rhs_bmm1_buffer; - params.rhs_bmm2_buffer = rhs_bmm2_buffer; - params.output_buffer = output_buffer; - params.activation_buffer = activation_buffer; - params.bias_buffer = bias_buffer; - params.seqlen_q_buffer = seqlen_q_buffer; - params.seqlen_k_buffer = seqlen_k_buffer; - return params; -} - -/*static*/ absl::StatusOr GpufMHABackwardParams::For( - const GpufMHABackwardConfig &config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer) { - GpufMHABackwardParams params; - params.config = &config; - params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer; - params.bmm1_grad_gemm2_rhs_buffer = bmm1_grad_gemm2_rhs_buffer; - params.bmm2_grad_gemm1_lhs_buffer = bmm2_grad_gemm1_lhs_buffer; - params.bmm2_grad_gemm2_rhs_buffer = bmm2_grad_gemm2_rhs_buffer; - params.d_output_buffer = d_output_buffer; - params.d_bmm1_lhs_buffer = d_bmm1_lhs_buffer; - params.d_bmm1_rhs_buffer = d_bmm1_rhs_buffer; - params.d_bmm2_rhs_buffer = d_bmm2_rhs_buffer; - params.d_s_buffer = d_s_buffer; - params.d_bias_buffer = d_bias_buffer; - params.fwd_output_buffer = fwd_output_buffer; - params.bias_buffer = bias_buffer; - params.seqlen_q_buffer = seqlen_q_buffer; - params.seqlen_k_buffer = seqlen_k_buffer; - return params; -} - -absl::Status RunGpuFMHA(const GpufMHAConfig &fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, - se::Stream *stream, RunFusedMHAOptions options) { - TF_ASSIGN_OR_RETURN( - GpufMHAParams params, - GpufMHAParams::For(fmha_config, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, bias_buffer, - activation_buffer, seqlen_q_buffer, seqlen_k_buffer)); - PrimitiveType input_primitive_type = fmha_config.input_type; - switch (input_primitive_type) { - case F16: - return RunGpuFMHAImpl( - params, stream, scratch_buffer, options); - case BF16: - return RunGpuFMHAImpl( - params, stream, scratch_buffer, options); - default: - return absl::UnimplementedError(absl::StrFormat( - "Unimplemented fused MHA with %s", ToString(fmha_config))); - } - return absl::OkStatus(); -} - -absl::Status RunGpuFMHABackward( - const GpufMHABackwardConfig &fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, se::Stream *stream, - RunFusedMHABackwardOptions options) { - TF_ASSIGN_OR_RETURN( - GpufMHABackwardParams params, - GpufMHABackwardParams::For( - fmha_config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, fwd_output_buffer, - bias_buffer, seqlen_q_buffer, seqlen_k_buffer)); - PrimitiveType input_primitive_type = fmha_config.input_type; - switch (input_primitive_type) { - case F16: - return RunGpuFMHABackwardImpl( - params, stream, scratch_buffer, options); - case BF16: - return RunGpuFMHABackwardImpl(params, stream, - scratch_buffer, options); - default: - return Unimplemented("Unimplemented fused MHA backward"); - } - return absl::OkStatus(); -} - -std::string ToString(const GpufMHAConfig &config) { - std::string result = "GpufMHAConfig:\n"; - absl::StrAppend(&result, - "input_type: ", PrimitiveType_Name(config.input_type), ", "); - absl::StrAppend( - &result, "output_type: ", PrimitiveType_Name(config.output_type), ", "); - absl::StrAppend(&result, "Kind: ", CudnnfMHAKindToString(config.kind), ", "); - if (config.fmha_scale) { - absl::StrAppend(&result, "fmha_scale: ", *config.fmha_scale, ", "); - } - if (config.dropout_rate) { - absl::StrAppend(&result, "dropout_rate: ", *config.dropout_rate, ", "); - } - if (config.seed) { - absl::StrAppend(&result, "seed: ", *config.seed, ", "); - } - absl::StrAppend(&result, "Algorithm Desc: ", config.algorithm.ToString(), - "\n"); - absl::StrAppend(&result, "lhs_bmm1: ", config.lhs_bmm1.ToString(), "\n"); - absl::StrAppend(&result, "rhs_bmm1: ", config.rhs_bmm1.ToString(), "\n"); - absl::StrAppend(&result, "rhs_bmm2: ", config.rhs_bmm2.ToString(), "\n"); - absl::StrAppend(&result, "intermediate_lhs_bmm2: ", - config.intermediate_lhs_bmm2.ToString(), "\n"); - absl::StrAppend(&result, "output: ", config.output.ToString(), "\n"); - - if (config.mask) { - absl::StrAppend(&result, "mask: ", (*config.mask).ToString(), "\n"); - } - - if (config.bias) { - absl::StrAppend(&result, "bias: ", (*config.bias).ToString(), "\n"); - } - - return result; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h deleted file mode 100644 index d0621cbdff6d74..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h +++ /dev/null @@ -1,431 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ -#define XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -inline absl::StatusOr AsCudnnFmhaMaskKind( - xla::gpu::CudnnfMHABackendConfig_MaskType mask_type) { - switch (mask_type) { - case xla::gpu::CudnnfMHABackendConfig::NO_MASK: - return xla::gpu::CudnnfMHAMaskKind::kNoMask; - case xla::gpu::CudnnfMHABackendConfig::PADDING: - return xla::gpu::CudnnfMHAMaskKind::kPadding; - case xla::gpu::CudnnfMHABackendConfig::CAUSAL: - return xla::gpu::CudnnfMHAMaskKind::kCausal; - case xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL: - return xla::gpu::CudnnfMHAMaskKind::kPaddingCausal; - case xla::gpu::CudnnfMHABackendConfig::ALIBI: - return xla::gpu::CudnnfMHAMaskKind::kAlibi; - default: - return xla::Internal("Unknown fmha mask kind."); - } -} - -// This is an interim structure to hold the parameters to construct a -// GpufMHAConfig. -// Struct to describe properties of a FMHA without being tied to specific -// IR. Will be used to help build FMHA thunks from either XLA HLO or -// LHLO GPU dialect in MLIR. -struct GpufMHADescriptor { - CudnnfMHAKind kind; - CudnnfMHABackendConfig backend_config; - CudnnfMHAMaskKind mask_type; - Shape lhs_bmm1_shape; - Shape rhs_bmm1_shape; - Shape rhs_bmm2_shape; - Shape intermediate_lhs_bmm2_shape; - // This will contain both output shape and activation shape - absl::InlinedVector output_shapes; - DotDimensionNumbers bmm1_dnums; - DotDimensionNumbers bmm2_dnums; - - std::optional mask_shape; - std::optional bias_shape; -}; - -struct GpufMHABackwardDescriptor { - CudnnfMHAKind kind; - CudnnfMHABackendConfig backend_config; - CudnnfMHAMaskKind mask_type; - Shape bmm1_grad_gemm1_rhs_shape; - Shape bmm1_grad_gemm2_rhs_shape; - Shape bmm2_grad_gemm1_lhs_shape; - Shape bmm2_grad_gemm2_rhs_shape; - Shape d_output_shape; - Shape d_bmm1_lhs_shape; - Shape d_bmm1_rhs_shape; - Shape d_bmm2_rhs_shape; - DotDimensionNumbers bmm1_grad_gemm1_dnums; - DotDimensionNumbers bmm1_grad_gemm2_dnums; - DotDimensionNumbers bmm2_grad_gemm1_dnums; - DotDimensionNumbers bmm2_grad_gemm2_dnums; - - std::optional d_s_shape; - std::optional fwd_output_shape; - std::optional mask_shape; - std::optional d_bias_shape; - std::optional bias_shape; - bool force_deterministic; -}; - -// Structure to describe static properties of a GPU fused Multi-Headed -// Attention. -struct GpufMHAConfig { - static absl::StatusOr For(const GpufMHADescriptor& fmha_desc); - - absl::StatusOr AsDnnFusedMHAOpConfig() const; - - PrimitiveType - input_type; // Capture the primitive type of one of the inputs of BMM1 - PrimitiveType output_type; - CudnnfMHAKind kind; - std::optional fmha_scale; - std::optional dropout_rate; - std::optional seed; - - se::dnn::AlgorithmDesc algorithm; - CudnnfMHAMaskKind mask_type; - // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] - // mask -> [batch_size, 1, q_seq_len, kv_seq_len] - se::dnn::MatmulTensorDescriptor lhs_bmm1; - se::dnn::MatmulTensorDescriptor rhs_bmm1; - se::dnn::MatmulTensorDescriptor rhs_bmm2; - se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2; - se::dnn::TensorDescriptor output; - - std::optional activation; - std::optional mask; - std::optional bias; -}; - -// Structure to describe static properties of a GPU fused Multi-Headed -// Attention backward. -struct GpufMHABackwardConfig { - static absl::StatusOr For( - const GpufMHABackwardDescriptor& fmha_desc); - - absl::StatusOr - AsDnnFusedMHABackwardOpConfig() const; - - PrimitiveType - input_type; // Capture the primitive type of one of the inputs of BMM1 - PrimitiveType output_type; - CudnnfMHAKind kind; - std::optional fmha_scale; - std::optional dropout_rate; - std::optional seed; - - se::dnn::AlgorithmDesc algorithm; - CudnnfMHAMaskKind mask_type; - // mask -> [batch_size, 1, q_seq_len, kv_seq_len] - // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] - se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; - se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs; - se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs; - se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs; - se::dnn::MatmulTensorDescriptor d_output; - se::dnn::TensorDescriptor d_bmm1_lhs; - se::dnn::TensorDescriptor d_bmm1_rhs; - se::dnn::TensorDescriptor d_bmm2_rhs; - std::optional d_s; - std::optional mask; - std::optional d_bias; - std::optional fwd_output; - std::optional bias; - bool force_deterministic; -}; - -// Implementation struct exposed for debugging and log analysis. -struct GpufMHAParams { - static absl::StatusOr For( - const GpufMHAConfig& config, se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, se::DeviceMemoryBase output_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer); - - const GpufMHAConfig* config; // Not owned - se::DeviceMemoryBase lhs_bmm1_buffer; - se::DeviceMemoryBase rhs_bmm1_buffer; - se::DeviceMemoryBase rhs_bmm2_buffer; - se::DeviceMemoryBase output_buffer; - std::optional activation_buffer; - std::optional bias_buffer; - std::optional seqlen_q_buffer; - std::optional seqlen_k_buffer; -}; - -struct GpufMHABackwardParams { - static absl::StatusOr For( - const GpufMHABackwardConfig& config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer); - - const GpufMHABackwardConfig* config; // Not owned - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer; - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer; - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer; - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer; - se::DeviceMemoryBase d_output_buffer; - se::DeviceMemoryBase d_bmm1_lhs_buffer; - se::DeviceMemoryBase d_bmm1_rhs_buffer; - se::DeviceMemoryBase d_bmm2_rhs_buffer; - std::optional d_s_buffer; - std::optional d_bias_buffer; - std::optional fwd_output_buffer; - std::optional bias_buffer; - std::optional seqlen_q_buffer; - std::optional seqlen_k_buffer; -}; - -class FusedMultiHeadedAttentionRunner { - public: - using Repr = - std::variant>>; - - FusedMultiHeadedAttentionRunner() = default; - - explicit FusedMultiHeadedAttentionRunner( - std::unique_ptr> runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionRunner(Repr runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionRunner(const GpufMHAConfig& config) - : FusedMultiHeadedAttentionRunner(CreateRunner(config)) { - if (std::holds_alternative(repr_)) { - CHECK(false) << "Cannot construct FusedMultiHeadedAttentionRunner with " - "std::monostate"; - } - } - - se::dnn::AlgorithmDesc ToAlgorithmDesc() const { - return std::visit(ToAlgorithmDescVisitor{}, repr_); - } - - se::dnn::LazyOpRunner* AsFusedMHARunner() { - CHECK(std::holds_alternative< - std::unique_ptr>>(repr_)); - return std::get< - std::unique_ptr>>( - repr_) - .get(); - } - - private: - // The CreateRunner function is defined as static because it - // doesn't need access to any non-static member variables of the - // FusedMultiHeadedAttentionRunner class. Defining it static makes it easy to - // use and makes it clear that it is a utility function that doesn't rely on - // the state of any specific instance of the class. - static Repr CreateRunner(const GpufMHAConfig& config) { - switch (config.kind) { - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - return std::make_unique>( - config.algorithm); - default: - LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in " - "FusedMultiHeadedAttentionRunner"; - } - } - - struct ToAlgorithmDescVisitor { - template - se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) { - return runner->ToAlgorithmDesc(); - } - - se::dnn::AlgorithmDesc operator()(const std::monostate&) { - CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc"; - } - }; - - Repr repr_; -}; - -class FusedMultiHeadedAttentionBackwardRunner { - public: - using Repr = std::variant< - std::monostate, // To allow XXX default ctor - std::unique_ptr>>; - - FusedMultiHeadedAttentionBackwardRunner() = default; - - explicit FusedMultiHeadedAttentionBackwardRunner( - std::unique_ptr> - runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionBackwardRunner(Repr runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionBackwardRunner( - const GpufMHABackwardConfig& config) - : FusedMultiHeadedAttentionBackwardRunner(CreateRunner(config)) { - if (std::holds_alternative(repr_)) { - CHECK(false) - << "Cannot construct FusedMultiHeadedAttentionBackwardRunner with " - "std::monostate"; - } - } - - se::dnn::AlgorithmDesc ToAlgorithmDesc() const { - return std::visit(ToAlgorithmDescVisitor{}, repr_); - } - - se::dnn::LazyOpRunner* - AsFusedMHABackwardRunner() { - CHECK(std::holds_alternative< - std::unique_ptr>>( - repr_)); - return std::get>>(repr_) - .get(); - } - - private: - // The CreateRunner function is defined as static because it - // doesn't need access to any non-static member variables of the - // FusedMultiHeadedAttentionBackwardRunner class. Defining it static makes it - // easy to use and makes it clear that it is a utility function that doesn't - // rely on the state of any specific instance of the class. - static Repr CreateRunner(const GpufMHABackwardConfig& config) { - switch (config.kind) { - case CudnnfMHAKind::kBackwardSoftmaxDropout: - case CudnnfMHAKind::kBackwardSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - return std::make_unique< - se::dnn::LazyOpRunner>( - config.algorithm); - default: - LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in " - "FusedMultiHeadedAttentionBackwardRunner"; - } - } - - struct ToAlgorithmDescVisitor { - template - se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) { - return runner->ToAlgorithmDesc(); - } - - se::dnn::AlgorithmDesc operator()(const std::monostate&) { - CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc"; - } - }; - - Repr repr_; -}; - -struct RunFusedMHAOptions { - // Nullable output-parameter pointer for profiling results. - // Profile results remain unused for now since cuDNN FMHA has only one - // algorithm for now. - se::dnn::ProfileResult* profile_result = nullptr; - - // Use this runner cache (and its configured algorithm), instead of the one - // from the instruction. - FusedMultiHeadedAttentionRunner* runner_cache; -}; - -struct RunFusedMHABackwardOptions { - // Nullable output-parameter pointer for profiling results. - // Profile results remain unused for now since cuDNN FMHA has only one - // algorithm for now. - se::dnn::ProfileResult* profile_result = nullptr; - - // Use this runner cache (and its configured algorithm), instead of the one - // from the instruction. - FusedMultiHeadedAttentionBackwardRunner* runner_cache; -}; - -absl::Status RunGpuFMHA(const GpufMHAConfig& fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, - se::Stream* stream, RunFusedMHAOptions = {}); - -absl::Status RunGpuFMHABackward( - const GpufMHABackwardConfig& fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, se::Stream* stream, - RunFusedMHABackwardOptions = {}); - -std::string ToString(const GpufMHAConfig& config); - -} // namespace gpu -} // namespace xla -#endif // XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index e8710fe452e642..f637e67e562113 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -78,84 +78,6 @@ int ComputeMaxUnrollFactor(int64_t num_elements) { return 1; } -// Determines if we enable the row optimized codegen. When we have a fusion with -// only pointwise operations, scalar broadcasting and row broadcasting, we can -// trigger a kernel that vectorizes the row loads. This speeds up the kernel, in -// particular on A100. The int is the number of inputs with rank `out_rank`. Its -// value is only defined if row vectorization is enabled. -std::pair RowVectorizationEnabled( - const HloFusionAdaptor& fusion, int64_t out_rank) { - auto roots = fusion.GetRoots(); - const auto is_row_major = [](const HloInstruction* instr) { - // Only tested when the inputs are row-major. So only enable that case. - // Maybe it would work if only the inner dimensions is contiguous. - return LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()); - }; - bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() && - is_row_major(&roots[0].instruction()); - if (!row_vectorized) { - return {false, 0}; - } - - // Check that the operations in the fusion are supported. Each - // supported operation (or category) must be manually vetted as XLA - // only unrolls and relies on LLVM to vectorize. But this is brittle. - // Currently tested and supported operations: - // Elementwise, scalar and row broadcasting. - // - // We also detect at the same time if there is a row broadcasting - // operation. - int num_big_inputs = 0; - bool some_row_broadcasting = false; - HloBfsConsumersFirstTraversal( - roots, fusion, [&](auto node) -> TraversalResult { - if (!row_vectorized) { - return TraversalResult::kInterrupt; - } - - if (node.instruction().IsElementwise()) { - return TraversalResult::kAdvance; - } - - switch (node.opcode()) { - case HloOpcode::kConstant: - return TraversalResult::kSkip; - case HloOpcode::kParameter: - return TraversalResult::kAdvance; - case HloOpcode::kBroadcast: { - auto dims = node.instruction().dimensions(); - if (dims.empty()) { - return TraversalResult::kAdvance; - } - - if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) { - some_row_broadcasting = true; - return TraversalResult::kAdvance; - } - TF_FALLTHROUGH_INTENDED; - } - default: - VLOG(2) << "Row vectorization not enabled due to: " - << node.ToString(); - row_vectorized = false; - return TraversalResult::kInterrupt; - } - }); - if (row_vectorized) { - for (const HloInstruction* argument : fusion.GetParameters()) { - if (argument->shape().rank() == out_rank) { - ++num_big_inputs; - } - if (!is_row_major(argument)) { - row_vectorized = false; - } - }; - } - // Trigger only when there is a row broadcasting. - return std::make_pair(row_vectorized && some_row_broadcasting, - num_big_inputs); -} - } // namespace bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) { @@ -1046,9 +968,6 @@ static void GetFusionRootsRec(const HloInstruction* root, GetFusionRootsRec(root->operand(i), out); } } else { - CHECK(!absl::c_linear_search(out, root)) - << "Fusion root contains instruction " << root->ToString() - << " multiple times"; out.push_back(root); } } @@ -1163,39 +1082,7 @@ LaunchDimensionsConfig ComputeLoopFusionConfig( CHECK(absl::has_single_bit(static_cast(unroll_factor))); VLOG(2) << "Unroll factor: " << unroll_factor; - bool row_vectorized; - int num_big_inputs; - std::tie(row_vectorized, num_big_inputs) = - RowVectorizationEnabled(analysis.fusion(), element_shape.rank()); - bool few_waves = !HloAnyOf(analysis.fusion(), [&](auto instr) { - if (instr.opcode() == HloOpcode::kParameter || - instr.opcode() == HloOpcode::kConstant || - HloInstruction::IsOpElementwise(instr.opcode())) { - return false; - } - if (auto broadcast = - DynCast(&instr.instruction())) { - if (broadcast->dimensions().empty() || - // More than 3 big inputs cause a speed regression. - (row_vectorized && num_big_inputs <= 3)) { - return false; - } - } - VLOG(2) << "few_waves not enabled due to: " - << instr.instruction().ToString(); - return true; - }); - - LaunchDimensionsConfig launch_config{unroll_factor, few_waves, - row_vectorized}; - // Check that the shapes is supported. - if (launch_config.row_vectorized && - ThreadsPerBlockRowVectorized(element_shape, analysis.device_info(), - launch_config) <= 0) { - VLOG(2) << "Cancelling row_vectorization as the shape isn't supported."; - launch_config.row_vectorized = false; - launch_config.few_waves = false; - } + LaunchDimensionsConfig launch_config{unroll_factor}; return launch_config; } diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h index 688e4eb0415a0b..0dadbfa36f5476 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.h +++ b/third_party/xla/xla/service/gpu/gpu_fusible.h @@ -34,7 +34,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" // TODO(b/112957171): Extract logic to determine fusibility of HLO ops from -// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion. +// GpuInstructionFusion, FusionMerger, and MultiOutputFusion. namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc index aa3713dab0c746..735709cbd346f8 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc @@ -544,7 +544,9 @@ TEST_F(GpuFusibleTest, FusionHeroesAreCompatible_TransposeFusionNotCompatible) { fused_computation_1 { p0.1 = f32[64,32]{1,0} parameter(0) neg = f32[64,32]{1,0} negate(p0.1) - ROOT transpose = f32[32,64]{1,0} transpose(neg), dimensions={1,0} + bc = f32[1,64,32]{2,1,0} bitcast(neg) + transpose = f32[1,32,64]{2,1,0} transpose(bc), dimensions={0,2,1} + ROOT bc2 = f32[32,64]{1,0} bitcast(transpose) } fused_computation_2 { @@ -562,10 +564,12 @@ TEST_F(GpuFusibleTest, FusionHeroesAreCompatible_TransposeFusionNotCompatible) { const HloInstruction* fusion_1 = module->entry_computation()->root_instruction(); const HloInstruction* fusion_2 = fusion_1->operand(0); - EXPECT_FALSE(FusionHeroesAreCompatible(fusion_1->fused_expression_root(), - fusion_2->fused_expression_root())); - EXPECT_FALSE(FusionHeroesAreCompatible(fusion_2->fused_expression_root(), - fusion_1->fused_expression_root())); + EXPECT_FALSE( + FusionHeroesAreCompatible(fusion_1->fused_expression_root(), + fusion_2->fused_expression_root()->operand(0))); + EXPECT_FALSE( + FusionHeroesAreCompatible(fusion_2->fused_expression_root()->operand(0), + fusion_1->fused_expression_root())); } TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_LoopFusions) { @@ -1520,9 +1524,9 @@ TEST_F(GpuFusibleTest, ChooseFusionKind) { HloModule module ENTRY computation { - p = f32[5000,6000]{1,0} parameter(0) - c = f32[6000,5000] transpose(p), dimensions={1,0} - ROOT r = f32[300,20,5000] reshape(c) + p = f32[1,5000,6000]{2,1,0} parameter(0) + c = f32[1,6000,5000]{2,1,0} transpose(p), dimensions={0,2,1} + ROOT r = f32[300,20,5000]{2,1,0} reshape(c) } )") .value(); @@ -1700,6 +1704,33 @@ TEST_F(GpuFusibleTest, GetFusionRootsWithMakeTupleGTESequence) { EXPECT_EQ(roots, expected_result); } +TEST_F(GpuFusibleTest, GetFusionRootsWithTupleMultipleSameOperands) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + + fusion { + p1 = s32[32] parameter(0) + add0 = s32[32] add(p1, p1) + ROOT _ = (s32[32], s32[32]) tuple(add0, add0) + } + + ENTRY entry { + p0 = s32[32] parameter(0) + ROOT fusion = (s32[32], s32[32]) fusion(p0), kind=kCustom, calls=fusion + } + )") + .value(); + + auto called_computations = + module->entry_computation()->root_instruction()->called_computations(); + ASSERT_EQ(called_computations.size(), 1); + + auto fusion = called_computations.front(); + auto roots = GetFusionRoots(*fusion); + auto add0 = fusion->root_instruction()->operand(0); + EXPECT_THAT(GetFusionRoots(*fusion), ElementsAre(add0, add0)); +} + TEST_F(GpuFusibleTest, GetFusibleComputations) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( fused_reduce { diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 218a151c160a57..588abc6297fadd 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -46,9 +46,10 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" -#include "xla/service/gpu/gpu_schedule_postprocessing.h" #include "xla/service/gpu/model/analytical_latency_estimator.h" -#include "xla/service/gpu/scheduling_instruction_annotator.h" +#include "xla/service/gpu/transforms/pgle_accuracy_checker.h" +#include "xla/service/gpu/transforms/schedule_postprocessing.h" +#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/latency_hiding_scheduler.h" @@ -63,6 +64,7 @@ limitations under the License. #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace gpu { @@ -430,6 +432,7 @@ static int64_t GetSchedulerMemoryLimit( absl::StatusOr ScheduleGpuModule( HloModule* module, int64_t pointer_size, const se::DeviceDescription& gpu_device_info) { + tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); int64_t memory_limit = GetSchedulerMemoryLimit(module, gpu_device_info, pointer_size); if (module->has_schedule()) { @@ -476,6 +479,7 @@ absl::StatusOr ScheduleGpuModule( module->config() .debug_options() .xla_gpu_enable_analytical_latency_estimator(); + HloPassPipeline pipeline("latency-hiding-scheduler"); if (profile.has_value()) { auto aggregator = std::make_unique(); auto pg_latency_estimator = std::make_unique( @@ -484,7 +488,7 @@ absl::StatusOr ScheduleGpuModule( LOG(INFO) << "Found profile, using profile guided latency estimator. Profile:\n" << profile->DebugString(); - TF_RETURN_IF_ERROR(pg_latency_estimator->CheckAccuracy(*module)); + pipeline.AddPass(*pg_latency_estimator); latency_estimator = std::move(pg_latency_estimator); } else if (enable_analytical_latency_estimator) { latency_estimator = std::make_unique( @@ -509,7 +513,6 @@ absl::StatusOr ScheduleGpuModule( auto shape_size_in_bytes = [pointer_size](const Shape& shape) { return GetSizeOfShape(shape, pointer_size); }; - HloPassPipeline pipeline("latency-hiding-scheduler"); auto scheduler_core = std::make_unique( shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config); @@ -520,8 +523,8 @@ absl::StatusOr ScheduleGpuModule( TF_RETURN_IF_ERROR(pipeline.Run(module).status()); - HloPassPipeline postprocessing_pipeline("gpu-schedule-postprocessing"); - postprocessing_pipeline.AddPass(); + HloPassPipeline postprocessing_pipeline("schedule-postprocessing"); + postprocessing_pipeline.AddPass(); TF_RETURN_IF_ERROR(postprocessing_pipeline.Run(module).status()); return ScheduleMetadata{memory_limit}; diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index 3582d40ef9ee0f..60b80d656aa7fc 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -47,6 +47,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -56,7 +57,6 @@ namespace xla { namespace gpu { using ::testing::ElementsAre; -using ::testing::HasSubstr; using ::tsl::testing::StatusIs; class GpuHloScheduleTest : public HloTestBase { @@ -494,6 +494,112 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModel) { } } +TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelFailsWithIncompleteProfile) { + const absl::string_view kHloString = R"( + HloModule m + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32,32] parameter(1) + p2 = f32[32,32] parameter(2) + p3 = f32[32] parameter(3) + + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT t = (f32[32],f32[32],f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Profile string, cost does not matter. + const absl::string_view kProfile = R"pb( + costs { name: "dot0" cost_us: 100.0 } + costs { name: "add0" cost_us: 10.0 } + costs { name: "ar-start" cost_us: 1000.0 } + )pb"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + kHloString, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true, + /*fdo_profile=*/kProfile))); + + // `dot1` and `ar-start1` are missing from the profile. + EXPECT_THAT(ScheduleGpuModule( + module.get(), /*pointer_size=*/8, + backend().default_stream_executor()->GetDeviceDescription()) + .status(), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F( + GpuHloScheduleTest, + ProfileGuidedCostModelDoesNotFailWithIncompleteProfileIfAccuracyCheckerIsDisabled) { // NOLINT(whitespace/line_length) + const absl::string_view kHloString = R"( + HloModule m + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32,32] parameter(1) + p2 = f32[32,32] parameter(2) + p3 = f32[32] parameter(3) + + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT t = (f32[32],f32[32],f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Profile string, cost does not matter. + const absl::string_view kProfile = R"pb( + costs { name: "dot0" cost_us: 100.0 } + costs { name: "add0" cost_us: 10.0 } + costs { name: "ar-start" cost_us: 1000.0 } + )pb"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + kHloString, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true, + /*fdo_profile=*/kProfile))); + + // `dot1` and `ar-start1` are missing from the profile but we disable the + // pass. + module->mutable_config().mutable_debug_options().add_xla_disable_hlo_passes( + "pgle-accuracy-checker"); + TF_EXPECT_OK(ScheduleGpuModule( + module.get(), /*pointer_size=*/8, + backend().default_stream_executor()->GetDeviceDescription()) + .status()); +} + TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelWithRematData) { const char* hlo_text = R"( HloModule AsyncAR diff --git a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc index 928011cb4b76a1..215609c7e288cb 100644 --- a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/service/buffer_value.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/stream_attribute_annotator.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_rematerialization.h" diff --git a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc index d84797d21c462e..06e1e6fa1594b0 100644 --- a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc +++ b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/service/algebraic_simplifier.h" #include "xla/service/conditional_simplifier.h" #include "xla/service/gather_expander.h" -#include "xla/service/gpu/gpu_algebraic_simplifier.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_module_config.h" diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index b3be135ce0e4c2..d7e7129b2f26ad 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -61,9 +61,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" @@ -546,37 +546,41 @@ static std::optional FindTiledTranspose( return std::nullopt; } - if (std::optional tr = ShapeUtil::GetNormalizedTransposeShape( - instr.operand(0)->shape(), instr.shape(), Vector3{0, 2, 1})) { + absl::InlinedVector permutation; + auto tr = ShapeUtil::GetNormalizedTransposeShape(instr.operand(0)->shape(), + instr.shape(), permutation); + if (!tr.has_value()) { + return std::nullopt; + } + if (permutation == absl::InlinedVector{0, 2, 1}) { if ((tr->at(1) >= kMinDimensionToTransposeTiled && tr->at(2) >= kMinDimensionToTransposeTiled) || (tr->at(1) >= kMinDimensionToTransposeTiled2 && tr->at(2) >= kMinDimensionToTransposeTiled2 && tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{0, 2, 1}}; + return TransposeDescription{ + &instr, *tr, + /*permutation=*/absl::InlinedVector{0, 2, 1}}; } - } - if (std::optional tr = ShapeUtil::GetNormalizedTransposeShape( - instr.operand(0)->shape(), instr.shape(), Vector3{2, 1, 0})) { + } else if (permutation == absl::InlinedVector{2, 1, 0}) { if ((tr->at(0) >= kMinDimensionToTransposeTiled && tr->at(2) >= kMinDimensionToTransposeTiled) || (tr->at(0) >= kMinDimensionToTransposeTiled2 && tr->at(2) >= kMinDimensionToTransposeTiled2 && tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{2, 1, 0}}; + return TransposeDescription{ + &instr, *tr, + /*permutation=*/absl::InlinedVector{2, 1, 0}}; } - } - if (IsMlirTransposeEmitterEnabled(instr)) { - if (std::optional tr = ShapeUtil::GetNormalizedTransposeShape( - instr.operand(0)->shape(), instr.shape(), Vector3{1, 0, 2})) { + } else if (IsMlirTransposeEmitterEnabled(instr)) { + if (permutation == absl::InlinedVector{1, 0, 2}) { auto byte_width = primitive_util::ByteWidth(instr.shape().element_type()); if (byte_width * tr->at(2) <= kMaxBytesInMostMinorDimension && byte_width * tr->at(2) * std::min(tr->at(0), tr->at(1)) >= kMinDimensionToTransposeTiled) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{1, 0, 2}}; + return TransposeDescription{ + &instr, *tr, + /*permutation=*/absl::InlinedVector{1, 0, 2}}; } } } @@ -590,43 +594,47 @@ static std::optional FindTiledLogicalTranspose( return std::nullopt; } - // TODO(cheshire): avoid code duplication. - if (std::optional tr = ShapeUtil::GetNormalizedLogicalTransposeShape( - instr.operand(0)->shape(), instr.shape(), instr.dimensions(), - Vector3{0, 2, 1})) { - if ((tr->at(1) >= kMinDimensionToTransposeTiled && - tr->at(2) >= kMinDimensionToTransposeTiled) || - (tr->at(1) >= kMinDimensionToTransposeTiled2 && - tr->at(2) >= kMinDimensionToTransposeTiled2 && - tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{0, 2, 1}}; - } + // We can assume that TransposeDimensionGrouper pass has run, so no need to + // call GetNormalizedLogicalTransposeShape here. + absl::InlinedVector permutation(instr.dimensions().begin(), + instr.dimensions().end()); + // A real transpose needs at least 2 transpose dimensions. + if (permutation.size() < 2) { + return std::nullopt; } - if (std::optional tr = ShapeUtil::GetNormalizedLogicalTransposeShape( - instr.operand(0)->shape(), instr.shape(), instr.dimensions(), - Vector3{2, 1, 0})) { - if ((tr->at(0) >= kMinDimensionToTransposeTiled && - tr->at(2) >= kMinDimensionToTransposeTiled) || - (tr->at(0) >= kMinDimensionToTransposeTiled2 && - tr->at(2) >= kMinDimensionToTransposeTiled2 && - tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{2, 1, 0}}; + absl::InlinedVector dimensions(instr.shape().dimensions().begin(), + instr.shape().dimensions().end()); + int64_t operand_most_minor_dim = + instr.operand(0)->shape().dimensions().back(); + if (permutation == absl::InlinedVector{0, 2, 1} || + permutation == absl::InlinedVector{2, 1, 0}) { + if ((dimensions.back() >= kMinDimensionToTransposeTiled && + operand_most_minor_dim >= kMinDimensionToTransposeTiled) || + (dimensions.back() >= kMinDimensionToTransposeTiled2 && + operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + dimensions.back() * operand_most_minor_dim >= + kMinTotalDimensionsToTransposeTiled)) { + return TransposeDescription{&instr, dimensions, permutation}; } - } - if (IsMlirTransposeEmitterEnabled(instr)) { - if (std::optional tr = - ShapeUtil::GetNormalizedLogicalTransposeShape( - instr.operand(0)->shape(), instr.shape(), instr.dimensions(), - Vector3{1, 0, 2})) { + } else if (IsMlirTransposeEmitterEnabled(instr)) { + if (permutation.back() == dimensions.size() - 1) { + operand_most_minor_dim = + instr.operand(0)->shape().dimensions(dimensions.size() - 2); auto byte_width = primitive_util::ByteWidth(instr.shape().element_type()); - if (byte_width * tr->at(2) <= kMaxBytesInMostMinorDimension && - byte_width * tr->at(2) * std::min(tr->at(0), tr->at(1)) >= + if (byte_width * dimensions.back() <= kMaxBytesInMostMinorDimension && + byte_width * dimensions.back() * + std::min(operand_most_minor_dim, + dimensions[dimensions.size() - 2]) >= kMinDimensionToTransposeTiled) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{1, 0, 2}}; + return TransposeDescription{&instr, dimensions, permutation}; } + } else if ((operand_most_minor_dim >= kMinDimensionToTransposeTiled && + dimensions.back() >= kMinDimensionToTransposeTiled) || + (operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + dimensions.back() >= kMinDimensionToTransposeTiled2 && + operand_most_minor_dim * dimensions.back() >= + kMinTotalDimensionsToTransposeTiled)) { + return TransposeDescription{&instr, dimensions, permutation}; } } return std::nullopt; @@ -634,12 +642,6 @@ static std::optional FindTiledLogicalTranspose( std::optional GetDescriptionForTiledTransposeEmitter( const HloInstruction& root, const HloInstruction& hero) { - // TODO(b/284431534): Figure out how to make the shared memory transpose - // emitter faster for this case. - if (hero.shape().element_type() == F32 && root.shape().element_type() == S8) { - return std::nullopt; - } - if (auto d1 = FindTiledTranspose(hero)) { return d1; } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index bed6830bfa4979..3dcf0bce20ae23 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -164,16 +165,18 @@ struct TransposeDescription { const HloInstruction* instr; // Normalized transpose dimensions. - Vector3 dimensions; + absl::InlinedVector dimensions; // Permutations of normalized transpose dimensions. - Vector3 permutation; + absl::InlinedVector permutation; - TransposeDescription(Vector3 dimensions, Vector3 permutation) + TransposeDescription(absl::InlinedVector dimensions, + absl::InlinedVector permutation) : TransposeDescription(/*instr=*/nullptr, dimensions, permutation) {} - TransposeDescription(const HloInstruction* instr, Vector3 dimensions, - Vector3 permutation) + TransposeDescription(const HloInstruction* instr, + absl::InlinedVector dimensions, + absl::InlinedVector permutation) : instr(instr), dimensions(dimensions), permutation(permutation) {} // Transpose instruction input shape. diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index 4900e474498cda..80407b0835d9eb 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -20,8 +20,11 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "xla/hlo/ir/backend_config.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" @@ -30,7 +33,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/types.h" -#include "xla/util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -41,26 +43,29 @@ namespace gpu { using ::tsl::testing::IsOkAndHolds; using IrEmissionUtilsTest = HloTestBase; +using InlinedVector = absl::InlinedVector; TEST_F(IrEmissionUtilsTest, FindTiledLogicalTranspose) { const char* hlo = R"( HloModule module ENTRY entry { - p = f32[32,48,64]{2,1,0} parameter(0) - ROOT t = f32[64,32,48]{2,1,0} transpose(p), dimensions={2,0,1} + p = f32[1536,64]{1,0} parameter(0) + ROOT t = f32[64,1536]{1,0} transpose(p), dimensions={1,0} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); HloInstruction* tr = module->entry_computation()->root_instruction(); auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, tr); - EXPECT_EQ(result->dimensions, Vector3({1, 64, 1536})); - EXPECT_EQ(result->permutation, Vector3({0, 2, 1})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 1536})); + EXPECT_EQ(result->permutation, InlinedVector({1, 0})); } TEST_F(IrEmissionUtilsTest, FindTiledLogical102Transpose) { @@ -82,8 +87,8 @@ ENTRY entry { auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, tr); - EXPECT_EQ(result->dimensions, Vector3({48, 32, 2})); - EXPECT_EQ(result->permutation, Vector3({1, 0, 2})); + EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 2})); + EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2})); } TEST_F(IrEmissionUtilsTest, FindTiledLogical102TransposeTooMuchMemoryRequired) { @@ -106,6 +111,52 @@ ENTRY entry { EXPECT_FALSE(result.has_value()); } +TEST_F(IrEmissionUtilsTest, FindTiledLogical2103Transpose) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f32[33,48,32,2]{3,2,1,0} parameter(0) + ROOT t = f32[32,48,33,2]{3,2,1,0} transpose(p), dimensions={2,1,0,3} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, tr); + EXPECT_EQ(result->dimensions, InlinedVector({32, 48, 33, 2})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0, 3})); +} + +TEST_F(IrEmissionUtilsTest, FindTiledLogical1320Transpose) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f32[33,48,32,34]{3,2,1,0} parameter(0) + ROOT t = f32[48,34,32,33]{3,2,1,0} transpose(p), dimensions={1,3,2,0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, tr); + EXPECT_EQ(result->dimensions, InlinedVector({48, 34, 32, 33})); + EXPECT_EQ(result->permutation, InlinedVector({1, 3, 2, 0})); +} + TEST_F(IrEmissionUtilsTest, FindTiled102Transpose) { const char* hlo = R"( HloModule module @@ -125,8 +176,8 @@ ENTRY entry { auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, tr); - EXPECT_EQ(result->dimensions, Vector3({48, 32, 4})); - EXPECT_EQ(result->permutation, Vector3({1, 0, 2})); + EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 4})); + EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2})); } TEST_F(IrEmissionUtilsTest, FindTiled102TransposeTooMuchMemoryRequired) { @@ -165,8 +216,8 @@ ENTRY entry { auto result = GetDescriptionForTiledTransposeEmitter(*r, *r); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, r); - EXPECT_EQ(result->dimensions, Vector3({64, 48, 32})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateUnaryOp) { @@ -186,8 +237,8 @@ ENTRY entry { auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, r->operand(0)); - EXPECT_EQ(result->dimensions, Vector3({64, 48, 32})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateUnaryOpS8) { @@ -210,11 +261,11 @@ ENTRY main { HloInstruction* r = module->entry_computation()->root_instruction()->fused_expression_root(); - // TODO(b/284431534): Update this test when the shared memory transpose - // emitter is fast for S8 output. - EXPECT_FALSE( - GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)).has_value()); - EXPECT_EQ(FindNonTrivialHero(*r).name(), "t"); + auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, r->operand(0)); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusion) { @@ -344,8 +395,8 @@ ENTRY entry { auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, r->operand(0)); - EXPECT_EQ(result->dimensions, Vector3({64, 48, 32})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithTwoIntermediateBinaryOps) { @@ -375,8 +426,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*r, FindNonTrivialHero(*r)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, r->operand(0)->operand(0)); - EXPECT_EQ(result->dimensions, Vector3({64, 48, 32})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, @@ -388,8 +439,10 @@ fusion { p = f32[32,48,64]{2,1,0} parameter(0) p2 = f32[48,32,64]{2,1,0} parameter(1) t = f32[64,48,32]{2,1,0} transpose(p), dimensions={2,1,0} - t2 = f32[64,48,32]{2,1,0} transpose(p2), dimensions={2,0,1} - ROOT add = f32[64,48,32]{2,1,0} add(t, t2) + bc = f32[1,1536,64]{2,1,0} bitcast(p2) + t2 = f32[1,64,1536]{2,1,0} transpose(bc), dimensions={0,2,1} + bc2 = f32[64,48,32]{2,1,0} bitcast(t2) + ROOT add = f32[64,48,32]{2,1,0} add(t, bc2) } ENTRY main { @@ -561,8 +614,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*copy, FindNonTrivialHero(*copy)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, copy); - EXPECT_EQ(result->dimensions, Vector3({8, 12, 1100})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOneSwapDimIsSmall) { @@ -570,13 +623,13 @@ TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOneSwapDimIsSmall) { HloModule module fusion { - p = f32[100,11,12,8]{3,2,1,0} parameter(0) - ROOT t = f32[8,12,100,11]{3,2,1,0} transpose(p), dimensions={3,2,0,1} + p = f32[1100,12,8]{2,1,0} parameter(0) + ROOT t = f32[8,12,1100]{2,1,0} transpose(p), dimensions={2,1,0} } ENTRY main { - param = f32[100,11,12,8]{3,2,1,0} parameter(0) - ROOT fusion = f32[8,12,100,11]{3,2,1,0} fusion(param), kind=kInput, calls=fusion + param = f32[1100,12,8]{2,1,0} parameter(0) + ROOT fusion = f32[8,12,1100]{2,1,0} fusion(param), kind=kInput, calls=fusion } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -588,8 +641,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*tr, FindNonTrivialHero(*tr)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, tr); - EXPECT_EQ(result->dimensions, Vector3({8, 12, 1100})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindTiledTransposeOtherSwapDimIsSmall) { @@ -615,8 +668,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*copy, FindNonTrivialHero(*copy)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, copy); - EXPECT_EQ(result->dimensions, Vector3({1100, 12, 8})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOtherSwapDimIsSmall) { @@ -624,13 +677,13 @@ TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOtherSwapDimIsSmall) { HloModule module fusion { - p = f32[8,12,100,11]{3,2,1,0} parameter(0) - ROOT t = f32[100,11,12,8]{3,2,1,0} transpose(p), dimensions={2,3,1,0} + p = f32[8,12,1100]{2,1,0} parameter(0) + ROOT t = f32[1100,12,8]{2,1,0} transpose(p), dimensions={2,1,0} } ENTRY main { - param = f32[8,12,100,11]{3,2,1,0} parameter(0) - ROOT fusion = f32[100,11,12,8]{3,2,1,0} fusion(param), kind=kInput, calls=fusion + param = f32[8,12,1100]{2,1,0} parameter(0) + ROOT fusion = f32[1100,12,8]{2,1,0} fusion(param), kind=kInput, calls=fusion } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -642,8 +695,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*tr, FindNonTrivialHero(*tr)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, tr); - EXPECT_EQ(result->dimensions, Vector3({1100, 12, 8})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, IsContiguousSlice) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 03690caeaa8121..b73225a1bd3c56 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -99,7 +98,6 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/gpu_norm_runner.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -122,7 +120,6 @@ limitations under the License. #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" #include "xla/service/gpu/runtime/fft_thunk.h" -#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/service/gpu/runtime/infeed_thunk.h" @@ -173,6 +170,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/runtime/cholesky_thunk.h" #include "xla/service/gpu/runtime/cub_sort_thunk.h" +#include "xla/service/gpu/runtime/cudnn_thunk.h" #include "xla/service/gpu/runtime/triangular_solve_thunk.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -955,221 +953,17 @@ absl::Status IrEmitterUnnested::EmitNormThunk( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitFusedMHAThunk( +absl::Status IrEmitterUnnested::EmitCuDnnThunk( const HloCustomCallInstruction* instr) { - const HloInstruction* lhs_bmm1 = instr->operand(0); - const HloInstruction* rhs_bmm1 = instr->operand(1); - const HloInstruction* rhs_bmm2 = instr->operand(2); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_bmm1_slice, - GetAllocationSliceForHlo(lhs_bmm1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm1_slice, - GetAllocationSliceForHlo(rhs_bmm1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm2_slice, - GetAllocationSliceForHlo(rhs_bmm2)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, - GetAllocationSliceForHlo(instr, {0})); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSliceForHlo( - instr, {instr->shape().tuple_shapes_size() - 1})); - BufferAllocation::Slice activation_slice; - bool has_activation = xla::ShapeUtil::TupleElementCount(instr->shape()) == 3; - if (has_activation) { - TF_ASSIGN_OR_RETURN(activation_slice, GetAllocationSliceForHlo(instr, {1})); - } - - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(instr)); - BufferAllocation::Slice mask_slice, bias_slice; - BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; - std::optional mask_shape, bias_shape; - { - bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || - kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; - - if (has_bias) { - const HloInstruction* bias = instr->operand(3); - TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias)); - bias_shape = bias->shape(); - } - int64_t seqlen_qk_operand_index = 3 + has_bias; - bool has_seqlen_qk = seqlen_qk_operand_index == instr->operand_count() - 2; - if (has_seqlen_qk) { - const HloInstruction* seqlen_q = instr->operand(seqlen_qk_operand_index); - TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); - const HloInstruction* seqlen_k = - instr->operand(seqlen_qk_operand_index + 1); - TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); - } - } - - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - absl::InlinedVector output_shapes = { - ShapeUtil::GetSubshape(instr->shape(), {0})}; - if (has_activation) { - output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {1})); - } - TF_ASSIGN_OR_RETURN(const auto mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - GpufMHADescriptor descriptor = {kind, - config, - mask_type, - lhs_bmm1->shape(), - rhs_bmm1->shape(), - rhs_bmm2->shape(), - intermediate_tensor_shape, - output_shapes, - config.bmm1_dot_dimension_numbers(), - config.bmm2_dot_dimension_numbers(), - mask_shape, - bias_shape}; - - TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, - GpufMHAConfig::For(descriptor)); - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(fmha_config), - lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, output_slice, - scratch_slice, mask_slice, bias_slice, activation_slice, seqlen_q_slice, - seqlen_k_slice)); - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk( - const HloCustomCallInstruction* instr) { - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - - int input_index = 0; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm1_grad_gemm1_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm1_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm2_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm2_grad_gemm1_lhs_shape; - - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - bmm2_grad_gemm1_lhs_shape = intermediate_tensor_shape; - input_index++; - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_output_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape d_output_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(instr)); - BufferAllocation::Slice mask_slice; - std::optional mask_shape; - - bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || - kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); - BufferAllocation::Slice bias_slice; - std::optional bias_shape; - if (has_bias) { - TF_ASSIGN_OR_RETURN(bias_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - bias_shape = instr->operand(input_index++)->shape(); - } - - BufferAllocation::Slice fwd_output_slice; - std::optional fwd_output_shape; - - TF_ASSIGN_OR_RETURN(fwd_output_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - fwd_output_shape = instr->operand(input_index++)->shape(); - - BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; - bool has_seqlen_qk = input_index == instr->operand_count() - 2; - if (has_seqlen_qk) { - const HloInstruction* seqlen_q = instr->operand(input_index); - TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); - const HloInstruction* seqlen_k = instr->operand(input_index + 1); - TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); - input_index += 2; - } - TF_RET_CHECK(input_index == instr->operand_count()); - - int output_index = 0; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_lhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm1_lhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_rhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm1_rhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm2_rhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm2_rhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - BufferAllocation::Slice d_s_slice; - std::optional d_s_shape; - - bool has_dbias = instr->shape().tuple_shapes().size() == 5; - BufferAllocation::Slice d_bias_slice; - std::optional d_bias_shape; - if (has_dbias) { - TF_ASSIGN_OR_RETURN(d_bias_slice, - GetAllocationSliceForHlo(instr, {output_index})); - d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - } - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSliceForHlo(instr, {output_index++})); - TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size()); - TF_ASSIGN_OR_RETURN(const auto mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - bool force_deterministic = config.force_deterministic(); - GpufMHABackwardDescriptor descriptor = { - kind, - config, - mask_type, - bmm1_grad_gemm1_rhs_shape, - bmm1_grad_gemm2_rhs_shape, - bmm2_grad_gemm1_lhs_shape, - bmm2_grad_gemm2_rhs_shape, - d_output_shape, - d_bmm1_lhs_shape, - d_bmm1_rhs_shape, - d_bmm2_rhs_shape, - config.bmm1_grad_gemm1_dot_dimension_numbers(), - config.bmm1_grad_gemm2_dot_dimension_numbers(), - config.bmm2_grad_gemm1_dot_dimension_numbers(), - config.bmm2_grad_gemm2_dot_dimension_numbers(), - d_s_shape, - fwd_output_shape, - mask_shape, - d_bias_shape, - bias_shape, - force_deterministic}; - - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_backward_config, - GpufMHABackwardConfig::For(descriptor)); - - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), - std::move(fmha_backward_config), bmm1_grad_gemm1_rhs_slice, - bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, - bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice, - d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, d_s_slice, - mask_slice, d_bias_slice, fwd_output_slice, bias_slice, seqlen_q_slice, - seqlen_k_slice)); - + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + instr->operands())); + TF_ASSIGN_OR_RETURN(const std::string fingerprint, + FingerprintWithBackendConfig(*instr)); + AddThunkToThunkSequence(std::make_unique( + fingerprint, Thunk::ThunkInfo::WithProfileAnnotation(instr), + kernel_arguments.args())); return absl::OkStatus(); } @@ -2921,11 +2715,8 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (IsCustomCallToDnnNorm(*instr)) { return EmitNormThunk(custom_call); } - if (IsFwdCustomCallTofMHA(*instr)) { - return EmitFusedMHAThunk(custom_call); - } - if (IsBwdCustomCallTofMHA(*instr)) { - return EmitFusedMHABackwardThunk(custom_call); + if (IsCustomCallTofMHA(*instr)) { + return EmitCuDnnThunk(custom_call); } #endif // GOOGLE_CUDA if (IsCustomCallToTopK(*instr)) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index f97f106ddfc0df..d19dd5d9c4172c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -147,8 +147,7 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitConvolutionReorderThunk( const HloCustomCallInstruction* instr); absl::Status EmitNormThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFusedMHAThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFusedMHABackwardThunk(const HloCustomCallInstruction* instr); + absl::Status EmitCuDnnThunk(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM absl::Status EmitCubDeviceRadixSort(const HloCustomCallInstruction* instr); diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index a1ce0cfa2b76b5..737b47db113eba 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -8,7 +8,7 @@ load( load("//xla:xla.bzl", "xla_cc_binary") load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") -load("//xla/tests:build_defs.bzl", "xla_test") +load("//xla/tests:build_defs.bzl", "DEFAULT_DISABLED_BACKENDS", "xla_test") load("//xla/tsl:tsl.bzl", "if_windows") package( @@ -73,14 +73,22 @@ cc_library( # a single dependency. cc_library( name = "custom_fusion_library", + tags = [ + "gpu", + "no_rocm", + ], visibility = [":friends"], - deps = [":cutlass_gemm_fusion"], + deps = if_cuda_is_configured([":cutlass_gemm_fusion"]), ) cc_library( name = "cutlass_gemm_fusion", srcs = ["cutlass_gemm_fusion.cc"], hdrs = ["cutlass_gemm_fusion.h"], + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":custom_kernel", ":custom_kernel_fusion", @@ -108,7 +116,7 @@ xla_test( srcs = ["cutlass_gemm_fusion_test.cc"], backends = ["gpu"], # TODO(b/332820384): Enable when it passes on H100. - disabled_backends = ["gpu_h100"], + disabled_backends = DEFAULT_DISABLED_BACKENDS + ["gpu_h100"], tags = ["no_rocm"], deps = [ ":custom_kernel", @@ -133,13 +141,12 @@ xla_test( cc_library( name = "topk_kernel", - srcs = if_gpu_is_configured(["topk_kernel.cc"]), - hdrs = if_gpu_is_configured(["topk_kernel.h"]), + srcs = ["topk_kernel.cc"], + hdrs = ["topk_kernel.h"], compatible_with = [], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), + tags = ["gpu"], deps = [ + ":topk_kernel_gpu", "//xla:shape_util", "//xla:types", "//xla:util", @@ -156,19 +163,17 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - ] + if_gpu_is_configured([ - ":topk_kernel_gpu", - ]), + ], ) gpu_kernel_library( name = "topk_kernel_gpu", - srcs = if_gpu_is_configured([ + srcs = [ + "topk_kernel.cu.h", "topk_kernel_bfloat16.cu.cc", "topk_kernel_float.cu.cc", - "topk_kernel.cu.h", - ]), - hdrs = if_gpu_is_configured(["topk_kernel_common.h"]), + ], + hdrs = ["topk_kernel_common.h"], compatible_with = [], deps = [ "//xla:types", @@ -179,7 +184,7 @@ gpu_kernel_library( xla_test( name = "topk_kernel_test", - srcs = if_gpu_is_configured(["topk_kernel_test.cc"]), + srcs = ["topk_kernel_test.cc"], backends = ["gpu"], deps = [ ":topk_kernel", @@ -228,7 +233,7 @@ cc_library( xla_test( name = "topk_custom_kernel_test", - srcs = if_gpu_is_configured(["topk_custom_kernel_test.cc"]), + srcs = ["topk_custom_kernel_test.cc"], backends = ["gpu"], deps = [ ":topk_custom_kernel", @@ -256,11 +261,12 @@ xla_test( cc_library( name = "cutlass_gemm_custom_kernel", - srcs = if_cuda_is_configured( - ["cutlass_gemm_custom_kernel.cc"], - ["cutlass_gemm_custom_kernel_stub.cc"], - ), + srcs = ["cutlass_gemm_custom_kernel.cc"], hdrs = ["cutlass_gemm_custom_kernel.h"], + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":custom_kernel", ":cutlass_gemm", @@ -278,9 +284,10 @@ cc_library( xla_test( name = "cutlass_gemm_custom_kernel_test", - srcs = if_cuda_is_configured(["cutlass_gemm_custom_kernel_test.cc"]), + srcs = ["cutlass_gemm_custom_kernel_test.cc"], backends = ["gpu"], data = [":cutlass_gemm_kernel_f32xf32_to_f32.so"], + tags = ["no_rocm"], deps = [ ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", @@ -299,7 +306,11 @@ xla_test( xla_cc_binary( name = "cutlass_gemm_custom_kernel_benchmarks", testonly = 1, - srcs = if_cuda_is_configured(["cutlass_gemm_custom_kernel_benchmarks.cc"]), + srcs = ["cutlass_gemm_custom_kernel_benchmarks.cc"], + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", @@ -329,22 +340,24 @@ cc_library( cuda_library( name = "cutlass_gemm_adaptor", - hdrs = if_cuda_is_configured(["cutlass_gemm_adaptor.cu.h"]), + hdrs = ["cutlass_gemm_adaptor.cu.h"], copts = if_windows( [], ["-Wno-unknown-attributes"], ), # __grid_constant__ is not supported by clang - deps = if_cuda_is_configured([ + tags = ["no_rocm"], + deps = [ ":cutlass_gemm", "@cutlass_archive//:cutlass", - ]), + ], ) cuda_library( name = "cutlass_gemm_epilogue", + tags = ["no_rocm"], # TODO(ezhulenev): Update to regular hdrs after fixing CUTLASS headers. - textual_hdrs = if_cuda_is_configured(["cutlass_gemm_epilogue.cu.h"]), - deps = if_cuda_is_configured(["@cutlass_archive//:cutlass"]), + textual_hdrs = ["cutlass_gemm_epilogue.cu.h"], + deps = ["@cutlass_archive//:cutlass"], ) #===--------------------------------------------------------------------------------------------===# @@ -356,6 +369,10 @@ cuda_library( cc_library( name = "cutlass_gemm_kernels", + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":cutlass_gemm_kernel_bf16xbf16_to_bf16", ":cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", @@ -378,7 +395,7 @@ cc_library( cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"]), + srcs = ["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"], copts = [ "-mllvm", "-unroll-threshold=100000", @@ -386,16 +403,17 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - deps = if_cuda_is_configured([ + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"]), + srcs = ["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"], copts = [ "-mllvm", "-unroll-threshold=100000", @@ -403,16 +421,17 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - deps = if_cuda_is_configured([ + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"]), + srcs = ["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"], copts = [ "-mllvm", "-unroll-threshold=100000", @@ -423,31 +442,33 @@ cuda_library( "-Wno-unknown-attributes", ], ), - deps = if_cuda_is_configured([ + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", ":cutlass_gemm_epilogue", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_f32xf32_to_f32", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"]), + srcs = ["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"], copts = if_windows( [], ["-Wno-unknown-attributes"], ), - deps = if_cuda_is_configured([ + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_f32", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc"]), + srcs = ["cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc"], copts = [ "-mllvm", "-unroll-threshold=100000", @@ -455,16 +476,17 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - deps = if_cuda_is_configured([ + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_bf16xf32_to_f32", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc"]), + srcs = ["cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc"], copts = [ "-mllvm", "-unroll-threshold=100000", @@ -472,16 +494,17 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - deps = if_cuda_is_configured([ + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_f32xbf16_to_f32", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc"]), + srcs = ["cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc"], copts = [ "-mllvm", "-unroll-threshold=100000", @@ -489,22 +512,27 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - deps = if_cuda_is_configured([ + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_bf16xs8_to_f32", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc"]), + srcs = ["cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc"], copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], - deps = if_cuda_is_configured([ + tags = [ + "gpu", + "no_rocm", + ], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) #===--------------------------------------------------------------------------------------------===# @@ -513,8 +541,12 @@ cuda_library( cc_binary( name = "cutlass_gemm_kernel_f32xf32_to_f32.so", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cc"]), + srcs = ["cutlass_gemm_kernel_f32xf32_to_f32.cc"], linkshared = True, linkstatic = False, + tags = [ + "gpu", + "no_rocm", + ], deps = [":cutlass_gemm"], ) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc deleted file mode 100644 index d95241b0abdec9..00000000000000 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/status/statusor.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/kernels/cutlass_gemm.h" -#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" -#include "xla/stream_executor/device_description.h" -#include "xla/xla_data.pb.h" - -namespace xla::gpu::kernel::gemm_universal { - -absl::StatusOr> GetCutlassGemmKernels( - std::string name, PrimitiveType dot_type, PrimitiveType lhs_type, - PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k, - const ArgsIndices& indices, const DynamicSliceIndices& slices, - const se::DeviceDescription& device) { - return absl::InternalError("XLA compiled without CUDA support"); -} - -absl::StatusOr LoadCutlassGemmKernel( - std::string name, const std::string& library_path, PrimitiveType dtype, - int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, - const DynamicSliceIndices& slices, const se::DeviceDescription& device) { - return absl::InternalError("XLA compiled without CUDA support"); -} - -} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/launch_dimensions.cc b/third_party/xla/xla/service/gpu/launch_dimensions.cc index 89b322f6708556..f9e28995d09960 100644 --- a/third_party/xla/xla/service/gpu/launch_dimensions.cc +++ b/third_party/xla/xla/service/gpu/launch_dimensions.cc @@ -16,13 +16,8 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include -#include #include -#include -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/str_format.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -32,139 +27,6 @@ limitations under the License. namespace xla { namespace gpu { -static int64_t ThreadsPerBlockLimit( - const se::DeviceDescription& gpu_device_info) { - int64_t threads_per_block = gpu_device_info.threads_per_block_limit(); - if (threads_per_block <= 0) { - static std::atomic log_count{0}; - if (log_count.fetch_add(1) < 8) { - LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " - "without full information about its capabilities. " - "StreamExecutor's PopulateDeviceDescription should be " - "updated for this device."; - } - threads_per_block = gpu_device_info.threads_per_warp(); - if (threads_per_block == 0) { - // Fall back to *something* if we can't even get num threads per warp. - threads_per_block = 32; - } - } - return threads_per_block; -} - -int64_t ThreadsPerBlockRowVectorized( - const Shape& shape, const se::DeviceDescription& gpu_device_info, - LaunchDimensionsConfig dim_config) { - if (shape.dimensions().empty()) { - return -1; - } - int64_t threads_per_block_row_vectorized = - shape.dimensions().back() / dim_config.unroll_factor; - if (dim_config.row_vectorized && - shape.dimensions().back() % dim_config.unroll_factor == 0 && - // If the row size is a multiple of 256, then use the old code - // path that use a block size of 256. This give small speed up on V100. - // Vectorization of the row load was already happening. - (shape.dimensions().back() % 256) != 0 && - // We do not support row that do not fit in one block. - threads_per_block_row_vectorized <= - gpu_device_info.threads_per_block_limit()) { - return threads_per_block_row_vectorized; - } - return -1; -} - -namespace { - -struct BlockSizes { - int64_t threads_per_block_x; - int64_t threads_per_block_y; - int64_t block_count; -}; - -BlockSizes GetBlockSizes(LaunchDimensionsConfig dim_config, - const se::DeviceDescription& gpu_device_info, - const Shape& shape, int64_t num_elements) { - if (!dim_config.row_vectorized && !dim_config.few_waves) { - BlockSizes result; - const int kWarpSchedulers = 4; - result.threads_per_block_x = std::min( - gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); - result.threads_per_block_y = 1; - result.block_count = CeilOfRatio( - num_elements, result.threads_per_block_x * result.threads_per_block_y); - return result; - } - - int64_t threads_per_block_row_vectorized = - ThreadsPerBlockRowVectorized(shape, gpu_device_info, dim_config); - // If row vectorized, threads_per_block_x is the vectorized size. - // Otherwise, we unroll kernels to make use of vectorized - // loads/stores. This means we need more registers to hold - // intermediate values. Reduce the number of threads per block to - // increase the number of registers available to ptxas. Make sure - // we still have a multiple of 32. - BlockSizes result; - int64_t max_threads_per_block_x = - threads_per_block_row_vectorized > 0 - ? threads_per_block_row_vectorized - : RoundUpTo(ThreadsPerBlockLimit(gpu_device_info) / - dim_config.unroll_factor, - int64_t{32}); - result.threads_per_block_x = std::min(num_elements, max_threads_per_block_x); - // threads_per_block_y > 1 when we row vectorize and have small row size. - result.threads_per_block_y = - threads_per_block_row_vectorized > 0 && - threads_per_block_row_vectorized < 128 && num_elements > 128 - ? CeilOfRatio(static_cast(128), - threads_per_block_row_vectorized) - : 1; - VLOG(2) << "Set # of threads per block to (.x=" << result.threads_per_block_x - << ", .y=" << result.threads_per_block_y << ")"; - - result.block_count = CeilOfRatio( - num_elements, result.threads_per_block_x * result.threads_per_block_y); - if (dim_config.few_waves) { - if (dim_config.row_vectorized) { - // This multiple of 32 was tuned to not cause regression on multiple - // benchmarks. It isn't a value that is optimal for all kernels. Maybe - // looking at the arithmetic intensity of the kernels can specialize the - // multiple per kernel. - int64_t max_block_count = - 32 * gpu_device_info.core_count() * - (gpu_device_info.threads_per_core_limit() / - (result.threads_per_block_x * result.threads_per_block_y)); - int64_t capped_block_count = result.block_count; - while (capped_block_count > max_block_count) { - capped_block_count /= 2; - } - if (capped_block_count < result.block_count) { - result.block_count = capped_block_count; - VLOG(2) << "Update # of blocks to " << result.block_count - << " as few_waves is enabled."; - } - } else { - int64_t capped_threads_per_block_x = - std::min(result.threads_per_block_x, 128); - int64_t capped_block_count = - gpu_device_info.core_count() * - (gpu_device_info.threads_per_core_limit() / - (capped_threads_per_block_x * result.threads_per_block_y)); - if (capped_block_count < result.block_count) { - result.threads_per_block_x = capped_threads_per_block_x; - result.block_count = capped_block_count; - VLOG(2) << "Update the # of blocks to " << result.block_count - << " and the # of threads per blocks to " - << result.threads_per_block_x - << " as the few_waves mode is enabled."; - } - } - } - return result; -} - -} // namespace - LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& gpu_device_info, LaunchDimensionsConfig dim_config) { @@ -173,12 +35,13 @@ LaunchDimensions CalculateLaunchDimensions( return LaunchDimensions(); } num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor}); - BlockSizes sizes = - GetBlockSizes(dim_config, gpu_device_info, shape, num_elements); - return LaunchDimensions( - se::BlockDim(sizes.block_count, 1, 1), - se::ThreadDim(sizes.threads_per_block_x, sizes.threads_per_block_y, 1)); + const int kWarpSchedulers = 4; + int64_t threads_per_block = std::min( + gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); + int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block); + return LaunchDimensions(se::BlockDim(num_blocks, 1, 1), + se::ThreadDim(threads_per_block, 1, 1)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/launch_dimensions.h b/third_party/xla/xla/service/gpu/launch_dimensions.h index e0c53f9b266f4c..7295048fcdd45c 100644 --- a/third_party/xla/xla/service/gpu/launch_dimensions.h +++ b/third_party/xla/xla/service/gpu/launch_dimensions.h @@ -85,17 +85,6 @@ struct LaunchDimensionsConfig { // The kernel implementation will be unrolled if `unroll_factor` is // greater than one. int unroll_factor = 1; - // A wave is a group of blocks that execute at the same time on the - // GPU. If there are more blocks then the number that can run - // concurrently, there are multiple waves of blocks running - // sequentially. If `few_waves` is true, each thread will loop over - // a block of unroll_factor elements. Otherwise each thread will - // handle only unroll_factor. - bool few_waves = false; - // If `row_vectorized` is true, then the block size will equal to - // `hlo.shape().dimensions().back()/unroll_factor`. - // Currently few_waves and row_vectorized do not work together. - bool row_vectorized = false; }; // Returns -1 if the shape doesn't allow the row vectorization code path. diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index 8951c2719cb290..8fc3db56945a8e 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -2,6 +2,10 @@ load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) +load( + "@local_config_sycl//sycl:build_defs.bzl", + "if_sycl_is_configured", +) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -88,6 +92,8 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@llvm-project//llvm:AMDGPUCodeGen", "@llvm-project//llvm:AMDGPUAsmParser", + ]) + if_sycl_is_configured([ + "@spirv_llvm_translator//:spirv_llvm_translator", ]), ) @@ -106,3 +112,16 @@ xla_cc_test( "@local_tsl//tsl/platform:test", ], ) + +xla_cc_test( + name = "gpu_backend_lib_test", + size = "small", + srcs = ["gpu_backend_lib_test.cc"], + deps = [ + ":llvm_gpu_backend", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", + "@llvm-project//llvm:Core", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 721cae6ac3c269..696e3608297b95 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -106,6 +106,11 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_asm_compiler.h" #endif +#if TENSORFLOW_USE_SYCL +#include "LLVMSPIRVLib.h" +#include "LLVMSPIRVOpts.h" +#endif // TENSORFLOW_USE_SYCL + namespace xla { namespace gpu { namespace { @@ -118,41 +123,6 @@ const int kAMDGPUInlineThreshold = 0x100000; // Default inline threshold value to use in llvm. const int kDefaultInlineThreshold = 1100; -// Gets the GPU name as it's known to LLVM for a given compute -// capability. If we see an unrecognized compute capability, we -// return the highest one that is known and below the selected device. -static std::string GetSmName(se::CudaComputeCapability compute_capability) { - int compute_capability_version = - compute_capability.major * 10 + compute_capability.minor; - int sm_version = 30; - // If the current compute capability isn't known, fallback to the - // most recent version before it. - int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62, - 61, 60, 53, 52, 50, 37, 35, 32, 30}; - for (int v : supported_versions) { - if (v <= compute_capability_version) { - sm_version = v; - break; - } - } - - // If the current CC isn't supported by LLVM and it is newer then - // the max supported LLVM version, do not warn about it. The end - // user can't do anything about this. E.g., PTX compiled for SM75 will - // run on SM80 too. - if (sm_version != compute_capability_version && - compute_capability_version < supported_versions[0]) { - LOG(WARNING) << "Unknown compute capability " - << compute_capability.ToString() - << ". Defaulting to telling LLVM that we're compiling for sm_" - << sm_version; - } - // If the target is sm_90, hard code it to sm_90a so that all instructions - // can be used. We don't need the portability that sm_90 gives. - std::string_view extension = sm_version == 90 ? "a" : ""; - return absl::StrCat("sm_", sm_version, extension); -} - // NOLINTBEGIN: clang-diagnostic-unused-function // Convenience function for producing a name of a temporary compilation product // from the input filename. @@ -379,7 +349,7 @@ std::unique_ptr NVPTXGetTargetMachine( #else std::string feature_str; #endif // GOOGLE_CUDA - return GetTargetMachine(target_triple, GetSmName(compute_capability), + return GetTargetMachine(target_triple, nvptx::GetSmName(compute_capability), debug_options, feature_str); } @@ -453,7 +423,9 @@ absl::Status LinkAndOptimizeModule( llvm::CGSCCAnalysisManager cgam; llvm::ModuleAnalysisManager mam; - fam.registerPass([&] { return target_machine->getTargetIRAnalysis(); }); + if (target_machine) { + fam.registerPass([&] { return target_machine->getTargetIRAnalysis(); }); + } llvm::PipelineTuningOptions pto; pto.SLPVectorization = true; @@ -570,6 +542,40 @@ void NVPTXBackendInit(const DebugOptions& debug_options) { namespace nvptx { +std::string GetSmName(se::CudaComputeCapability compute_capability) { + int compute_capability_version = + compute_capability.major * 10 + compute_capability.minor; + int sm_version = 30; + // If the current compute capability isn't known, fallback to the + // most recent version before it. + int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62, + 61, 60, 53, 52, 50, 37, 35, 32, 30}; + for (int v : supported_versions) { + if (v <= compute_capability_version) { + sm_version = v; + break; + } + } + + // If the current CC isn't supported by LLVM and it is newer then + // the max supported LLVM version, do not warn about it. The end + // user can't do anything about this. E.g., PTX compiled for SM75 will + // run on SM80 too. + if (sm_version != compute_capability_version && + compute_capability_version < supported_versions[0]) { + LOG(WARNING) << "Unknown compute capability " + << compute_capability.ToString() + << ". Defaulting to telling LLVM that we're compiling for sm_" + << sm_version; + } + // On Hopper, default to sm_90a so that all instructions can be used. But + // only sm_90 is forward compatible, so don't use sm_90a with newer hardware: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility + std::string_view extension = + (compute_capability.major == 9 && sm_version == 90) ? "a" : ""; + return absl::StrCat("sm_", sm_version, extension); +} + std::string CantFindCudaMessage(absl::string_view msg, absl::string_view xla_gpu_cuda_data_dir) { return absl::StrCat( @@ -1138,5 +1144,95 @@ absl::StatusOr> CompileToHsaco( } // namespace amdgpu +namespace { + +std::unique_ptr SPIRGetTargetMachine( + llvm::Triple target_triple, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options) { + return nullptr; +} + +absl::Status SPIRTargetModuleLinker( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + const std::string& device_bitcode_dir_path) { + return absl::OkStatus(); +} + +absl::StatusOr EmitModuleToSpir( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options) { +#if TENSORFLOW_USE_SYCL + SPIRV::TranslatorOpts::ExtensionsStatusMap ExtensionsStatus; + SPIRV::TranslatorOpts opts(SPIRV::VersionNumber::MaximumVersion, + ExtensionsStatus); + opts.enableAllExtensions(); // enable all SPIR-V extension first + + std::ostringstream oss; + std::string err; + bool success = llvm::writeSpirv(module, opts, oss, err); + if (!success) { + return xla::Internal("Fails to convert LLVM as SPIR-V: %s", err); + } + return oss.str(); +#else + return absl::UnimplementedError("Not implemented for SYCL"); +#endif +} + +void SPIRBackendInit(const DebugOptions& debug_options) { + FeedLLVMWithFlags({ + "-slp-vectorize-hor=false", + "-slp-min-reg-size=64", + "-slp-max-reg-size=64", + }); + + llvm_ir::InitializeLLVMCommandLineOptions( + debug_options.xla_backend_extra_options()); + + llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); + InitializePasses(registry); +} + +} // namespace + +namespace spir { + +absl::StatusOr> CompileToSpir( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options) { + std::string libdevice_dir_path; + static absl::once_flag backend_init_flag; + absl::call_once(backend_init_flag, SPIRBackendInit, debug_options); + + std::string spir; + { + XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); + + // If the module has no functions or globals, there's nothing to compile. + if (module->empty() && module->global_empty()) { + VLOG(2) << "Module '" << module->getName().str() + << "' is empty. Skipping compilation."; + return std::vector(); + } + + llvm::Triple default_target_triple("spir64-unknown-unknown"); + std::unique_ptr target_machine = + SPIRGetTargetMachine(default_target_triple, gpu_version, debug_options); + + TF_RETURN_IF_ERROR(LinkAndOptimizeModule( + module, gpu_version, debug_options, libdevice_dir_path, + SPIRTargetModuleLinker, default_target_triple, target_machine.get(), + kDefaultInlineThreshold)); + + // Lower optimized LLVM module to SPIR. + TF_ASSIGN_OR_RETURN(spir, + EmitModuleToSpir(module, gpu_version, debug_options)); + } + return std::vector(spir.begin(), spir.end()); +} + +} // namespace spir + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index 3ab5d6d84db1b3..1814291beae184 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -37,6 +37,11 @@ namespace gpu { namespace nvptx { +// Gets the GPU name as it's known to LLVM for a given compute +// capability. If we see an unrecognized compute capability, we +// return the highest one that is known and below the selected device. +std::string GetSmName(se::CudaComputeCapability compute_capability); + std::string CantFindCudaMessage(absl::string_view msg, absl::string_view xla_gpu_cuda_data_dir); @@ -73,6 +78,13 @@ absl::StatusOr> CompileToHsaco( const std::string& module_config_cache_key); } // namespace amdgpu +namespace spir { +// Compiles the argument module and returns it. +absl::StatusOr> CompileToSpir( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options); +} // namespace spir + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc new file mode 100644 index 00000000000000..9e65f34a296cb6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" + +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { +namespace se = ::stream_executor; + +TEST(UtilsTest, TestGetSmName) { + se::CudaComputeCapability cc_hopper(9, 0); + ASSERT_EQ(nvptx::GetSmName(cc_hopper), "sm_90a"); + // Do not default to sm90_a after Hopper, because it is not forward + // compatible. + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility + se::CudaComputeCapability cc_next(10, 0); + ASSERT_EQ(nvptx::GetSmName(cc_next), "sm_90"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index fe4982e9a223b9..49270de65ecd3f 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -456,7 +456,11 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm, const HloInstruction* gemm) { TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, gemm->backend_config()); - const GemmBackendConfig& config = gpu_config.gemm_backend_config(); + return For(gemm, gpu_config.gemm_backend_config()); +} + +/*static*/ absl::StatusOr GemmConfig::For( + const HloInstruction* gemm, const GemmBackendConfig& config) { std::optional algorithm; if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) { algorithm = config.selected_algorithm(); diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h index 22d7f178133835..5f128e418af58c 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.h +++ b/third_party/xla/xla/service/gpu/matmul_utils.h @@ -108,6 +108,11 @@ struct GemmConfig : public se::gpu::GemmConfig { static absl::StatusOr For(const HloInstruction* gemm); + // Gets the GemmConfig of the `gemm` instruction with overridden + // GemmBackendConfig. + static absl::StatusOr For(const HloInstruction* gemm, + const GemmBackendConfig& config); + static absl::StatusOr For( const Shape& lhs_shape, absl::Span lhs_batch_dims, absl::Span lhs_contracting_dims, const Shape& rhs_shape, diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 6173fa558a0e4c..48e3b1ccafb363 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -350,6 +350,7 @@ cc_library( ":indexing_analysis", ":symbolic_tile_analysis", ":tiled_hlo_computation", + ":triton_emitter_constraints", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", @@ -732,6 +733,7 @@ xla_cc_test( ":indexing_test_utils", ":symbolic_tile", ":symbolic_tile_analysis", + ":symbolic_tiled_hlo_instruction", ":tiled_hlo_computation", ":tiled_hlo_instruction", "//xla:util", @@ -752,6 +754,40 @@ xla_cc_test( ], ) +cc_library( + name = "triton_emitter_constraints", + srcs = ["triton_emitter_constraints.cc"], + hdrs = ["triton_emitter_constraints.h"], + deps = [ + ":affine_map_evaluator", + ":symbolic_tile_analysis", + ":symbolic_tiled_hlo_instruction", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "triton_emitter_constraints_test", + srcs = ["triton_emitter_constraints_test.cc"], + deps = [ + ":symbolic_tile_analysis", + ":triton_emitter_constraints", + "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "coalescing_analysis", srcs = ["coalescing_analysis.cc"], diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index 6a182882ba192e..aefe84294472a2 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -167,13 +167,13 @@ TEST_F(CoalescingTest, Transpose) { HloModule module fusion { - %input = f32[100, 64, 32] parameter(0) - ROOT transpose = f32[32, 100, 64] transpose(%input), dimensions={2, 0, 1} + %input = f32[1, 6400, 32] parameter(0) + ROOT transpose = f32[1, 32, 6400] transpose(%input), dimensions={0, 2, 1} } ENTRY entry { - %input = f32[100, 64, 32] parameter(0) - ROOT %fusion = f32[32, 100, 64] fusion(%input), kind=kLoop, calls=fusion + %input = f32[1, 6400, 32] parameter(0) + ROOT %fusion = f32[1, 32, 6400] fusion(%input), kind=kLoop, calls=fusion })"; // thread_x to linearized input mapping for thread_x in [0, 31]: // Operand 1: (thread_x)[s0] -> (thread_x + s0 * 128) for s0 in [0, 7] @@ -185,15 +185,15 @@ TEST_F(CoalescingTest, TransposeOfBroadcastHeuristic) { HloModule module fusion { - input = f32[32, 100, 64] parameter(0) - ROOT slice = f32[32, 100, 1] slice(input), slice={[0:32:1], [0:100:1], [0:1:1]} + input = f32[1, 32, 6400] parameter(0) + ROOT slice = f32[1, 32, 100] slice(input), slice={[0:1:1], [0:32:1], [0:6400:64]} } ENTRY entry { p0 = f32[32] parameter(0) - broadcast = f32[100, 64, 32] broadcast(p0), dimensions={2} - transpose = f32[32, 100, 64] transpose(broadcast), dimensions={2, 0, 1} - ROOT %fusion = f32[32, 100, 1] fusion(transpose), kind=kLoop, calls=fusion + broadcast = f32[1, 6400, 32] broadcast(p0), dimensions={2} + transpose = f32[1, 32, 6400] transpose(broadcast), dimensions={0, 2, 1} + ROOT %fusion = f32[1, 32, 100] fusion(transpose), kind=kLoop, calls=fusion })"; EXPECT_TRUE(IsReadCoalescedHeuristic(ir)); } diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc index f4369df4bbe7df..5e80fe7bca8b7e 100644 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc @@ -79,7 +79,6 @@ const HloFusionAnalysis& HloFusionAnalysisCache::Get( } void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); analyses_.erase(instruction.unique_id()); if (auto consumers = @@ -97,8 +96,6 @@ void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { } void HloFusionAnalysisCache::Clear() { - absl::MutexLock lock(&mutex_); - analyses_.clear(); producer_consumer_analyses_.clear(); consumers_for_producers_.clear(); diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h index 4cf6053e03fed6..9eacee0a933aad 100644 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h @@ -28,9 +28,9 @@ limitations under the License. namespace xla::gpu { -// Caches HloFusionAnalyses. Thread-compatible, if no threads concurrently `Get` -// and `Invalidate` the same key. Analyses are cached based on unique_ids, no -// checking or tracking of changes is done. +// Caches HloFusionAnalyses. `Get` can be called concurrently, but `Invalidate` +// and `Clear` shouldn't. Analyses are cached based on unique_ids, no checking +// or tracking of changes is done. class HloFusionAnalysisCache { public: explicit HloFusionAnalysisCache( diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index f3417c9729b29e..3b0cae33ea7acc 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -369,7 +370,9 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( absl::Span tile_sizes) { // TODO(b/332714755): Add caching for SymbolicTileAnalysis. SymbolicTileAnalysisOrError analysis_or_error = - SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + SymbolicTileAnalysis::AnalyzeFusion( + fusion_adaptor, mlir_context_, + TritonEmitterConstraints::GetBuilder()); if (const auto* fusion_decision = std::get_if(&analysis_or_error)) { return absl::FailedPreconditionError(absl::StrCat( @@ -429,7 +432,9 @@ absl::StatusOr GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( const HloFusionAdaptor& fusion_adaptor) { SymbolicTileAnalysisOrError analysis_or_error = - SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + SymbolicTileAnalysis::AnalyzeFusion( + fusion_adaptor, mlir_context_, + TritonEmitterConstraints::GetBuilder()); if (const auto* fusion_decision = std::get_if(&analysis_or_error)) { diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc index b64580055884c2..0ece419e1b009f 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc @@ -216,8 +216,8 @@ ENTRY entry_computation { auto launch_dimensions = GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis); - EXPECT_EQ(launch_dimensions.num_blocks(), 16); - EXPECT_EQ(launch_dimensions.num_threads_per_block(), 1024); + EXPECT_EQ(launch_dimensions.num_blocks(), 128); + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 128); } TEST_F(GpuPerformanceModelBaseTest, diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index 7c88789345b95b..2335906b13f544 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -161,7 +161,7 @@ ENTRY e { auto reification_cost = root->backend_config() ->fusion_backend_config() .reification_cost(); - EXPECT_NEAR(reification_cost.end_to_end_cycles(), 257.7, 0.1); + EXPECT_NEAR(reification_cost.end_to_end_cycles(), 38.4, 0.1); EXPECT_NEAR(reification_cost.exec_time_us(), 0, 1); auto indexing_t = indexing_cost_model_.EstimateRunTimes(root); diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc index 6a8ed6538e8edb..550e9e7a31ffdf 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc @@ -39,7 +39,7 @@ TEST_F(HloOpProfilerTest, BasicMeasurementsAreCorrect) { EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kDivide, F64) .value() .clock_cycles(), - 400); + 300); // c128 sqrt is slow. EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kSqrt, C128) .value() diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 7aabefcea76e14..da21c3464b16b5 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -1127,28 +1127,6 @@ bool IndexingMap::IsSymbolConstrained(int64_t symbol_id) const { return false; } -llvm::DenseMap IndexingMap::GetConstraintsForSymbol( - int symbol_id) const { - llvm::DenseMap constraints; - for (auto const& [exp, interval] : GetConstraints()) { - if (exp.isFunctionOfSymbol(symbol_id)) { - constraints.insert({exp, interval}); - } - } - return constraints; -} - -llvm::DenseMap IndexingMap::GetConstraintsForDim( - int dim_id) const { - llvm::DenseMap constraints; - for (auto const& [exp, interval] : GetConstraints()) { - if (exp.isFunctionOfDim(dim_id)) { - constraints.insert({exp, interval}); - } - } - return constraints; -} - RangeEvaluator::RangeEvaluator(const IndexingMap& indexing_map, MLIRContext* mlir_context, bool use_constraints) : mlir_context_(mlir_context), diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 8be6177717f699..2e6cc1374f505b 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -372,14 +372,6 @@ class IndexingMap { // Returns true if there is a constraint on the given symbol. bool IsSymbolConstrained(int64_t symbol_id) const; - // Returns the constraints for the given dimension. - llvm::DenseMap GetConstraintsForDim( - int dim_id) const; - - // Returns the constraints for the given symbol. - llvm::DenseMap GetConstraintsForSymbol( - int symbol_id) const; - // Returns true if the domain is empty. If it returns false, that does not // mean that the domain is not effectively empty. // For example, if there are two constraints 0 <= d0 mod 7 <= 0 and diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 1677bf7747b1b2..3624a01eb44b6b 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -46,8 +46,6 @@ namespace { using ::mlir::AffineMap; using ::testing::AnyOf; using ::testing::ElementsAre; -using ::testing::Pair; -using ::testing::UnorderedElementsAre; class IndexingMapTest : public HloTestBase { public: @@ -1885,74 +1883,6 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { /*instr=*/nullptr, zero_dim_map}})}); } -TEST_F(IndexingMapTest, GetConstraintsForSymbol) { - auto map = IndexingMap::GetUndefined(); - map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), - Interval{0, 1}); - map.AddConstraint(ParseAffineExpr("s1 mod 4", &mlir_context_), - Interval{0, 2}); - map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}); - map.AddConstraint(ParseAffineExpr("s1 + d0", &mlir_context_), Interval{0, 4}); - map.AddConstraint(ParseAffineExpr("d0 mod 4", &mlir_context_), - Interval{0, 5}); - map.AddConstraint(ParseAffineExpr("d1 mod 32", &mlir_context_), - Interval{0, 6}); - - EXPECT_THAT( - map.GetConstraintsForSymbol(1), - UnorderedElementsAre( - Pair(ParseAffineExpr("s1 mod 4", &mlir_context_), Interval{0, 2}), - Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}), - Pair(ParseAffineExpr("s1 + d0", &mlir_context_), Interval{0, 4}))); - - EXPECT_THAT( - map.GetConstraintsForSymbol(0), - UnorderedElementsAre( - Pair(ParseAffineExpr("s0 mod 4", &mlir_context_), Interval{0, 1}), - Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}))); -} - -TEST_F(IndexingMapTest, GetConstraintsForSymbolEmpty) { - auto map = IndexingMap(AffineMap::get(&mlir_context_), {}, {}, {}); - EXPECT_THAT(map.GetConstraintsForSymbol(1), UnorderedElementsAre()); - map.AddConstraint(ParseAffineExpr("d0 mod 4", &mlir_context_), - Interval{0, 5}); - EXPECT_THAT(map.GetConstraintsForSymbol(1), UnorderedElementsAre()); -} - -TEST_F(IndexingMapTest, GetConstraintsForDim) { - auto map = IndexingMap(AffineMap::get(&mlir_context_), {}, {}, {}); - map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), - Interval{0, 1}); - map.AddConstraint(ParseAffineExpr("s1 mod 4", &mlir_context_), - Interval{0, 2}); - map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}); - map.AddConstraint(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}); - map.AddConstraint(ParseAffineExpr("d0 mod 4", &mlir_context_), - Interval{0, 5}); - map.AddConstraint(ParseAffineExpr("d1 mod 32", &mlir_context_), - Interval{0, 6}); - - EXPECT_THAT( - map.GetConstraintsForDim(1), - UnorderedElementsAre( - Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}), - Pair(ParseAffineExpr("d1 mod 32", &mlir_context_), Interval{0, 6}))); - - EXPECT_THAT( - map.GetConstraintsForDim(0), - UnorderedElementsAre( - Pair(ParseAffineExpr("d0 mod 4", &mlir_context_), Interval{0, 5}))); -} - -TEST_F(IndexingMapTest, GetConstraintsForDimEmpty) { - auto map = IndexingMap(AffineMap::get(&mlir_context_), {}, {}, {}); - EXPECT_THAT(map.GetConstraintsForDim(1), UnorderedElementsAre()); - map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), - Interval{0, 5}); - EXPECT_THAT(map.GetConstraintsForDim(1), UnorderedElementsAre()); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index dc5c2f28c6f3f1..1ea107ce9903a8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -602,23 +602,8 @@ std::optional ExtractSizeAndStride( AffineExpr strided_indexing, absl::Span dimension_intervals, absl::Span symbol_intervals) { MLIRContext* ctx = strided_indexing.getContext(); - // Deal with the symbol case (capturing a whole untiled dimension). - // TODO(b/330906085): concatenating across a reduction dimension needs to be - // handled by this code. - if (auto symbol = llvm::dyn_cast(strided_indexing)) { - const Interval& symbol_interval = symbol_intervals[symbol.getPosition()]; - if (symbol_interval.lower != 0) { - return std::nullopt; - } - - return SizeAndStrideExpression( - /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx), - /*stride=*/getAffineConstantExpr(1, ctx)); - } - AffineMapPrinter printer; - // TODO(b/328427138): support multivariate size expressions. switch (strided_indexing.getKind()) { case AffineExprKind::DimId: return SizeAndStrideExpression(/*size=*/strided_indexing, @@ -626,23 +611,15 @@ std::optional ExtractSizeAndStride( case mlir::AffineExprKind::Mul: { const auto mul = llvm::cast(strided_indexing); AffineExpr lhs = mul.getLHS(); - // The stride may not be fully collapsed if it is negative; in that case, - // we need to extract the negative multiplier first. - if (const auto rhs = llvm::dyn_cast(mul.getRHS()); - rhs && rhs.getValue() == -1) { - std::optional maybe_size_and_stride = - ExtractSizeAndStride(lhs, dimension_intervals, symbol_intervals); - if (!maybe_size_and_stride.has_value()) { - return std::nullopt; - } - - return SizeAndStrideExpression( - /*size=*/maybe_size_and_stride->size, - /*stride=*/maybe_size_and_stride->stride * rhs); + std::optional maybe_size_and_stride = + ExtractSizeAndStride(lhs, dimension_intervals, symbol_intervals); + if (!maybe_size_and_stride.has_value()) { + return std::nullopt; } - CHECK(lhs.getKind() == AffineExprKind::DimId); - return SizeAndStrideExpression(/*size=*/lhs, - /*stride=*/mul.getRHS()); + + return SizeAndStrideExpression( + /*size=*/maybe_size_and_stride->size, + /*stride=*/maybe_size_and_stride->stride * mul.getRHS()); } case mlir::AffineExprKind::Mod: { auto mod = llvm::cast(strided_indexing); @@ -656,15 +633,18 @@ std::optional ExtractSizeAndStride( case mlir::AffineExprKind::Constant: return SizeAndStrideExpression(/*size=*/getAffineConstantExpr(1, ctx), /*stride=*/getAffineConstantExpr(0, ctx)); - case mlir::AffineExprKind::SymbolId: - VLOG(1) << "Encountered complex size expression involving symbol " - << printer.ToString(strided_indexing); - // It's currently not checked separately, but RTVars shouldn't appear in - // the strided indexing expressions. - return std::nullopt; + case mlir::AffineExprKind::SymbolId: { + auto symbol = llvm::cast(strided_indexing); + const Interval& symbol_interval = symbol_intervals[symbol.getPosition()]; + if (symbol_interval.lower != 0) { + return std::nullopt; + } + + return SizeAndStrideExpression( + /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx), + /*stride=*/getAffineConstantExpr(1, ctx)); + } case mlir::AffineExprKind::Add: { - // TODO(b/328427138): this should only be necessary in the multivariate - // case, and will be implemented later. std::optional> maybe_sizes_and_strides = ExtractSizesAndStridesFromMultivariateSummation( diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 6aed0c540ba910..7025fe46fc4c47 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -300,13 +300,16 @@ void SortTiledHloInstructionsInPostOrder( } // namespace /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation( - const HloComputation& computation, MLIRContext* ctx) { + const HloComputation& computation, MLIRContext* ctx, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder) { auto fusion = HloFusionAdaptor::ForComputation(&computation); - return SymbolicTileAnalysis::AnalyzeFusion(*fusion, ctx); + return SymbolicTileAnalysis::AnalyzeFusion( + *fusion, ctx, emitter_specific_constraints_builder); } /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeFusion( - const HloFusionAdaptor& fusion, MLIRContext* ctx) { + const HloFusionAdaptor& fusion, MLIRContext* ctx, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder) { OrderedUniquePtrValueHashSet tiled_hlo_instructions_set; @@ -383,12 +386,20 @@ void SortTiledHloInstructionsInPostOrder( return std::get(constraints_or); } + // Create emitter-specific constraints if a builder was provided. + std::unique_ptr emitter_specific_constraints; + if (emitter_specific_constraints_builder != nullptr) { + emitter_specific_constraints = + emitter_specific_constraints_builder(tiled_hlo_instructions); + } + // Order instructions in def-before-use order. SortTiledHloInstructionsInPostOrder(tiled_hlo_instructions, root_tiled_hlo); return SymbolicTileAnalysis( std::move(tiled_hlo_instructions), - std::get(std::move(constraints_or)), ctx); + std::get(std::move(constraints_or)), + std::move(emitter_specific_constraints), ctx); } absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( @@ -399,11 +410,6 @@ absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( "This should never happen."); } - // Handle the unconstrained case. - if (constraints_.IsAlwaysSatisfied()) { - return true; - } - if (tile_parameters.size() != num_tile_parameters()) { return absl::InvalidArgumentError(absl::StrFormat( "Failed to check if tile parameters satisfy constraints. Number of " @@ -412,6 +418,21 @@ absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( tile_parameters.size(), num_tile_parameters())); } + if (emitter_specific_constraints_ != nullptr) { + TF_ASSIGN_OR_RETURN( + bool constraints_are_satisfied, + emitter_specific_constraints_->ParametersSatisfyConstraints( + tile_parameters)); + if (!constraints_are_satisfied) { + return false; + } + } + + // Handle the unconstrained case. + if (constraints_.IsAlwaysSatisfied()) { + return true; + } + // TODO(bchetioui): replace with convenience methods in // `ConstraintExpression`. bool constraints_are_satisfied = false; @@ -443,9 +464,9 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( TF_ASSIGN_OR_RETURN(bool constraints_are_satisfied, ParametersSatisfyConstraints(tile_parameters)); if (!constraints_are_satisfied) { - return absl::InvalidArgumentError(absl::StrCat( - "Tile parameters ", absl::StrJoin(tile_parameters, ", "), - " do not satisfy the SymbolicTileAnalysis's constraints.")); + return absl::InvalidArgumentError( + absl::StrCat("Tile parameters ", absl::StrJoin(tile_parameters, ", "), + " do not satisfy constraints.")); } } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index df56d2325dd641..692e88db11b998 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -44,6 +44,21 @@ class SymbolicTileAnalysis; using SymbolicTileAnalysisOrError = std::variant; +// An interface to implement additional emitter-specific constraints. This +// interface can be used as an extension point to further constrain the set of +// given limitations of a particular codegen solution. +class EmitterSpecificConstraints { + public: + virtual ~EmitterSpecificConstraints() = default; + + virtual absl::StatusOr ParametersSatisfyConstraints( + absl::Span tile_parameters) const = 0; +}; + +using EmitterSpecificConstraintsBuilder = + std::function( + const std::vector>&)>; + // Constructs and holds symbolic tiles for all the instructions within a // computation. We may hold several different symbolic tiles for the same // instruction if the instruction is indexed in several different ways in order @@ -59,10 +74,17 @@ class SymbolicTileAnalysis { // Tries to construct a symbolic tile analysis from a computation. Returns // a diagnostic if the construction fails for any reason. + // + // If `emitter_specific_constraints_builder` is provided, it will be used to + // construct emitter-specific constraints for the analysis. static SymbolicTileAnalysisOrError AnalyzeComputation( - const HloComputation& computation, mlir::MLIRContext* ctx); + const HloComputation& computation, mlir::MLIRContext* ctx, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder = + nullptr); static SymbolicTileAnalysisOrError AnalyzeFusion( - const HloFusionAdaptor& fusion, mlir::MLIRContext* ctx); + const HloFusionAdaptor& fusion, mlir::MLIRContext* ctx, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder = + nullptr); // Returns a graph of HLO instructions tiled with the given tile parameters. // The provided tile parameters must satisfy the analysis's constraints. @@ -101,7 +123,8 @@ class SymbolicTileAnalysis { const ConstraintExpression& GetConstraints() const { return constraints_; } // Returns true if a list of tile parameters satisfies the symbolic tile - // analysis's constraints. + // analysis's constraints. If provided, also checks the emitter-specific + // constraints. // // Returns false if the constraints are not satisfied but can be evaluated // correctly. Returns an error if the constraints cannot be evaluated @@ -127,13 +150,16 @@ class SymbolicTileAnalysis { absl::StatusOr> GetGoodTilings() const; private: - SymbolicTileAnalysis(std::vector> - symbolic_tiled_hlo_instructions, - ConstraintExpression constraints, - mlir::MLIRContext* context) + SymbolicTileAnalysis( + std::vector> + symbolic_tiled_hlo_instructions, + ConstraintExpression constraints, + std::unique_ptr emitter_specific_constraints, + mlir::MLIRContext* context) : symbolic_tiled_hlo_instructions_( std::move(symbolic_tiled_hlo_instructions)), constraints_(std::move(constraints)), + emitter_specific_constraints_(std::move(emitter_specific_constraints)), context_(context) {} // The tiled HLO instructions in def-before-use order. @@ -143,6 +169,10 @@ class SymbolicTileAnalysis { // See the documentation of GetConstraints(). ConstraintExpression constraints_; + // Additional emitter-specific constraints on tile parameters. May be null if + // no builder was provided when constructing the analysis. + std::unique_ptr emitter_specific_constraints_; + mlir::MLIRContext* context_; }; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index b0287b572c1f51..c7fecd6525e2f2 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/service/gpu/model/symbolic_tile.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/instruction_fusion.h" @@ -78,15 +79,44 @@ Matcher MatchTiledHloInstruction( tile_offsets_indexing); } +// Fake emitter-specific constraints for testing. Requires that the tile size +// along the first dimension is exactly half the size of the axis. +class FakeEmitterSpecificConstraints : public EmitterSpecificConstraints { + public: + absl::StatusOr ParametersSatisfyConstraints( + absl::Span tile_parameters) const override { + return tile_parameters[0] == dim0_tile_size_; + } + + static EmitterSpecificConstraintsBuilder GetBuilder() { + return [](const std::vector>& + instructions) { + const SymbolicTiledHloInstruction* root = instructions[0].get(); + int64_t dim0_size = root->hlo()->shape().dimensions(0); + return std::make_unique( + /*dim0_tile_size=*/dim0_size / 2); + }; + } + + explicit FakeEmitterSpecificConstraints(int64_t dim0_tile_size) + : dim0_tile_size_(dim0_tile_size) {} + + private: + int64_t dim0_tile_size_; +}; + class SymbolicTileAnalysisTest : public HloTestBase { public: - std::optional TryAnalyzeModule(HloModule* module) { + std::optional TryAnalyzeModule( + HloModule* module, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder = + nullptr) { SymbolicTileAnalysisOrError analysis_or_error = SymbolicTileAnalysis::AnalyzeComputation( *module->entry_computation() ->root_instruction() ->fused_instructions_computation(), - &mlir_context_); + &mlir_context_, emitter_specific_constraints_builder); if (std::holds_alternative(analysis_or_error)) { return std::get(std::move(analysis_or_error)); @@ -507,6 +537,35 @@ ENTRY main { impossible_tile_parameters, /*constraints_are_known_satisfied=*/true)); } +TEST_F(SymbolicTileAnalysisTest, EmitterSpecificConstraintsAreUsedCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( + fusion { + p0 = f32[16,32] parameter(0) + ROOT add = f32[16,32] add(p0, p0) + } + + ENTRY main { + p0 = f32[16,32] parameter(0) + ROOT fusion = f32[16,32] fusion(p0), kind=kLoop, calls=fusion + })")); + + std::optional analysis = TryAnalyzeModule( + module.get(), FakeEmitterSpecificConstraints::GetBuilder()); + + ASSERT_TRUE(analysis.has_value()); + + // FakeEmitterSpecificConstraints require that the tile size along the first + // dimension is exactly half the size of the axis. Tile sizes {5, 32} do not + // satisfy emitter-specific constraints. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({5, 32}), + IsOkAndHolds(false)); + + // However, tile sizes {8, 32} do satisfy emitter-specific constraints. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({8, 32}), + IsOkAndHolds(true)); +} + TEST_F(SymbolicTileAnalysisTest, ConstraintsAreAggregatedCorrectly) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 1db55375c0cc84..92c851d94ea929 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -549,6 +549,61 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReshapeOfReverse) { )"))); } +TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReductionOfSplittedAxis) { + // A split reshape of a reverse creates a sum of strided symbols. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + computation { + p0 = f32[18] parameter(0) + bitcast = f32[9,2] bitcast(p0) + c0 = f32[] constant(0) + reduce_0 = f32[9] reduce(bitcast, c0), dimensions={1}, to_apply=add + ROOT reduce_1 = f32[] reduce(reduce_0, c0), dimensions={0}, to_apply=add + } + + ENTRY e { + p0 = f32[18] parameter(0) + ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=computation + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTileString(R"( + Symbolic tile with + offset_map: () -> (0) + size_map: () -> (18) + stride_map: () -> (1) + )"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughSummationOfSymbols) { + // Such an indexing map is representative of a sequence of HLOs containing a + // bitcast followed by two sequential reductions of the split axis, i.e. + // something like + // p0 = f32[18] parameter(0) + // bitcast = f32[9,2] bitcast(p0) + // reduce_0 = f32[9] reduce(bitcast), dimensions={1} + // reduce_1 = f32[] reduce(reduce_0), dimensions={0} + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("()[s0, s1] -> (s1 * 2 + s0)", &mlir_context_), {}, + {2, 9}); + + EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map), + Optional(MatchSymbolicTileString(R"( + Symbolic tile with + offset_map: () -> (0) + size_map: () -> (18) + stride_map: () -> (1) + )"))); +} + TEST_F(SymbolicTileTest, FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshape) { // TODO(b/349487906): constraints should allow us to unblock this use case. diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc new file mode 100644 index 00000000000000..6ccc3db120c697 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/triton_emitter_constraints.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/IR/AffineMap.h" +#include "xla/service/gpu/model/affine_map_evaluator.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" + +namespace xla { +namespace gpu { + +namespace { + +// Triton enforces that all tensors in the program have less than 1048576 +// elements, otherwise it will fail to compile. +constexpr int64_t kMaxTensorNumElements = 1048576; + +} // namespace + +/*static*/ EmitterSpecificConstraintsBuilder +TritonEmitterConstraints::GetBuilder() { + return [](const std::vector>& + instructions) { + llvm::DenseSet unique_tile_size_maps; + for (const auto& tiled_hlo_instruction : instructions) { + unique_tile_size_maps.insert( + tiled_hlo_instruction->symbolic_tile().size_map()); + } + + return std::make_unique( + llvm::SmallVector(unique_tile_size_maps.begin(), + unique_tile_size_maps.end())); + }; +} + +absl::StatusOr TritonEmitterConstraints::ParametersSatisfyConstraints( + absl::Span tile_parameters) const { + // Verify that the tile sizes are not too big. + for (const auto& tile_size_map : tile_size_maps_) { + int64_t tile_size = 1; + for (auto expr : tile_size_map.getResults()) { + tile_size *= llvm::PowerOf2Ceil( + EvaluateAffineExpr(expr, /*dim_values=*/tile_parameters)); + } + + if (tile_size > kMaxTensorNumElements) { + return false; + } + } + return true; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h new file mode 100644 index 00000000000000..d5281bd12f0e98 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/AffineMap.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" + +#ifndef XLA_SERVICE_GPU_MODEL_TRITON_EMITTER_CONSTRAINTS_H_ +#define XLA_SERVICE_GPU_MODEL_TRITON_EMITTER_CONSTRAINTS_H_ + +namespace xla { +namespace gpu { + +// Triton-specific constraints on tile sizes. +class TritonEmitterConstraints : public EmitterSpecificConstraints { + public: + static EmitterSpecificConstraintsBuilder GetBuilder(); + + explicit TritonEmitterConstraints( + llvm::SmallVector tile_size_maps) + : tile_size_maps_(std::move(tile_size_maps)) {} + + absl::StatusOr ParametersSatisfyConstraints( + absl::Span tile_parameters) const override; + + private: + // A collection of unique size maps from all the SymbolicTiledHloInstructions. + // + // Different TiledHloInstructions often have the same size map, so we keep a + // collection of unique maps to improve compilation time. + llvm::SmallVector tile_size_maps_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_TRITON_EMITTER_CONSTRAINTS_H_ diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc new file mode 100644 index 00000000000000..827c2fa488a307 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/triton_emitter_constraints.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/log/log.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/instruction_fusion.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +class TritonEmitterConstraintsTest : public HloTestBase { + public: + std::optional TryAnalyzeModule(HloModule* module) { + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeComputation( + *module->entry_computation() + ->root_instruction() + ->fused_instructions_computation(), + &mlir_context_, TritonEmitterConstraints::GetBuilder()); + + if (std::holds_alternative(analysis_or_error)) { + return std::get(std::move(analysis_or_error)); + } + VLOG(1) << "Cannot analyze module: " + << std::get(analysis_or_error).Explain(); + return std::nullopt; + } + + mlir::MLIRContext mlir_context_; +}; + +TEST_F(TritonEmitterConstraintsTest, TritonSpecificConstraintsAreEnforced) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +max_computation { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(param_0, param_1) +} + +fused_computation { + param_0 = f32[8192,50304] parameter(0) + constant = f32[] constant(-inf) + reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=max_computation + broadcast = f32[8192,50304] broadcast(reduce), dimensions={0} + ROOT subtract = f32[8192,50304] subtract(param_0, broadcast) +} + +ENTRY entry_computation { + param_0 = f32[8192,50304] parameter(0) + ROOT fusion = f32[8192,50304] fusion(param_0), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton"}} +} +)")); + + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + // The biggest tile in the program has 8 * 65536 = 524288 elements. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({8, 128}), + IsOkAndHolds(true)); + + // The biggest tile in the program is 18 * 50304 = 905472 elements which is + // smaller than the limit of 1048576, but since Triton requires all tile sizes + // to be a power of 2, the actual tile will be 32 * 65536 = 2097152 elements. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({18, 50304}), + IsOkAndHolds(false)); + + // Because of reduce, we need to load full rows from param_0 and the load tile + // will be 1024 * 65536 = 67108864 elements, that is larger than the limit of + // 1048576. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({1024, 1}), + IsOkAndHolds(false)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 2f84317065f0f9..2044115f3bbc44 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -59,17 +59,17 @@ limitations under the License. #include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h" #include "xla/service/gpu/buffer_sharing.h" #include "xla/service/gpu/cublas_padding_requirements.h" -#include "xla/service/gpu/gpu_algebraic_simplifier.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_compiler.h" -#include "xla/service/gpu/gpu_conv_padding_legalization.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" -#include "xla/service/gpu/gpu_sort_rewriter.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/target_constants.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" +#include "xla/service/gpu/transforms/conv_padding_legalization.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include "xla/service/gpu/transforms/cublas_pad_for_gemms.h" +#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h" #include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h" #include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h" #include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h" @@ -78,10 +78,10 @@ limitations under the License. #include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h" #include "xla/service/gpu/transforms/cudnn_simplify_padding.h" #include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h" -#include "xla/service/gpu/transforms/cudnn_workspace_rewriter.h" #include "xla/service/gpu/transforms/dot_sparsity_rewriter.h" #include "xla/service/gpu/transforms/gpusolver_rewriter.h" -#include "xla/service/gpu/triangular_solve_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" +#include "xla/service/gpu/transforms/triangular_solve_rewriter.h" #include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dataflow_analysis.h" @@ -191,7 +191,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( auto cuda_compute_capability = std::get(gpu_version); // Convert convolutions into CustomCalls to cudnn, then canonicalize them - // (GpuConvPaddingLegalization). Also expand cuSolver calls. + // (ConvPaddingLegalization). Also expand cuSolver calls. HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantCheckerDebug( /*layout_sensitive=*/false, @@ -206,10 +206,10 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(&matmul_bf16_support); pipeline.AddPass(); - pipeline.AddPass(cuda_compute_capability); + pipeline.AddPass(cuda_compute_capability); pipeline.AddPass(cuda_compute_capability, dnn_version, GetToolkitVersion()); - pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(cuda_compute_capability); pipeline.AddPass(cuda_compute_capability, dnn_version); @@ -234,7 +234,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( // e.g. clean up unnecessary nop `convert`s. pipeline.AddPass(); - // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and + // tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and // CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover // to a fixed point. Include algsimp because ReshapeMover relies on it. [&, &pipeline = pipeline.AddPass>( @@ -256,7 +256,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(algsimp_options, gpu_version); }(); - // GpuConvRewriter, GpuConvPaddingLegalization and + // ConvRewriter, ConvPaddingLegalization and // CudnnConvPadForTensorCores may add instructions which can be simplified // by constant folding. pipeline.AddPass(); @@ -342,9 +342,6 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( // Transform TriangularSolve ops into custom-calls, so we can add temp // memory. post_pipeline.AddPass(); - if (stream_exec) { - post_pipeline.AddPass(*stream_exec); - } TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status()); return absl::OkStatus(); @@ -390,20 +387,22 @@ absl::Status NVPTXCompiler::AddGemmFusionAutotuningPasses( absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) { if (debug_options.xla_gpu_enable_cub_radix_sort()) { - pipeline->AddPass(); + pipeline->AddPass(); } return absl::OkStatus(); } -absl::Status NVPTXCompiler::RunCudnnFusionCompilerPass( +absl::Status NVPTXCompiler::RunCudnnCompilerPasses( HloModule* module, se::StreamExecutor* stream_exec, BinaryMap* dnn_compiled_graphs) { tsl::profiler::ScopedAnnotation annotation([&] { return absl::StrFormat("XlaCompileCudnnFusion:#module=%s,program_id=%d#", module->name(), module->unique_id()); }); - CuDnnFusionCompiler cudnn_compiler(*stream_exec, *dnn_compiled_graphs); - return cudnn_compiler.Run(module).status(); + CuDnnFusionCompiler fusion_compiler(*stream_exec, *dnn_compiled_graphs); + TF_RETURN_IF_ERROR(fusion_compiler.Run(module).status()); + CuDnnCustomCallCompiler call_compiler(*stream_exec, *dnn_compiled_graphs); + return call_compiler.Run(module).status(); } namespace { @@ -581,10 +580,12 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config, VLOG(2) << "Deferring the PTX to CUBIN compilation of the relocatable " "module to the linking step."; std::vector binary; - binary.reserve(sizeof(kPtxPrefix) + ptx.size() + 1); - binary.insert(binary.end(), kPtxPrefix, kPtxPrefix + sizeof(kPtxPrefix)); - binary.insert(binary.end(), ptx.begin(), ptx.end()); - binary.emplace_back('\0'); + if (!ptx.empty()) { + binary.reserve(sizeof(kPtxPrefix) + ptx.size() + 1); + binary.insert(binary.end(), kPtxPrefix, kPtxPrefix + sizeof(kPtxPrefix)); + binary.insert(binary.end(), ptx.begin(), ptx.end()); + binary.emplace_back('\0'); + } return BackendCompileResult{std::move(ptx), std::move(binary)}; } diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index fb74f553d67967..6d84deb4398176 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -84,9 +84,9 @@ class NVPTXCompiler : public GpuCompiler { absl::Status AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) override; - absl::Status RunCudnnFusionCompilerPass( - HloModule* module, se::StreamExecutor* stream_exec, - BinaryMap* dnn_compiled_graphs) override; + absl::Status RunCudnnCompilerPasses(HloModule* module, + se::StreamExecutor* stream_exec, + BinaryMap* dnn_compiled_graphs) override; HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override; diff --git a/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc b/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc index 21d2590e763176..16fdf7c8fbe6ff 100644 --- a/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc +++ b/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc @@ -166,16 +166,6 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x" // // %nctaid.x is currently specified as 2147483647. - if (launch_dimensions_.thread_counts_per_block().y > 1) { - // When blockDim.y > 1, then we are in the small row case. Each - // blockDim.x do exatly to one row and blockDim.y map to some - // consecutive row. This prevents too small block size that isn't - // efficient. - CHECK(launch_config_.row_vectorized); - CHECK_EQ(shape_.dimensions().back(), - launch_dimensions_.thread_counts_per_block().x * - launch_config_.unroll_factor); - } CHECK_EQ(launch_dimensions_.thread_counts_per_block().z, 1); CHECK_EQ(launch_dimensions_.block_counts().y, 1); CHECK_EQ(launch_dimensions_.block_counts().z, 1); @@ -189,14 +179,6 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Value* linear_index_base = linear_base_and_thread_idx.linear_base; - llvm::Value* row_index = - launch_config_.row_vectorized - ? b_->CreateMul(linear_base_and_thread_idx.thread_idx, - llvm::ConstantInt::get(index_type, - launch_config_.unroll_factor), - "row_index", /*HasNUW=*/true, /*HasNSW=*/true) - : nullptr; - std::vector multidim(shape_.rank(), nullptr); for (int i = 0; i < launch_config_.unroll_factor; ++i) { // The add operation is needed even if the offset is 0, since when the @@ -207,17 +189,6 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i), absl::StrCat("linear_index", i), /*HasNUW=*/true, /*HasNSW=*/true); - if (launch_config_.row_vectorized) { - // This lets us avoid emitting the division for the last dimension of the - // index. The check for i > 0 is here for historical reasons, it might not - // do anything. - multidim.back() = - i == 0 ? row_index - : b_->CreateAdd( - row_index, llvm::ConstantInt::get(index_type, i), - absl::StrCat("row_index_plus", i), /*HasNUW=*/true, - /*HasNSW=*/true); - } array_indices.emplace_back(linear_index, multidim, shape_, b_); } diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc index 9fb7fd4697478d..7ced52cbd17ca0 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc @@ -22,10 +22,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/copy_insertion.h" #include "xla/service/cpu_gpu_shape_verifier.h" -#include "xla/service/gpu/gpu_sanitize_constant_names.h" -#include "xla/service/gpu/horizontal_loop_fusion.h" #include "xla/service/gpu/transforms/alias_passthrough_params.h" #include "xla/service/gpu/transforms/copy_fusion.h" +#include "xla/service/gpu/transforms/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/sanitize_constant_names.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_pass_pipeline.h" @@ -78,14 +78,14 @@ HloPassPipeline PrepareHloModuleForIrEmittingPipeline( } // We are using a sub-pipeline here, so that the verifier only runs after both - // GpuHorizontalLoopFusion and HloDCE. + // HorizontalLoopFusion and HloDCE. auto& sub_pipeline = pipeline.AddPass("horizontal-loop-fusion-for-copy"); // To fuse the copy. sub_pipeline.AddPass(); - sub_pipeline.AddPass("copy_"); + sub_pipeline.AddPass("copy_"); sub_pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); return pipeline; } diff --git a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc index cd0be0b867612a..03adc0b2cdaea5 100644 --- a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc @@ -102,12 +102,20 @@ ENTRY e { "num_ctas":1}}} })"; +constexpr std::string_view kResultsInNoPtxHlo = R"( + ENTRY e { + a = f32[5,5] parameter(0) + ROOT _ = f32[5,5] custom-call(a, a), custom_call_target="__cublas$gemm", + backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" + })"; + std::string_view GetHlo(std::string_view name) { static const absl::flat_hash_map* const kHloMap = new absl::flat_hash_map( {{"simple", kSimpleHlo}, {"parallel_compilation", kParallelCompilationHlo}, - {"requires_sm90a", kSM90AHlo}}); + {"requires_sm90a", kSM90AHlo}, + {"results_in_no_ptx", kResultsInNoPtxHlo}}); return kHloMap->at(name); } @@ -288,15 +296,20 @@ TEST_P(NVPTXCompilationTests, CompareBinaryOutput) { absl::Span reference_binary = static_cast(reference.value().get())->binary(); - if (executable_binary != reference_binary) { - std::string test_name = - GenerateParametrizedTestname(name, compilation_method, linking_method); - DumpArtifactIfEnabled(absl::StrCat(test_name, "_executable.bin"), - executable_binary); - DumpArtifactIfEnabled(absl::StrCat(test_name, "_reference.bin"), - reference_binary); + if (executable_binary == reference_binary) { + // If the binaries are exactly the same, we can short circuit and don't need + // to parse them. + SUCCEED(); + return; } + std::string test_name = + GenerateParametrizedTestname(name, compilation_method, linking_method); + DumpArtifactIfEnabled(absl::StrCat(test_name, "_executable.bin"), + executable_binary); + DumpArtifactIfEnabled(absl::StrCat(test_name, "_reference.bin"), + reference_binary); + auto get_text_sections = [&](absl::Span binary) -> absl::StatusOr> { auto buffer = llvm::MemoryBuffer::getMemBuffer( @@ -341,14 +354,15 @@ TEST_P(NVPTXCompilationTests, CompareBinaryOutput) { INSTANTIATE_TEST_SUITE_P( NVPTXCompilationTest, NVPTXCompilationTests, - ::testing::Combine( - ::testing::Values("simple", "parallel_compilation", "requires_sm90a"), - ::testing::Values(PtxCompilationMethod::kNvPtxCompiler, - PtxCompilationMethod::kPtxas, - PtxCompilationMethod::kNvJitLink), - ::testing::Values(PtxLinkingMethod::kNone, PtxLinkingMethod::kNvLink, - PtxLinkingMethod::kDriver, - PtxLinkingMethod::kNvJitLink)), + ::testing::Combine(::testing::Values("simple", "parallel_compilation", + "requires_sm90a", "results_in_no_ptx"), + ::testing::Values(PtxCompilationMethod::kNvPtxCompiler, + PtxCompilationMethod::kPtxas, + PtxCompilationMethod::kNvJitLink), + ::testing::Values(PtxLinkingMethod::kNone, + PtxLinkingMethod::kNvLink, + PtxLinkingMethod::kDriver, + PtxLinkingMethod::kNvJitLink)), [](const ::testing::TestParamInfo>& info) { return GenerateParametrizedTestname(std::get<0>(info.param), diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 861fcca6b45eb4..5926ed1066bd74 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -79,7 +79,6 @@ cc_library( "//xla/service:executable", "//xla/service:global_device_id", "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_fused_mha_runner", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", @@ -121,7 +120,6 @@ cc_library( ":copy_thunk", ":cudnn_thunk", ":custom_call_thunk", - ":fused_mha_thunk", ":gemm_thunk", ":gpublas_lt_matmul_thunk", ":kernel_thunk", @@ -163,11 +161,12 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor/gpu:gpu_test_kernels", + "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -296,6 +295,7 @@ cc_library( "@com_google_absl//absl/crc:crc32c", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", @@ -424,7 +424,6 @@ cc_library( name = "command_buffer_thunk", srcs = ["command_buffer_thunk.cc"], hdrs = ["command_buffer_thunk.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":annotation", ":command_buffer_cmd", @@ -451,7 +450,7 @@ cc_library( xla_test( name = "command_buffer_thunk_test", - srcs = if_gpu_is_configured(["command_buffer_thunk_test.cc"]), + srcs = ["command_buffer_thunk_test.cc"], backend_tags = { "gpu_a100": if_google(["config-cuda-only"]), "gpu_v100": if_google(["config-cuda-only"]), @@ -485,10 +484,12 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", + "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -660,28 +661,6 @@ cc_library( ], ) -cc_library( - name = "fused_mha_thunk", - srcs = ["fused_mha_thunk.cc"], - hdrs = ["fused_mha_thunk.h"], - deps = [ - ":thunk", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_fused_mha_runner", - "//xla/stream_executor", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "gemm_thunk", srcs = ["gemm_thunk.cc"], diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index d913871332933d..aceb2cdbb94666 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -1167,314 +1167,6 @@ CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { {workspace_, MemoryAccess::kWrite}}; } -//===----------------------------------------------------------------------===// -// FusedMHACmd -//===----------------------------------------------------------------------===// - -FusedMHACmd::FusedMHACmd( - ExecutionStreamId execution_stream_id, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) - : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHACmd, - execution_stream_id), - config_(std::move(config)), - lhs_bmm1_buffer_(lhs_bmm1), - rhs_bmm1_buffer_(rhs_bmm1), - rhs_bmm2_buffer_(rhs_bmm2), - output_buffer_(output), - scratch_buffer_(scratch), - bias_buffer_(bias), - activation_buffer_(activation), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k) {} - -FusedMultiHeadedAttentionRunner& FusedMHACmd::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mutex_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHACmd::Initialize(const Thunk::InitializeParams& params, - StateManager& state) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.command_buffer_trace_stream).AsFusedMHARunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig()); - return lazy_runner - ->GetOrCreateRunner(config, params.command_buffer_trace_stream) - .status(); -} - -absl::Status FusedMHACmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(execute_params.command_buffer_trace_stream) - .AsFusedMHARunner(); - CHECK(lazy_runner) << "FusedMHA lazy runner cache should have been populated"; - - const auto& buffer_allocations = *execute_params.buffer_allocations; - se::DeviceMemoryBase lhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm2_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_); - se::DeviceMemoryBase output_buffer = - buffer_allocations.GetDeviceAddress(output_buffer_); - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional activation_buffer = - AssignBufferIfNotNull(buffer_allocations, activation_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - - ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "FusedMHACmd with execution_scope_id: " - << execution_scope_id.value(); - VLOG(5) << " lhs_bmm1_buffer: " << lhs_bmm1_buffer_.ToString(); - VLOG(5) << " rhs_bmm1_buffer: " << rhs_bmm1_buffer_.ToString(); - VLOG(5) << " rhs_bmm2_buffer: " << rhs_bmm2_buffer_.ToString(); - VLOG(5) << " output_buffer: " << output_buffer_.ToString(); - VLOG(5) << " scratch_buffer: " << scratch_buffer_.ToString(); - VLOG(5) << " bias_buffer: " << bias_buffer_.ToString(); - VLOG(5) << " activation_buffer: " << activation_buffer_.ToString(); - VLOG(5) << " seqlen_q_buffer: " << seqlen_q_buffer_.ToString(); - VLOG(5) << " seqlen_k_buffer: " << seqlen_k_buffer_.ToString(); - - RunFusedMHAOptions opts; - opts.runner_cache = - &GetOrCreateRunner(execute_params.command_buffer_trace_stream); - return AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, - bias_buffer, activation_buffer, seqlen_q_buffer, - seqlen_k_buffer, stream, opts); - }); -} - -FusedMHACmd::BufferUsageVector FusedMHACmd::buffers() { - BufferUsageVector buffer_usage; - buffer_usage.reserve(9); - buffer_usage.push_back({lhs_bmm1_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({rhs_bmm1_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({rhs_bmm2_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({output_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite}); - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); - } - if (activation_buffer_.allocation() != nullptr) { - buffer_usage.push_back({activation_buffer_, MemoryAccess::kRead}); - } - if (seqlen_q_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead}); - } - if (seqlen_k_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead}); - } - return buffer_usage; -} - -//===----------------------------------------------------------------------===// -// FusedMHABackwardCmd -//===----------------------------------------------------------------------===// - -FusedMHABackwardCmd::FusedMHABackwardCmd( - ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output, - BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k) - : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHABackwardCmd, - execution_stream_id), - config_(std::move(config)), - bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), - bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), - bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs), - bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs), - d_output_buffer_(d_output), - scratch_buffer_(scratch), - d_bmm1_lhs_buffer_(d_bmm1_lhs), - d_bmm1_rhs_buffer_(d_bmm1_rhs), - d_bmm2_rhs_buffer_(d_bmm2_rhs), - d_s_buffer_(d_s), - d_bias_buffer_(d_bias), - fwd_output_buffer_(fwd_output), - bias_buffer_(bias), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k) {} - -FusedMultiHeadedAttentionBackwardRunner& FusedMHABackwardCmd::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mutex_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, - std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHABackwardCmd::Initialize( - const Thunk::InitializeParams& params, StateManager& state) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.command_buffer_trace_stream) - .AsFusedMHABackwardRunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig()); - return lazy_runner - ->GetOrCreateRunner(config, params.command_buffer_trace_stream) - .status(); -} - -absl::Status FusedMHABackwardCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(execute_params.command_buffer_trace_stream) - .AsFusedMHABackwardRunner(); - CHECK(lazy_runner) - << "FusedMHABackward lazy runner cache should have been populated"; - - const auto& buffer_allocations = *execute_params.buffer_allocations; - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_); - - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase d_output_buffer = - buffer_allocations.GetDeviceAddress(d_output_buffer_); - - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - se::DeviceMemoryBase d_bmm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_); - - se::DeviceMemoryBase d_bmm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_); - - se::DeviceMemoryBase d_bmm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_); - - std::optional d_s_buffer = - AssignBufferIfNotNull(buffer_allocations, d_s_buffer_); - std::optional d_bias_buffer = - AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_); - std::optional fwd_output_buffer = - AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - - ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "FusedMHABackwardCmd with execution_scope_id: " - << execution_scope_id.value(); - VLOG(5) << "bmm1_grad_gemm1_rhs_buffer" - << bmm1_grad_gemm1_rhs_buffer_.ToString(); - VLOG(5) << "bmm1_grad_gemm2_rhs_buffer" - << bmm1_grad_gemm2_rhs_buffer_.ToString(); - VLOG(5) << "bmm2_grad_gemm1_lhs_buffer" - << bmm2_grad_gemm1_lhs_buffer_.ToString(); - VLOG(5) << "bmm2_grad_gemm2_rhs_buffer" - << bmm2_grad_gemm2_rhs_buffer_.ToString(); - VLOG(5) << "d_output_buffer" << d_output_buffer_.ToString(); - VLOG(5) << "scratch_buffer" << scratch_buffer_.ToString(); - VLOG(5) << "d_bmm1_lhs_buffer" << d_bmm1_lhs_buffer_.ToString(); - VLOG(5) << "d_bmm1_rhs_buffer" << d_bmm1_rhs_buffer_.ToString(); - VLOG(5) << "d_bmm2_rhs_buffer" << d_bmm2_rhs_buffer_.ToString(); - VLOG(5) << "d_s_buffer" << d_s_buffer_.ToString(); - VLOG(5) << "d_bias_buffer" << d_bias_buffer_.ToString(); - VLOG(5) << "fwd_output_buffer" << fwd_output_buffer_.ToString(); - VLOG(5) << "bias_buffer" << bias_buffer_.ToString(); - VLOG(5) << "seqlen_q_buffer" << seqlen_q_buffer_.ToString(); - VLOG(5) << "seqlen_k_buffer" << seqlen_k_buffer_.ToString(); - - RunFusedMHABackwardOptions opts; - opts.runner_cache = - &GetOrCreateRunner(execute_params.command_buffer_trace_stream); - return AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGpuFMHABackward( - config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, seqlen_q_buffer, seqlen_k_buffer, - stream, opts); - }); -} - -FusedMHABackwardCmd::BufferUsageVector FusedMHABackwardCmd::buffers() { - BufferUsageVector buffer_usage; - buffer_usage.reserve(15); - - buffer_usage.push_back({bmm1_grad_gemm1_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm1_grad_gemm2_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm2_grad_gemm1_lhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm2_grad_gemm2_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_output_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({d_bmm1_lhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_bmm1_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_bmm2_rhs_buffer_, MemoryAccess::kRead}); - - if (d_s_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_s_buffer_, MemoryAccess::kRead}); - }; - if (d_bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_bias_buffer_, MemoryAccess::kRead}); - }; - if (fwd_output_buffer_.allocation() != nullptr) { - buffer_usage.push_back({fwd_output_buffer_, MemoryAccess::kRead}); - }; - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); - }; - if (seqlen_q_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead}); - }; - if (seqlen_k_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead}); - }; - - return buffer_usage; -} - //===----------------------------------------------------------------------===// // CublasLtCmd //===----------------------------------------------------------------------===// @@ -1920,30 +1612,16 @@ absl::Status CollectiveCmd::BarrierIfAsync( absl::Status CollectiveCmd::Prepare( const Thunk::PrepareParams& params, Thunk::ResourceRequests& resource_requests) { - const Thunk::CollectiveExecuteParams* collectives = params.collective_params; - TF_ASSIGN_OR_RETURN( - std::vector participants, - GetParticipatingDevices(collectives->global_device_id, - *collectives->device_assn, + NcclCliqueKey clique_key, + GetNcclCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, nccl_stream_id(), + GetAsyncStreamKind())); + TF_ASSIGN_OR_RETURN( + size_t num_local_participants, + GetNumLocalParticipants(*params.collective_params, config().replica_groups, config().group_mode)); - - std::vector local_devices; - if (collectives->global_device_id_map) { - local_devices.reserve(collectives->global_device_id_map->size()); - for (const auto& entry : *collectives->global_device_id_map) { - local_devices.push_back(entry.second); - } - } - - size_t num_local_participants = GetNumLocalParticipants( - participants, - collectives->global_device_id_map ? &local_devices : nullptr); - - return resource_requests.AddClique( - NcclCliqueKey(std::move(participants), nccl_stream_id(), - GetAsyncStreamKind()), - num_local_participants); + return resource_requests.AddClique(clique_key, num_local_participants); } absl::Status CollectiveCmd::AddTracedCommandBuffer( diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index b7a077e81a9e4f..27e8fea0d86366 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -40,7 +40,6 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" @@ -81,8 +80,6 @@ namespace xla::gpu { V(kReduceScatter, "ReduceScatterCmd") \ V(kAllGatherCmd, "AllGatherCmd") \ V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd") \ - V(kFusedMHACmd, "FusedMHACmd") \ - V(kFusedMHABackwardCmd, "FusedMHABackwardCmd") \ V(kUnknownCmd, "UnknownCmd") \ // clang-format on @@ -782,112 +779,6 @@ class GemmCmd : public TracedCommandBufferCmd { const bool deterministic_; }; -//===----------------------------------------------------------------------===// -// FusedMHACmd -//===----------------------------------------------------------------------===// - -class FusedMHACmd : public TracedCommandBufferCmd { - public: - FusedMHACmd(ExecutionStreamId execution_stream_id, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, - BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - BufferUsageVector buffers() override; - - bool IsNestedCommandBuffer() const final { return true; } - - private: - FusedMultiHeadedAttentionRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - const GpufMHAConfig config_; - BufferAllocation::Slice lhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm2_buffer_; - BufferAllocation::Slice output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice activation_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - // FusedMHA config - absl::Mutex mutex_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mutex_); -}; - -//===----------------------------------------------------------------------===// -// FusedMHABackwardCmd -//===----------------------------------------------------------------------===// - -class FusedMHABackwardCmd : public TracedCommandBufferCmd { - public: - FusedMHABackwardCmd( - ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output, - BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - BufferUsageVector buffers() override; - - bool IsNestedCommandBuffer() const final { return true; } - - private: - FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - const GpufMHABackwardConfig config_; - BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_; - BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice d_output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice d_bmm1_lhs_buffer_; - BufferAllocation::Slice d_bmm1_rhs_buffer_; - BufferAllocation::Slice d_bmm2_rhs_buffer_; - BufferAllocation::Slice d_s_buffer_; - BufferAllocation::Slice d_bias_buffer_; - BufferAllocation::Slice fwd_output_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - // FusedMHA config - absl::Mutex mutex_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mutex_); -}; - //===----------------------------------------------------------------------===// // CublasLtCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index 54e01fab8e1109..230d050856fcc2 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/cudnn_thunk.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" -#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/service/gpu/runtime/kernel_thunk.h" @@ -143,27 +142,6 @@ static absl::StatusOr Convert(const CublasLtMatmulThunk& thunk) { thunk.workspace().value()); } -static absl::StatusOr Convert(const FusedMHAThunk& thunk) { - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), thunk.lhs_bmm1_buffer(), - thunk.rhs_bmm1_buffer(), thunk.rhs_bmm2_buffer(), thunk.output_buffer(), - thunk.scratch_buffer(), BufferAllocation::Slice(), thunk.bias_buffer(), - thunk.activation_buffer(), thunk.seqlen_q_buffer(), - thunk.seqlen_k_buffer()); -} - -static absl::StatusOr Convert(const FusedMHABackwardThunk& thunk) { - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), - thunk.bmm1_grad_gemm1_rhs_buffer(), thunk.bmm1_grad_gemm2_rhs_buffer(), - thunk.bmm2_grad_gemm1_lhs_buffer(), thunk.bmm2_grad_gemm2_rhs_buffer(), - thunk.d_output_buffer(), thunk.scratch_buffer(), - thunk.d_bmm1_lhs_buffer(), thunk.d_bmm1_rhs_buffer(), - thunk.d_bmm2_rhs_buffer(), thunk.d_s_buffer(), thunk.d_bias_buffer(), - thunk.fwd_output_buffer(), thunk.bias_buffer(), thunk.seqlen_q_buffer(), - thunk.seqlen_k_buffer()); -} - static absl::StatusOr Convert( const ConditionalThunk& thunk, CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { @@ -276,10 +254,6 @@ static absl::Status AppendCommands( return append(Convert(thunk)); case Thunk::Kind::kCustomKernel: return append(Convert(thunk)); - case Thunk::Kind::kFusedMHA: - return append(Convert(thunk)); - case Thunk::Kind::kFusedMHABackward: - return append(Convert(thunk)); case Thunk::Kind::kKernel: return append(Convert(thunk)); case Thunk::Kind::kGemm: diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc index 40ba5e35bd9d73..90b6e0666c8adf 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/strings/ascii.h" +#include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" @@ -30,7 +31,7 @@ limitations under the License. #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" @@ -352,20 +353,15 @@ TEST(CommandBufferCmdTest, LaunchCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(s0, "add", args, args_access, + commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); // Initialize command sequence and load device kernels. - Thunk::ExecutableSource source = { -#if defined(GOOGLE_CUDA) - /*text=*/se::gpu::internal::kAddI32Kernel, - /*binary=*/{} -#elif defined(TENSORFLOW_USE_ROCM) - /*text=*/{}, - /*binary=*/se::gpu::internal::kAddI32KernelModule -#endif - }; + TF_ASSERT_OK_AND_ASSIGN(std::vector fatbin, + se::gpu::GetGpuTestKernelsFatbin()); + Thunk::ExecutableSource source = {/*text=*/{}, + /*binary=*/fatbin}; CommandBufferCmd::StateManager state; TF_ASSERT_OK(commands.Initialize({executor, source}, state)); diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index f4fc9e22c62c4f..bce9d1927d05ea 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include #include +#include #include // NOLINT #include #include #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" @@ -41,6 +43,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" #include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -64,27 +67,29 @@ namespace xla::gpu { using MemoryAccess = CommandBufferCmd::MemoryAccess; using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; -static se::StreamExecutor* GpuExecutor() { +namespace { +se::StreamExecutor* GpuExecutor() { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); auto* platform = se::PlatformManager::PlatformWithName(name).value(); return platform->ExecutorForDevice(0).value(); } -static Thunk::ExecutableSource ExecutableSource() { - Thunk::ExecutableSource source = { -#if defined(GOOGLE_CUDA) - /*text=*/se::gpu::internal::kAddI32Kernel, - /*binary=*/{} -#elif defined(TENSORFLOW_USE_ROCM) - /*text=*/{}, - /*binary=*/se::gpu::internal::kAddI32KernelModule -#endif - }; - return source; +struct OwningExecutableSource { + std::string text; + std::vector binary; + + explicit operator Thunk::ExecutableSource() const { return {text, binary}; } +}; + +absl::StatusOr ExecutableSource() { + TF_ASSIGN_OR_RETURN(std::vector fatbin, + se::gpu::GetGpuTestKernelsFatbin()); + return OwningExecutableSource{/*text=*/{}, + /*binary=*/fatbin}; } -static KernelArgsPacking CreateDefaultArgsPacking() { +KernelArgsPacking CreateDefaultArgsPacking() { using Packed = absl::StatusOr>; return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed { @@ -96,7 +101,7 @@ static KernelArgsPacking CreateDefaultArgsPacking() { } // Some of the tests rely on CUDA 12.3+ features. -static bool IsAtLeastCuda12300() { +bool IsAtLeastCuda12300() { #if defined(TENSORFLOW_USE_ROCM) return false; #endif @@ -107,8 +112,9 @@ static bool IsAtLeastCuda12300() { } // Give a short aliases to execution threads. -static constexpr auto s0 = ExecutionStreamId(0); -static constexpr auto s1 = ExecutionStreamId(1); +constexpr auto s0 = ExecutionStreamId(0); +constexpr auto s1 = ExecutionStreamId(1); +} // namespace TEST(CommandBufferThunkTest, MemcpyCmd) { se::StreamExecutor* executor = GpuExecutor(); @@ -428,7 +434,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(s0, "add", args, args_access, + commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -442,9 +448,10 @@ TEST(CommandBufferThunkTest, LaunchCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -498,7 +505,7 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { spec.AddInProcessSymbol(se::gpu::internal::GetAddI32Kernel(), "add"); auto custom_kernel = - CustomKernel("add", std::move(spec), se::BlockDim(), + CustomKernel("AddI32", std::move(spec), se::BlockDim(), se::ThreadDim(4, 1, 1), /*shared_memory_bytes=*/0); int64_t length = 4; @@ -524,7 +531,7 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(s0, "add", args, args_access, + commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -538,9 +545,10 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -880,10 +888,10 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(s0, "add", args, args_access, + commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - commands.Emplace(s0, "add", args_1, args_access, + commands.Emplace(s0, "AddI32", args_1, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -897,9 +905,10 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -994,7 +1003,7 @@ TEST(CommandBufferThunkTest, IfCmd) { // Prepare commands sequence for `then` branch. CommandBufferCmdSequence then_commands; - then_commands.Emplace(s0, "add", args, args_access, + then_commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -1012,9 +1021,10 @@ TEST(CommandBufferThunkTest, IfCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1084,14 +1094,14 @@ TEST(CommandBufferThunkTest, IfElseCmd) { { // Then: b = a + a auto args = {slice_a, slice_a, slice_b}; - then_commands.Emplace(s0, "add", args, args_access, + then_commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); } { // Else: b = b + b auto args = {slice_b, slice_b, slice_b}; - else_commands.Emplace(s0, "add", args, args_access, + else_commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); } @@ -1111,9 +1121,10 @@ TEST(CommandBufferThunkTest, IfElseCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1174,14 +1185,14 @@ TEST(CommandBufferThunkTest, CaseCmd) { { // Case 0: b = a + a auto args = {slice_a, slice_a, slice_b}; - branches[0].Emplace(s0, "add", args, args_access, + branches[0].Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); } { // Case 1: b = b + b auto args = {slice_b, slice_b, slice_b}; - branches[1].Emplace(s0, "add", args, args_access, + branches[1].Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); } @@ -1200,9 +1211,10 @@ TEST(CommandBufferThunkTest, CaseCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1260,7 +1272,7 @@ TEST(CommandBufferThunkTest, ForCmd) { // Prepare commands sequence for loop `body`. CommandBufferCmdSequence body_commands; - body_commands.Emplace(s0, "add", args, args_access, + body_commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -1279,9 +1291,10 @@ TEST(CommandBufferThunkTest, ForCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value 10 times. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc index 604bce14592727..eda2bc6e2c3462 100644 --- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc +++ b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc @@ -74,8 +74,6 @@ void ForAllThunks(absl::FunctionRef fn, case Thunk::kCustomKernel: case Thunk::kCuDnn: case Thunk::kFft: - case Thunk::kFusedMHA: - case Thunk::kFusedMHABackward: case Thunk::kGemm: case Thunk::kInfeed: case Thunk::kKernel: diff --git a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc deleted file mode 100644 index ee13689fbbb578..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc +++ /dev/null @@ -1,230 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/fused_mha_thunk.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/synchronization/mutex.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -FusedMHAThunk::FusedMHAThunk( - ThunkInfo thunk_info, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) - : Thunk(Kind::kFusedMHA, thunk_info), - lhs_bmm1_buffer_(lhs_bmm1), - rhs_bmm1_buffer_(rhs_bmm1), - rhs_bmm2_buffer_(rhs_bmm2), - output_buffer_(output), - scratch_buffer_(scratch), - bias_buffer_(bias), - activation_buffer_(activation), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k), - config_(std::move(config)) {} - -FusedMultiHeadedAttentionRunner& FusedMHAThunk::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mu_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -std::optional AssignBufferIfNotNull( - const BufferAllocations& buffer_allocations, - BufferAllocation::Slice& slice) { - return slice.allocation() != nullptr - ? std::optional{buffer_allocations - .GetDeviceAddress(slice)} - : std::nullopt; -} - -absl::Status FusedMHAThunk::Initialize(const InitializeParams& params) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.stream).AsFusedMHARunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig()); - return lazy_runner->GetOrCreateRunner(config, params.stream).status(); -} - -absl::Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { - const auto& buffer_allocations = *params.buffer_allocations; - se::DeviceMemoryBase lhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm2_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_); - se::DeviceMemoryBase output_buffer = - buffer_allocations.GetDeviceAddress(output_buffer_); - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional activation_buffer = - AssignBufferIfNotNull(buffer_allocations, activation_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - RunFusedMHAOptions opts; - opts.runner_cache = &GetOrCreateRunner(params.stream); - TF_RETURN_IF_ERROR(RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, - bias_buffer, activation_buffer, seqlen_q_buffer, - seqlen_k_buffer, params.stream, opts)); - - if (!params.stream->ok()) { - return Internal("FusedMHAThunk::ExecuteOnStream failed."); - } - return absl::OkStatus(); -} -FusedMHABackwardThunk::FusedMHABackwardThunk( - ThunkInfo thunk_info, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice mask, BufferAllocation::Slice d_bias, - BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias, - BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) - : Thunk(Kind::kFusedMHABackward, thunk_info), - bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), - bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), - bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs), - bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs), - d_output_buffer_(d_output), - scratch_buffer_(scratch), - d_bmm1_lhs_buffer_(d_bmm1_lhs), - d_bmm1_rhs_buffer_(d_bmm1_rhs), - d_bmm2_rhs_buffer_(d_bmm2_rhs), - d_s_buffer_(d_s), - d_bias_buffer_(d_bias), - fwd_output_buffer_(fwd_output), - bias_buffer_(bias), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k), - config_(std::move(config)) {} - -FusedMultiHeadedAttentionBackwardRunner& -FusedMHABackwardThunk::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mu_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, - std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHABackwardThunk::Initialize(const InitializeParams& params) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.stream).AsFusedMHABackwardRunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig()); - return lazy_runner->GetOrCreateRunner(config, params.stream).status(); -} - -absl::Status FusedMHABackwardThunk::ExecuteOnStream( - const ExecuteParams& params) { - const auto& buffer_allocations = *params.buffer_allocations; - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_); - - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase d_output_buffer = - buffer_allocations.GetDeviceAddress(d_output_buffer_); - - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - se::DeviceMemoryBase d_bmm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_); - - se::DeviceMemoryBase d_bmm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_); - - se::DeviceMemoryBase d_bmm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_); - - std::optional d_s_buffer = - AssignBufferIfNotNull(buffer_allocations, d_s_buffer_); - std::optional d_bias_buffer = - AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_); - std::optional fwd_output_buffer = - AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - RunFusedMHABackwardOptions opts; - - opts.runner_cache = &GetOrCreateRunner(params.stream); - - TF_RETURN_IF_ERROR(RunGpuFMHABackward( - config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, - scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, - d_s_buffer, d_bias_buffer, fwd_output_buffer, bias_buffer, - seqlen_q_buffer, seqlen_k_buffer, params.stream, opts)); - if (!params.stream->ok()) { - return Internal("FusedMHABackwardThunk::ExecuteOnStream failed."); - } - return absl::OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h deleted file mode 100644 index 99a8327269499e..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h +++ /dev/null @@ -1,184 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ - -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/synchronization/mutex.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -// This class stores everything that StreamExecutor needs to launch a DNN -// fMHA. It is generated by IrEmitter. -// -// This is thread-compatible. -class FusedMHAThunk : public Thunk { - public: - // Constructs a thunk for launching a DNN FMHA. - FusedMHAThunk(ThunkInfo thunk_info, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1_slice, - BufferAllocation::Slice rhs_bmm1_slice, - BufferAllocation::Slice rhs_bmm2_slice, - BufferAllocation::Slice output_slice, - BufferAllocation::Slice scratch_slice, - BufferAllocation::Slice mask_slice, /* may be null */ - BufferAllocation::Slice bias_slice /* may be null */, - BufferAllocation::Slice activation_slice /* may be null */, - BufferAllocation::Slice seqlen_q_slice /* may be null */, - BufferAllocation::Slice seqlen_k_slice /* may be null */); - - FusedMHAThunk(const FusedMHAThunk&) = delete; - FusedMHAThunk& operator=(const FusedMHAThunk&) = delete; - - BufferAllocation::Slice lhs_bmm1_buffer() const { return lhs_bmm1_buffer_; } - BufferAllocation::Slice rhs_bmm1_buffer() const { return rhs_bmm1_buffer_; } - BufferAllocation::Slice rhs_bmm2_buffer() const { return rhs_bmm2_buffer_; } - BufferAllocation::Slice output_buffer() const { return output_buffer_; } - BufferAllocation::Slice scratch_buffer() const { return scratch_buffer_; } - BufferAllocation::Slice bias_buffer() const { return bias_buffer_; } - BufferAllocation::Slice activation_buffer() const { - return activation_buffer_; - } - BufferAllocation::Slice seqlen_q_buffer() const { return seqlen_q_buffer_; } - BufferAllocation::Slice seqlen_k_buffer() const { return seqlen_k_buffer_; } - - GpufMHAConfig config() const { return config_; } - absl::Status Initialize(const InitializeParams& params) override; - absl::Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - BufferAllocation::Slice lhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm2_buffer_; - BufferAllocation::Slice output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice activation_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - FusedMultiHeadedAttentionRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - // FusedMHA config - const GpufMHAConfig config_; - absl::Mutex mu_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mu_); -}; - -class FusedMHABackwardThunk : public Thunk { - public: - // Constructs a thunk for launching a DNN FMHA backward. - FusedMHABackwardThunk(ThunkInfo thunk_info, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, - BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice, - BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice, - BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice, - BufferAllocation::Slice d_output_slice, - BufferAllocation::Slice scratch_slice, - BufferAllocation::Slice d_bmm1_lhs_slice, - BufferAllocation::Slice d_bmm1_rhs_slice, - BufferAllocation::Slice d_bmm2_rhs_slice, - BufferAllocation::Slice d_s_slice, - BufferAllocation::Slice mask_slice, - BufferAllocation::Slice d_bias_slice, - BufferAllocation::Slice fwd_output_slice, - BufferAllocation::Slice bias_slice, - BufferAllocation::Slice seqlen_q_slice, - BufferAllocation::Slice seqlen_k_slice); - - FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete; - FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete; - - BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer() const { - return bmm1_grad_gemm1_rhs_buffer_; - } - BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer() const { - return bmm1_grad_gemm2_rhs_buffer_; - } - BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer() const { - return bmm2_grad_gemm1_lhs_buffer_; - } - BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer() const { - return bmm2_grad_gemm2_rhs_buffer_; - } - BufferAllocation::Slice d_output_buffer() const { return d_output_buffer_; } - BufferAllocation::Slice scratch_buffer() const { return scratch_buffer_; } - BufferAllocation::Slice d_bmm1_lhs_buffer() const { - return d_bmm1_lhs_buffer_; - } - BufferAllocation::Slice d_bmm1_rhs_buffer() const { - return d_bmm1_rhs_buffer_; - } - BufferAllocation::Slice d_bmm2_rhs_buffer() const { - return d_bmm2_rhs_buffer_; - } - BufferAllocation::Slice d_s_buffer() const { return d_s_buffer_; } - BufferAllocation::Slice d_bias_buffer() const { return d_bias_buffer_; } - BufferAllocation::Slice fwd_output_buffer() const { - return fwd_output_buffer_; - } - BufferAllocation::Slice bias_buffer() const { return bias_buffer_; } - BufferAllocation::Slice seqlen_q_buffer() const { return seqlen_q_buffer_; } - BufferAllocation::Slice seqlen_k_buffer() const { return seqlen_k_buffer_; } - - GpufMHABackwardConfig config() const { return config_; } - - absl::Status Initialize(const InitializeParams& params) override; - absl::Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_; - BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice d_output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice d_bmm1_lhs_buffer_; - BufferAllocation::Slice d_bmm1_rhs_buffer_; - BufferAllocation::Slice d_bmm2_rhs_buffer_; - BufferAllocation::Slice d_s_buffer_; - BufferAllocation::Slice d_bias_buffer_; - BufferAllocation::Slice fwd_output_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - // FusedMHA backward config - const GpufMHABackwardConfig config_; - absl::Mutex mu_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mu_); -}; -} // namespace gpu -} // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc index a7b068c4a9a0b4..9bbc6f4019eab1 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/service/global_device_id.h" @@ -36,12 +37,14 @@ namespace xla::gpu { // NcclCliqueKey //===----------------------------------------------------------------------===// -NcclCliqueKey::NcclCliqueKey(std::vector devices, - NcclStreamId stream_id, - AsyncStreamKind stream_kind) +NcclCliqueKey::NcclCliqueKey( + std::vector devices, NcclStreamId stream_id, + AsyncStreamKind stream_kind, + std::vector> participant_groups) : devices_(std::move(devices)), stream_id_(stream_id), - stream_kind_(stream_kind) {} + stream_kind_(stream_kind), + participant_groups_(std::move(participant_groups)) {} absl::Span NcclCliqueKey::devices() const { return devices_; @@ -64,12 +67,23 @@ bool NcclCliqueKey::IsSubsetOf(const NcclCliqueKey& other) const { } std::string NcclCliqueKey::ToString() const { - return absl::StrFormat("devices=[%s]; stream=%d", - GlobalDeviceIdsToString(devices_), stream_id_.value()); + std::string group_string = ""; + if (!participant_groups_.empty()) { + std::vector values; + values.reserve(participant_groups_.size()); + for (const auto& group : participant_groups_) { + values.push_back("[" + GlobalDeviceIdsToString(group) + "]"); + } + group_string = absl::StrFormat("; groups=[%s]", absl::StrJoin(values, ",")); + } + return absl::StrFormat("devices=[%s]; stream=%d%s", + GlobalDeviceIdsToString(devices_), stream_id_.value(), + group_string); } bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { - return a.devices_ == b.devices_ && a.stream_id_ == b.stream_id_; + return a.devices_ == b.devices_ && a.stream_id_ == b.stream_id_ && + a.participant_groups_ == b.participant_groups_; } bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) { diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h index 56c9b81f81e2ba..0946ce62ef7275 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h @@ -82,7 +82,8 @@ class NcclCliqueKey { explicit NcclCliqueKey( std::vector devices, NcclStreamId stream_id = NcclStreamId(0), - AsyncStreamKind stream_kind = AsyncStreamKind::kCollective); + AsyncStreamKind stream_kind = AsyncStreamKind::kCollective, + std::vector> participant_groups = {}); absl::Span devices() const; @@ -113,11 +114,23 @@ class NcclCliqueKey { std::vector devices_; NcclStreamId stream_id_; AsyncStreamKind stream_kind_; + // The full list of groups across all devices which this clique is a part of. + // When enable_nccl_comm_splitting is enabled, this is used to distinguish + // which cliques can be reused from the cache or must be split in order to + // prevent a deadlock situation. + // For example, imagine we have a communicator with devices = [0,1] and groups + // = [0, 1] Later on, we may want to create communicators [0, 1] and [2, 3] by + // splitting [0, 1, 2, 3] If ranks 0 and 1 reuse the exisiting [0, 1] clique + // but ranks 2 and 3 initiate a split, there will be a deadlock since ranks 2, + // 3 and will be waiting forever for 0, 1 to join the split. Having the + // particating groups as part of the cache key will prevent such situations + std::vector> participant_groups_; }; template H AbslHashValue(H h, const NcclCliqueKey& k) { - return H::combine(std::move(h), k.devices_, k.stream_id_); + return H::combine(std::move(h), k.devices_, k.stream_id_, + k.participant_groups_); } bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc index 4346f544db20bc..c72c5115252865 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/container/btree_map.h" #include "xla/service/global_device_id.h" @@ -53,6 +54,26 @@ TEST(NcclCliqueKeyTest, Compare) { EXPECT_GT(key1, key0); } +TEST(NcclCliqueKeyTest, CompareWithParticipantGroups) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + // The keys are not equal because the replica groups are different. + NcclCliqueKey key0({id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id0, id1}}); + NcclCliqueKey key1( + {id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id0, id1}, {id2, id3}}); + EXPECT_FALSE(key0 == key1); + + // With no replica groups, the keys are equal + NcclCliqueKey key0_nogroups({id0, id1}, NcclStreamId(0)); + NcclCliqueKey key1_nogroups({id0, id1}, NcclStreamId(0)); + EXPECT_EQ(key0_nogroups, key1_nogroups); +} + TEST(NcclCliqueKeyTest, BtreeIterationOrder) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc index 7582c18c292e72..93b113b3a25627 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -217,7 +217,7 @@ NcclCollectiveThunk::NcclCollectiveThunk(Kind kind, ThunkInfo thunk_info, nccl_api_(nccl_api), async_events_(is_sync ? nullptr : new AsyncEvents()) {} -static absl::StatusOr GetNcclCliqueKey( +absl::StatusOr GetNcclCliqueKey( const Thunk::CollectiveExecuteParams& params, const std::vector& replica_groups, CollectiveOpGroupMode group_mode, NcclStreamId stream_id, @@ -229,6 +229,18 @@ static absl::StatusOr GetNcclCliqueKey( GetParticipatingDevices(global_device_id, *params.device_assn, replica_groups, group_mode)); + // If splitting is enabled, particpating groups must match in order for a + // clique to be reused from the cache. We can ignore the particpating groups + // otherwise. + static const int64_t enable_nccl_comm_splitting = + xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_comm_splitting(); + std::vector> participant_groups; + if (enable_nccl_comm_splitting) { + TF_ASSIGN_OR_RETURN(participant_groups, + GetParticipatingDevicesGroups( + *params.device_assn, replica_groups, group_mode)); + } + if (IsGlobalNcclConfig() && (participants.size() != params.device_assn->replica_count())) { return InvalidArgument( @@ -240,7 +252,7 @@ static absl::StatusOr GetNcclCliqueKey( return NcclCliqueKey(std::move(participants), enable_per_stream_comms ? stream_id : kNoStreamId, - stream_kind); + stream_kind, std::move(participant_groups)); } absl::StatusOr GetNcclComm( @@ -373,33 +385,16 @@ absl::StatusOr NcclCollectiveThunk::AsyncEvents::GetEvent( absl::Status NcclCollectiveThunk::Prepare(const PrepareParams& params, ResourceRequests& resource_requests) { - const CollectiveExecuteParams* collectives = params.collective_params; - TF_ASSIGN_OR_RETURN( - std::vector participants, - GetParticipatingDevices(collectives->global_device_id, - *collectives->device_assn, + NcclCliqueKey clique_key, + GetNcclCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, nccl_stream_id(), + GetAsyncStreamKind())); + TF_ASSIGN_OR_RETURN( + size_t num_local_participants, + GetNumLocalParticipants(*params.collective_params, config().replica_groups, config().group_mode)); - - std::vector local_devices; - if (collectives->global_device_id_map) { - local_devices.reserve(collectives->global_device_id_map->size()); - for (const auto& entry : *collectives->global_device_id_map) { - local_devices.push_back(entry.second); - } - } - - size_t num_local_participants = GetNumLocalParticipants( - participants, - collectives->global_device_id_map ? &local_devices : nullptr); - AsyncStreamKind stream_kind = GetAsyncStreamKind(); - static const bool enable_per_stream_comms = - xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_per_stream_comms(); - return resource_requests.AddClique( - NcclCliqueKey(std::move(participants), - enable_per_stream_comms ? nccl_stream_id() : kNoStreamId, - stream_kind), - num_local_participants); + return resource_requests.AddClique(clique_key, num_local_participants); } absl::Status NcclCollectiveThunk::Initialize(const InitializeParams& params) { @@ -537,13 +532,26 @@ absl::Status IsValidOperand(Shape shape, Thunk::Kind reduction_op) { return absl::OkStatus(); } -size_t GetNumLocalParticipants( - const std::vector& participants, - const std::vector* local_devices) { - if (local_devices == nullptr) return participants.size(); +absl::StatusOr GetNumLocalParticipants( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode) { + TF_ASSIGN_OR_RETURN( + std::vector participants, + GetParticipatingDevices(params.global_device_id, *params.device_assn, + replica_groups, group_mode)); + if (!params.global_device_id_map) { + return participants.size(); + } + + std::vector local_devices; + local_devices.reserve(params.global_device_id_map->size()); + for (const auto& entry : *params.global_device_id_map) { + local_devices.push_back(entry.second); + } return absl::c_count_if(participants, [&](const GlobalDeviceId& device_id) { - return absl::c_linear_search(*local_devices, device_id); + return absl::c_linear_search(local_devices, device_id); }); } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h index ccaffb35c308a8..2a549cdd81f520 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -283,9 +283,16 @@ absl::Status AddOpDescription(absl::Status status, OpT op, //===----------------------------------------------------------------------===// -size_t GetNumLocalParticipants( - const std::vector& participants, - const std::vector* local_devices); // may be null +absl::StatusOr GetNcclCliqueKey( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode, NcclStreamId stream_id, + AsyncStreamKind stream_kind); + +absl::StatusOr GetNumLocalParticipants( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode); // Returns a nccl comm handle and a flag indicating if // it's a local communicator. diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index fc3c0cff8741c5..6f3081a90eb234 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -286,8 +286,6 @@ Thunk::ExecuteParams::ExecuteParams( CASE(kSequential); CASE(kTriangularSolve); CASE(kWhile); - CASE(kFusedMHA); - CASE(kFusedMHABackward); CASE(kWaitForStreams); CASE(kCuDnn); } diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 346664976a2d9c..cd26323ee70fea 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -165,8 +165,6 @@ class Thunk { kSendDone, kTriangularSolve, kWhile, - kFusedMHA, - kFusedMHABackward, kWaitForStreams, kCuDnn }; diff --git a/third_party/xla/xla/service/gpu/runtime_intrinsics.cc b/third_party/xla/xla/service/gpu/runtime_intrinsics.cc index 879ca6faf7c671..33bbac0f90f373 100644 --- a/third_party/xla/xla/service/gpu/runtime_intrinsics.cc +++ b/third_party/xla/xla/service/gpu/runtime_intrinsics.cc @@ -28,10 +28,11 @@ limitations under the License. #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_finder.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -51,11 +52,8 @@ absl::Status AssertOnGpu(void* stream_handle, void* buffer, TF_ASSIGN_OR_RETURN( se::Platform * platform, se::PlatformManager::PlatformWithName(GetGpuPlatformName())); - se::StreamExecutorConfig config; - config.gpu_stream = stream_handle; - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - platform->GetExecutor(config)); - se::Stream* stream = executor->FindAllocatedStream(stream_handle); + TF_ASSIGN_OR_RETURN(se::Stream * stream, + stream_executor::FindStream(platform, stream_handle)); if (!stream) { return Internal("Stream not found for: %p", stream_handle); } diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 54b29378adec0c..ea5607516d6e3d 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -40,7 +40,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/layout.h" #include "xla/literal_util.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/triton_fusion_analysis.h" diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index cc4bd98f78ed80..991a0b2c3bf8ca 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -162,44 +162,6 @@ xla_test( ], ) -xla_cc_test( - name = "gpu_reduce_scatter_creator_test", - srcs = ["gpu_reduce_scatter_creator_test.cc"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu:gpu_reduce_scatter_creator", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_all_gather_optimizer_test", - srcs = ["gpu_all_gather_optimizer_test.cc"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service/gpu:gpu_all_gather_optimizer", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - xla_test( name = "gpu_spmd_e2e_compile_test", size = "small", @@ -238,51 +200,6 @@ xla_test( ], ) -xla_cc_test( - name = "reduction_degenerate_dim_remover_test", - srcs = [ - "reduction_degenerate_dim_remover_test.cc", - ], - deps = [ - "//xla/service/gpu:reduction_degenerate_dim_remover", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_test( - name = "reduction_layout_normalizer_test", - srcs = [ - "reduction_layout_normalizer_test.cc", - ], - backends = ["gpu"], - deps = [ - "//xla:error_spec", - "//xla/service/gpu:reduction_layout_normalizer", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "tree_reduction_rewriter_test", - srcs = [ - "tree_reduction_rewriter_test.cc", - ], - deps = [ - "//xla/service/gpu:tree_reduction_rewriter", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - xla_test( name = "swap_conv_operands_test", srcs = [ @@ -316,20 +233,6 @@ xla_test( ], ) -xla_cc_test( - name = "reduction_dimension_grouper_test", - srcs = [ - "reduction_dimension_grouper_test.cc", - ], - deps = [ - "//xla/service/gpu:reduction_dimension_grouper", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - xla_test( name = "parallel_reduction_test", srcs = [ @@ -581,7 +484,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_fusible", - "//xla/service/gpu:instruction_fusion", + "//xla/service/gpu/transforms:instruction_fusion", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:test_main", ], @@ -597,9 +500,9 @@ xla_test( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_pass_pipeline", "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:instruction_fusion", - "//xla/service/gpu:multi_output_fusion", "//xla/service/gpu/transforms:fusion_merger", + "//xla/service/gpu/transforms:instruction_fusion", + "//xla/service/gpu/transforms:multi_output_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:test_main", @@ -721,11 +624,9 @@ lit_test_suite( "copy.hlo", "dot_bf16.hlo", "dynamic_update_slice_inplace.hlo", - "element_wise_row_vectorization.hlo", "fused_scatter.hlo", "fused_slice.hlo", "kernel_reuse.hlo", - "launch_dimensions.hlo", "pad_to_static.hlo", "reduce_atomic_min.hlo", "reduce_column_layout_change.hlo", @@ -769,6 +670,7 @@ lit_test_suite( "//xla/tools/hlo_opt:gpu_specs/v100.txtpb", ], default_tags = tf_cuda_tests_tags(), + hermetic_cuda_data_dir = "%S/../../../../../cuda_nvcc", tags_override = { "element_wise_row_vectorization.hlo": ["no_rocm"], "scatter_bf16.hlo": ["no_rocm"], @@ -798,6 +700,7 @@ lit_test_suite( # name = "xla-opt", # srcs = ["xla-opt.cc"], # deps = [ +# "//xla/service/gpu/fusions/transforms:passes", # "//xla/service/gpu/fusions/triton:passes", # "@llvm-project//mlir:AllExtensions", # "@llvm-project//mlir:MlirOptLib", @@ -912,7 +815,7 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_sort_rewriter", + "//xla/service/gpu/transforms:sort_rewriter", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo b/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo deleted file mode 100644 index 3e75fceb48f530..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo +++ /dev/null @@ -1,292 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/v100.txtpb --split-input-file | FileCheck %s -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK-LLVM %s -// We check that the row loads are vectorized. - -HloModule SimpleAddRowBroadcasting, is_scheduled=true - -%fused_computation.0 (param_0: f32[672], param_1: f32[512,14,14,672]) -> f32[512,14,14,672]{ - %param_0 = f32[672]{0} parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={3} - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - ROOT %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[672]{0} parameter(0) - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - - ROOT %fusion.0 = f32[512,14,14,672]{3,2,1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.0 -} - -// CHECK-LABEL: fusion_0 -// CHECK: .reqntid 168, 1, 1 -// CHECK-NOT: ld.global.nc.f -// CHECK-NOT: ld.global.nc.b - -// ----- - -HloModule SimpleAddSmallRowBroadcasting, is_scheduled=true - -%fused_computation.0 (param_0: f32[48], param_1: f32[512,14,14,48]) -> f32[512,14,14,48]{ - %param_0 = f32[48]{0} parameter(0) - %broadcast = f32[512,14,14,48]{3,2,1,0} broadcast(%param_0), dimensions={3} - %param_1 = f32[512,14,14,48]{3,2,1,0} parameter(1) - ROOT %add = f32[512,14,14,48]{3,2,1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[48]{0} parameter(0) - %param_1 = f32[512,14,14,48]{3,2,1,0} parameter(1) - - ROOT %fusion.0_small = f32[512,14,14,48]{3,2,1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.0 -} - -// CHECK-LABEL: fusion_0_small -// CHECK: .reqntid 12, 11, 1 -// CHECK-NOT: ld.global.nc.f -// CHECK-NOT: ld.global.nc.b - -// ----- - -// This test an BatchNorm fused kernel found in EfficientNet. -HloModule EfficientNetSwish, is_scheduled=true - -%fused_computation.1 (param_0.89: f32[672], param_1: f32[672], param_2: f32[672], param_3: f32[672], param_4: f16[512,14,14,672], param_5: f32[672], param_6: f16[512,14,14,672], param_7: f32[672]) -> f16[512,14,14,672] { - %param_2 = f32[672]{0} parameter(2) - %constant_157 = f32[] constant(1), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %broadcast.186 = f32[672]{0} broadcast(f32[] %constant_157), dimensions={}, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %param_5 = f32[672]{0} parameter(5) - %constant_56 = f32[] constant(9.96492327e-06) - %broadcast.185 = f32[672]{0} broadcast(f32[] %constant_56), dimensions={} - %multiply.155 = f32[672]{0} multiply(f32[672]{0} %param_5, f32[672]{0} %broadcast.185), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %param_3 = f32[672]{0} parameter(3) - %multiply.154 = f32[672]{0} multiply(f32[672]{0} %param_3, f32[672]{0} %broadcast.185), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %multiply.153 = f32[672]{0} multiply(f32[672]{0} %multiply.154, f32[672]{0} %multiply.154), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %subtract.15 = f32[672]{0} subtract(f32[672]{0} %multiply.155, f32[672]{0} %multiply.153), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %constant_155 = f32[] constant(0.001), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %broadcast.184 = f32[672]{0} broadcast(f32[] %constant_155), dimensions={} - %add.14 = f32[672]{0} add(f32[672]{0} %subtract.15, f32[672]{0} %broadcast.184), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %rsqrt.23 = f32[672]{0} rsqrt(f32[672]{0} %add.14), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %multiply.152 = f32[672]{0} multiply(f32[672]{0} %rsqrt.23, f32[672]{0} %rsqrt.23), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %divide.14 = f32[672]{0} divide(f32[672]{0} %broadcast.186, f32[672]{0} %multiply.152), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %rsqrt.7 = f32[672]{0} rsqrt(f32[672]{0} %divide.14), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.29 = f32[672]{0} multiply(f32[672]{0} %param_2, f32[672]{0} %rsqrt.7), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.28 = f32[672]{0} multiply(f32[672]{0} %multiply.29, f32[672]{0} %broadcast.185), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %broadcast.47 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %multiply.28), dimensions={3} - %param_6 = f16[512,14,14,672]{3,2,1,0} parameter(6) - %constant_194 = f16[] constant(1), metadata={op_type="AddV2" op_name="add"} - %broadcast.256 = f16[512,14,14,672]{3,2,1,0} broadcast(f16[] %constant_194), dimensions={} - %param_4 = f16[512,14,14,672]{3,2,1,0} parameter(4) - %convert.66 = f32[512,14,14,672]{3,2,1,0} convert(f16[512,14,14,672]{3,2,1,0} %param_4), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %broadcast.254 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %multiply.154), dimensions={3}, metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %subtract.82 = f32[512,14,14,672]{3,2,1,0} subtract(f32[512,14,14,672]{3,2,1,0} %convert.66, f32[512,14,14,672]{3,2,1,0} %broadcast.254), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %broadcast.251 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %rsqrt.23), dimensions={3} - %multiply.219 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %subtract.82, f32[512,14,14,672]{3,2,1,0} %broadcast.251), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %broadcast.250 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_2), dimensions={3}, metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %multiply.218 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %multiply.219, f32[512,14,14,672]{3,2,1,0} %broadcast.250), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %param_7 = f32[672]{0} parameter(7) - %broadcast.249 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_7), dimensions={3}, metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %add.79 = f32[512,14,14,672]{3,2,1,0} add(f32[512,14,14,672]{3,2,1,0} %multiply.218, f32[512,14,14,672]{3,2,1,0} %broadcast.249), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %convert.65 = f16[512,14,14,672]{3,2,1,0} convert(f32[512,14,14,672]{3,2,1,0} %add.79), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %negate.12 = f16[512,14,14,672]{3,2,1,0} negate(f16[512,14,14,672]{3,2,1,0} %convert.65) - %exponential.10 = f16[512,14,14,672]{3,2,1,0} exponential(f16[512,14,14,672]{3,2,1,0} %negate.12) - %add.78 = f16[512,14,14,672]{3,2,1,0} add(f16[512,14,14,672]{3,2,1,0} %broadcast.256, f16[512,14,14,672]{3,2,1,0} %exponential.10) - %divide.20 = f16[512,14,14,672]{3,2,1,0} divide(f16[512,14,14,672]{3,2,1,0} %broadcast.256, f16[512,14,14,672]{3,2,1,0} %add.78), metadata={op_type="Sigmoid" op_name="foo/activation/Sigmoid"} - %subtract.77 = f16[512,14,14,672]{3,2,1,0} subtract(f16[512,14,14,672]{3,2,1,0} %broadcast.256, f16[512,14,14,672]{3,2,1,0} %divide.20), metadata={op_type="Sub" op_name="sub"} - %multiply.211 = f16[512,14,14,672]{3,2,1,0} multiply(f16[512,14,14,672]{3,2,1,0} %convert.65, f16[512,14,14,672]{3,2,1,0} %subtract.77), metadata={op_type="Mul" op_name="mul"} - %add.75 = f16[512,14,14,672]{3,2,1,0} add(f16[512,14,14,672]{3,2,1,0} %broadcast.256, f16[512,14,14,672]{3,2,1,0} %multiply.211), metadata={op_type="AddV2" op_name="add"} - %multiply.210 = f16[512,14,14,672]{3,2,1,0} multiply(f16[512,14,14,672]{3,2,1,0} %divide.20, f16[512,14,14,672]{3,2,1,0} %add.75), metadata={op_type="Mul" op_name="mul_1"} - %multiply.209 = f16[512,14,14,672]{3,2,1,0} multiply(f16[512,14,14,672]{3,2,1,0} %param_6, f16[512,14,14,672]{3,2,1,0} %multiply.210), metadata={op_type="Mul" op_name="mul_2"} - %convert.8 = f32[512,14,14,672]{3,2,1,0} convert(f16[512,14,14,672]{3,2,1,0} %multiply.209), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %constant_48 = f32[] constant(100352), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %broadcast.46 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[] %constant_48), dimensions={}, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.27 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %convert.8, f32[512,14,14,672]{3,2,1,0} %broadcast.46), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %param_1 = f32[672]{0} parameter(1) - %broadcast.45 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_1), dimensions={3}, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %subtract.10 = f32[512,14,14,672]{3,2,1,0} subtract(f32[512,14,14,672]{3,2,1,0} %multiply.27, f32[512,14,14,672]{3,2,1,0} %broadcast.45), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %param_0.89 = f32[672]{0} parameter(0) - %broadcast.44 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_0.89), dimensions={3}, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.26 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %broadcast.44, f32[512,14,14,672]{3,2,1,0} %subtract.82), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %broadcast.42 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %divide.14), dimensions={3} - %divide.6 = f32[512,14,14,672]{3,2,1,0} divide(f32[512,14,14,672]{3,2,1,0} %multiply.26, f32[512,14,14,672]{3,2,1,0} %broadcast.42), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %subtract.9 = f32[512,14,14,672]{3,2,1,0} subtract(f32[512,14,14,672]{3,2,1,0} %subtract.10, f32[512,14,14,672]{3,2,1,0} %divide.6), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.25 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %broadcast.47, f32[512,14,14,672]{3,2,1,0} %subtract.9), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - ROOT %convert.7 = f16[512,14,14,672]{3,2,1,0} convert(f32[512,14,14,672]{3,2,1,0} %multiply.25), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} -} - -ENTRY main { - %param_0 = f32[672]{0} parameter(0) - %param_1 = f32[672]{0} parameter(1) - %param_2 = f32[672]{0} parameter(2) - %param_3 = f32[672]{0} parameter(3) - %param_4 = f16[512,14,14,672]{3,2,1,0} parameter(4) - %param_5 = f32[672]{0} parameter(5) - %param_6 = f16[512,14,14,672]{3,2,1,0} parameter(6) - %param_7 = f32[672]{0} parameter(7) - - ROOT %fusion.1 = f16[512,14,14,672]{3,2,1,0} fusion(f32[672]{0} %param_0, f32[672]{0} %param_1, f32[672]{0} %param_2, f32[672]{0} %param_3, f16[512,14,14,672]{3,2,1,0} %param_4, f32[672]{0} %param_5, f16[512,14,14,672]{3,2,1,0} %param_6, f32[672]{0} %param_7), kind=kLoop, calls=%fused_computation.1, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} -} - -// CHECK-LABEL: fusion_1 -// CHECK: .reqntid 168, 1, 1 -// CHECK-NOT: ld.global.nc.f -// CHECK-NOT: ld.global.nc.b - -// ----- - -HloModule TransposeOutput, is_scheduled=true - -%fused_computation.2 (param_0: f32[672], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { - %param_0 = f32[672]{0} parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={3} - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) - ROOT %copy = f32[512,14,14,672]{0,2,3,1} copy(%add) -} - -ENTRY main { - %param_0 = f32[672]{0} parameter(0) - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - - ROOT %fusion.2 = f32[512,14,14,672]{0,2,3,1} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.2 -} -// Check that we didn't do anything. The block size didn't change. -// CHECK-LABEL: fusion_2 -// CHECK: .reqntid 128, 1, 1 -// CHECK: ld.global.nc.f - -// ----- - -HloModule TransposeInput, is_scheduled=true - -%fused_computation.3 (param_0: f32[672], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { - %param_0 = f32[672]{0} parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={3} - %param_1 = f32[512,14,14,672]{0,2,3,1} parameter(1) - %copy = f32[512,14,14,672]{3,2,1,0} copy(%param_1) - ROOT %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %copy) -} - -ENTRY main { - %param_0 = f32[672]{0} parameter(0) - %param_1 = f32[512,14,14,672]{0,2,3,1} parameter(1) - - ROOT %fusion.3 = f32[512,14,14,672]{3,2,1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.3 -} -// Check that we didn't do anything. The block size didn't change. -// CHECK-LABEL: fusion_3 -// CHECK: .reqntid 128, 1, 1 -// CHECK: ld.global.nc.f - -// ----- - -HloModule ScalarBroadcasting, is_scheduled=true - -%fused_computation.5 (param_0: f32[], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { - %param_0 = f32[] parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={} - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - ROOT %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[] parameter(0) - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - - ROOT %fusion.5 = f32[512,14,14,672] fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.5 -} - -// CHECK-LABEL: fusion_5 -// CHECK: .reqntid 128, 1, 1 -// CHECK: ld.global.nc.f - -// ----- - -HloModule NotSupportedBroadcasting, is_scheduled=true - -%fused_computation.6 (param_0: f32[14,672], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { - %param_0 = f32[14,672]{1,0} parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={2,3} - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - ROOT %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[14,672]{1,0} parameter(0) - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - - ROOT %fusion.6 = f32[512,14,14,672] fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.6 -} - -// Check that we didn't do anything. The block size didn't change. -// CHECK-LABEL: fusion_6 -// CHECK: .reqntid 128, 1, 1 -// CHECK: ld.global.nc.f - -// ----- -HloModule Module, is_scheduled=true - -%fused_computation.7 { - %constant_2 = f32[] constant(0) - %broadcast.1 = f32[32,7,7,352]{2,1,3,0} broadcast(f32[] %constant_2), dimensions={} - %param_1.2 = f32[32,7,7,320]{2,1,3,0} parameter(1) - %param_2.1 = f32[32,7,7,224]{2,1,3,0} parameter(2) - %param_3.1 = f32[32,7,7,128]{2,1,3,0} parameter(3) - %tmp_8.1 = f32[32,7,7,1024]{2,1,3,0} concatenate(f32[32,7,7,352]{2,1,3,0} %broadcast.1, f32[32,7,7,320]{2,1,3,0} %param_1.2, f32[32,7,7,224]{2,1,3,0} %param_2.1, f32[32,7,7,128]{2,1,3,0} %param_3.1), dimensions={3} - %param_0.1 = f32[32,7,7,1024]{2,1,3,0} parameter(0) - ROOT %tmp_10.1 = f32[32,7,7,1024]{2,1,3,0} add(f32[32,7,7,1024]{2,1,3,0} %tmp_8.1, f32[32,7,7,1024]{2,1,3,0} %param_0.1) -} - -ENTRY %computation { - %tmp_0 = u8[32,224,224,3]{3,2,1,0} parameter(0) - %tmp_9 = f32[32,7,7,1024]{2,1,3,0} constant({...}) - %tmp_5 = f32[32,7,7,320]{2,1,3,0} constant({...}) - %tmp_6 = f32[32,7,7,224]{2,1,3,0} constant({...}) - %tmp_7 = f32[32,7,7,128]{2,1,3,0} constant({...}) - ROOT %fusion.7 = f32[32,7,7,1024]{2,1,3,0} fusion(f32[32,7,7,1024]{2,1,3,0} %tmp_9, f32[32,7,7,320]{2,1,3,0} %tmp_5, f32[32,7,7,224]{2,1,3,0} %tmp_6, f32[32,7,7,128]{2,1,3,0} %tmp_7), kind=kLoop, calls=%fused_computation.7 -} - - -// This graph triggered a bug where the new indexing was generated -// CHECK-LLVM-LABEL: @fusion_7 -// CHECK-LLVM-NOT: row_index - -// ----- -HloModule RowToLong, is_scheduled=true - -%fused_computation.1 { - %p0 = f32[2025]{0} parameter(0) - ROOT %r = f32[3025,2025]{1,0} broadcast(%p0), dimensions={1} -} - -ENTRY main { - %param_0 = f32[2025]{0} parameter(0) - ROOT %fusion.8 = f32[3025,2025]{1,0} fusion(%param_0), kind=kLoop, calls=%fused_computation.1 - -} -// Check that we didn't emit the simpler row broadcasting. -// CHECK-LLVM-LABEL: @fusion_8 -// CHECK-LLVM-NOT: row_index - -// ----- - -HloModule module, is_scheduled=true - -%fused_computation.1 { - %p0 = f16[5000,64,64,32] parameter(0) - %p1 = f16[] parameter(1) - ROOT %pad1 = f16[5000,65,65,32] pad(%p0, %p1), padding=0_0x0_1x0_1x0_0 -} - -ENTRY computation { - p0 = f16[5000,64,64,32] parameter(0) - zero = f16[] constant(0) - - ROOT %fusion.9 = f16[5000,65,65,32] fusion(p0, zero), kind=kLoop, calls=%fused_computation.1 -} - -// Our codegen can't emit a vectorized load here, but it can emit a vectorized -// store. -// CHECK-LABEL: .visible .entry fusion_9 -// CHECK-COUNT-4: ld.global.nc.u16 -// CHECK: st.global.v4.b16 diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc index 2e5db538d8a0a5..b4124f2673d958 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/gpu_sort_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -35,7 +35,7 @@ namespace { bool HloWasRewrittenToUseCubSort(const HloModule& module) { for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) { - if (pass_metadata.pass_name() == "gpu-sort-rewriter") { + if (pass_metadata.pass_name() == "sort-rewriter") { return pass_metadata.module_changed(); } } @@ -50,13 +50,13 @@ class CubSortKeysTest : public HloTestBase, public: void SetUp() override { HloTestBase::SetUp(); - GpuSortRewriter::SetSortSizeThresholdForTestingOnly(33000); + SortRewriter::SetSortSizeThresholdForTestingOnly(33000); } }; TEST_P(CubSortKeysTest, CompareToReference) { int batch_size = std::get<2>(GetParam()); - int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size; + int segment_size = SortRewriter::SortSizeThreshold() / batch_size; const char* kHloTpl = R"( HloModule TestSortKeys @@ -103,7 +103,7 @@ ENTRY m { })"; int batch_size = std::get<2>(GetParam()); - int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size; + int segment_size = SortRewriter::SortSizeThreshold() / batch_size; std::string hlo_str = absl::Substitute( kHloTpl, primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())), @@ -138,13 +138,13 @@ class CubSortPairsTest public: void SetUp() override { HloTestBase::SetUp(); - GpuSortRewriter::SetSortSizeThresholdForTestingOnly(33000); + SortRewriter::SetSortSizeThresholdForTestingOnly(33000); } }; TEST_P(CubSortPairsTest, CompareToReference) { int batch_size = std::get<3>(GetParam()); - int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size; + int segment_size = SortRewriter::SortSizeThreshold() / batch_size; const char* kHloTpl = R"( HloModule TestSortPairs @@ -216,7 +216,7 @@ ENTRY m { })"; int batch_size = std::get<3>(GetParam()); - int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size; + int segment_size = SortRewriter::SortSizeThreshold() / batch_size; std::string hlo_str = absl::Substitute( kHloTpl, primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())), diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc index 3550704716772c..aed017cbefb2fa 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -134,7 +134,7 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { EXPECT_TRUE( LiteralTestUtil::Near(expected_result, actual_result, mha_error_spec_)); - // Run FusedMHA/FuseMHABackward thunk through command buffer + // Run through command buffer DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); debug_options.set_xla_gpu_graph_min_graph_size(1); diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc index 4419cddb862d5e..cc20ef8b8484e8 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/instruction_fusion.h" -#include "xla/service/gpu/multi_output_fusion.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/gpu/transforms/fusion_merger.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" +#include "xla/service/gpu/transforms/multi_output_fusion.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/shape.h" @@ -51,8 +51,7 @@ class GpuFusionPipelineTest : public GpuCodegenTest { device_info); pipeline.AddPass(/*may_duplicate=*/true, device_info); pipeline.AddPass(device_info, ShapeSizeBytesFunction()); - pipeline.AddPass(device_info, - ShapeSizeBytesFunction()); + pipeline.AddPass(device_info, ShapeSizeBytesFunction()); RunAndFilecheckHloRewrite(hlo, std::move(pipeline), expected); } @@ -65,15 +64,17 @@ HloModule module ENTRY computation { p = f32[5000,6000]{1,0} parameter(0) e = f32[5000,6000]{1,0} sqrt(p) - c = f32[6000,5000] transpose(p), dimensions={1,0} + b = f32[1,5000,6000] reshape(p) + c = f32[1,6000,5000] transpose(b), dimensions={0,2,1} r = f32[300,20,5000] reshape(c) ROOT out = (f32[5000,6000], f32[300,20,5000]) tuple(e,r) } )"; CheckGpuFusionPipeline(hlo, R"( -// CHECK: %fused_computation (param_0.1: f32[5000,6000]) -> (f32[300,20,5000], f32[5000,6000]) { +// CHECK: %fused_computation ({{[^:]+}}: f32[5000,6000]) -> (f32[300,20,5000], f32[5000,6000]) { // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[5000,6000]{1,0} parameter(0) -// CHECK-NEXT: [[c_1_1:%[^ ]+]] = f32[6000,5000]{1,0} transpose([[param_0_1_0]]), dimensions={1,0} +// CHECK-NEXT: [[bc:%[^ ]+]] = f32[1,5000,6000]{2,1,0} reshape([[param_0_1_0]]) +// CHECK-NEXT: [[c_1_1:%[^ ]+]] = f32[1,6000,5000]{2,1,0} transpose([[bc]]), dimensions={0,2,1} // CHECK-NEXT: [[r_1_2:%[^ ]+]] = f32[300,20,5000]{2,1,0} reshape([[c_1_1]]) // CHECK-NEXT: [[e_1_3:%[^ ]+]] = f32[5000,6000]{1,0} sqrt([[param_0_1_0]]) // CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[300,20,5000]{2,1,0}, f32[5000,6000]{1,0}) tuple([[r_1_2]], [[e_1_3]]) diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc index 849cf1dcaf5bba..43c6a509239ccd 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_fusible.h" -#include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -181,18 +181,18 @@ TEST_F(TransposeFusionTest, ElementaryLogical) { HloModule module ENTRY main { - p = f32[16,32]{1,0} parameter(0) - s = sqrt(p) - ROOT c = f32[32,16]{1,0} transpose(s), dimensions={1,0} + p = f32[1,16,32]{2,1,0} parameter(0) + s = f32[1,16,32]{2,1,0} sqrt(p) + ROOT c = f32[1,32,16]{2,1,0} transpose(s), dimensions={0,2,1} } )"; CheckGpuFusion(hlo, R"( -// CHECK: %fused_computation (param_0.1: f32[16,32]) -> f32[32,16] { -// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) -// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) -// CHECK-NEXT: ROOT [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0} -// CHECK: ROOT [[fusion_3:%[^ ]+]] = f32[32,16]{1,0} fusion([[p_4:%[^ ]+]]), kind=kInput, calls=[[fused_computation_5:%[^ ]+]] +// CHECK: %fused_computation ({{[^:]+}}: f32[1,16,32]) -> f32[1,32,16] { +// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[1,16,32]{2,1,0} parameter(0) +// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[1,16,32]{2,1,0} sqrt([[param_0_1_0]]) +// CHECK-NEXT: ROOT [[c_1_2:%[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[s_1_1]]), dimensions={0,2,1} +// CHECK: ROOT [[fusion_3:%[^ ]+]] = f32[1,32,16]{2,1,0} fusion([[p_4:%[^ ]+]]), kind=kInput, calls=[[fused_computation_5:%[^ ]+]] )"); } diff --git a/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo b/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo deleted file mode 100644 index 3d05dcf9892ad5..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo +++ /dev/null @@ -1,338 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s -// This tests that we do not increase the grid launch size when -// few_waves is enabled. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @wrapped_b -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 2} -// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 1024} - - -HloModule Test, is_scheduled=true - -fused_computation { - param_0 = f32[100,20]{1,0} parameter(0) - ROOT b.1 = f32[100,20]{1,0} round-nearest-even(f32[100,20]{1,0} param_0) -} - -ENTRY main { - a = f32[100, 20]{1,0} parameter(0) - ROOT wrapped_b = f32[100,20]{1,0} fusion(f32[100,20]{1,0} a), kind=kLoop, calls=fused_computation -} - -// ----- - -// This tests that we cap grid launch code when few_waves is enabled. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @wrapped_b -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule Test, is_scheduled=true - -fused_computation { - param_0 = f32[10000,10000]{1,0} parameter(0) - ROOT b.1 = f32[10000,10000]{1,0} round-nearest-even(f32[10000,10000]{1,0} param_0) -} - -ENTRY main { - a = f32[10000, 10000]{1,0} parameter(0) - ROOT wrapped_b = f32[10000,10000]{1,0} fusion(f32[10000,10000]{1,0} a), kind=kLoop, calls=fused_computation -} - -// ----- - -// This tests that we cap grid launch code when few_waves is enabled -// and scalar broadcast are present. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion_3 -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule ScalarBroadcast, is_scheduled=true - -%fused_computation.3 (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { - %param_0 = f32[] parameter(0) - %broadcast = f32[10000, 10000]{1,0} broadcast(%param_0), dimensions={} - %param_1 = f32[10000, 10000]{1,0} parameter(1) - ROOT %add = f32[10000, 10000]{1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[] parameter(0) - %param_1 = f32[10000, 10000]{1,0} parameter(1) - - ROOT %fusion.3 = f32[10000, 10000]{1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.3 -} - -// ----- - -// This tests that we enable few_waves in a simple fusion. It is the baseline -// for the tests below. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule SimpleFusion, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { - %param_0 = f32[10000,10000] parameter(0) - %param_1 = f32[10000, 10000]{1,0} parameter(1) - ROOT %add = f32[10000, 10000]{1,0} add(%param_0, %param_1) -} - -ENTRY main { - %param_0 = f32[10000, 10000]{1,0} parameter(0) - %param_1 = f32[10000, 10000]{1,0} parameter(1) - - ROOT %fusion = f32[10000, 10000]{1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we keep few_waves enabled for large constants. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule LargeConstant, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { - %param_0 = f32[10000,10000] parameter(0) - %c0 = f32[10000,10000] constant(0) - ROOT %add = f32[10000, 10000]{1,0} add(%param_0, %c0) -} - -ENTRY main { - %param_0 = f32[10000, 10000] parameter(0) - - ROOT %fusion = f32[10000, 10000]{1,0} fusion(%param_0), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we disable few_waves if a non-elementwise op is present. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 195313} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 97657} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 256} - -HloModule NonElementwise, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { - %param_0 = f32[10000,10000] parameter(0) - %reverse = f32[10000,10000]{1,0} reverse(%param_0), dimensions={0,1} - %param_1 = f32[10000, 10000]{1,0} parameter(1) - ROOT %add = f32[10000, 10000]{1,0} add(%reverse, %param_1) -} - -ENTRY main { - %param_0 = f32[10000, 10000]{1,0} parameter(0) - %param_1 = f32[10000, 10000]{1,0} parameter(1) - - ROOT %fusion = f32[10000, 10000]{1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we disable few_waves if -// - a tensor broadcast is present -// - at least four big inputs are present -// - the fusion is not row-vectorizable -// It serves as a baseline for the tests below. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 7813} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 3907} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 256} - -HloModule NoFewWaves, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[2000, 2000]) -> f32[2000, 2000] { - %param_0 = f32[2000] parameter(0) - %broadcast = f32[2000, 2000]{1,0} broadcast(%param_0), dimensions={0} - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{0,1} parameter(2) - %param_3 = f32[2000, 2000]{0,1} parameter(3) - %param_4 = f32[2000, 2000]{0,1} parameter(4) - - %sum.0 = f32[2000, 2000] add(%param_1, %param_2) - %sum.1 = f32[2000, 2000] add(%sum.0, %param_3) - %sum.2 = f32[2000, 2000] add(%sum.1, %param_4) - ROOT %add = f32[2000, 2000]{1,0} add(%sum.2, %broadcast) -} - -ENTRY main { - %param_0 = f32[2000]{0} parameter(0) - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{0,1} parameter(2) - %param_3 = f32[2000, 2000]{0,1} parameter(3) - %param_4 = f32[2000, 2000]{0,1} parameter(4) - - ROOT %fusion = f32[2000, 2000]{1,0} fusion(%param_0, %param_1, %param_2, %param_3, %param_4), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we enable few_waves if -// - a tensor broadcast is present -// - THREE big inputs are present -// - the fusion IS row-vectorizable -// In this case, the block count is changed from 7813 to 2000. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 2000} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 500} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 2000} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 500} - -HloModule RowVectorizable, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[2000, 2000]) -> f32[2000, 2000] { - %param_0 = f32[2000] parameter(0) - %broadcast = f32[2000, 2000]{1,0} broadcast(%param_0), dimensions={1} - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{1,0} parameter(2) - %param_3 = f32[2000, 2000]{1,0} parameter(3) - - %sum.0 = f32[2000, 2000] add(%param_1, %param_2) - %sum.1 = f32[2000, 2000] add(%sum.0, %param_3) - ROOT %add = f32[2000, 2000]{1,0} add(%sum.1, %broadcast) -} - -ENTRY main { - %param_0 = f32[2000]{0} parameter(0) - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{1,0} parameter(2) - %param_3 = f32[2000, 2000]{1,0} parameter(3) - - ROOT %fusion = f32[2000, 2000]{1,0} fusion(%param_0, %param_1, %param_2, %param_3), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we enable few_waves if -// - a SCALAR broadcast is present -// - four big inputs are present -// - the fusion is not row-vectorizable -// In this case, the block count is changed from 7813 to 1008. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule ScalarBroadcastFourInputs, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[2000, 2000]) -> f32[2000, 2000] { - %param_0 = f32[] parameter(0) - %broadcast = f32[2000, 2000]{1,0} broadcast(%param_0), dimensions={} - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{1,0} parameter(2) - %param_3 = f32[2000, 2000]{1,0} parameter(3) - %param_4 = f32[2000, 2000]{1,0} parameter(4) - - %sum.0 = f32[2000, 2000] add(%param_1, %param_2) - %sum.1 = f32[2000, 2000] add(%sum.0, %param_3) - %sum.2 = f32[2000, 2000] add(%sum.1, %param_4) - ROOT %add = f32[2000, 2000]{1,0} add(%sum.2, %broadcast) -} - -ENTRY main { - %param_0 = f32[] parameter(0) - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{1,0} parameter(2) - %param_3 = f32[2000, 2000]{1,0} parameter(3) - %param_4 = f32[2000, 2000]{1,0} parameter(4) - - ROOT %fusion = f32[2000, 2000]{1,0} fusion(%param_0, %param_1, %param_2, %param_3, %param_4), kind=kLoop, calls=%fused_computation -} - -// ----- -// This tests the GELU kernel. The original kernel that -// motivated few_waves implementation. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule Test, is_scheduled=true - -%fused_computation (param_0: f16[6,512,4096]) -> f16[6,512,4096] { - %param_0 = f16[6,512,4096]{2,1,0} parameter(0) - %power.tmp.1 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %param_0) - %power.0 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %power.tmp.1, f16[6,512,4096]{2,1,0} %param_0) - %constant_4 = f16[] constant(0.044708), metadata={op_type="Mul" op_name="mul"} - %broadcast.3 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_4), dimensions={}, metadata={op_type="Mul" op_name="mul"} - %multiply.3 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %power.0, f16[6,512,4096]{2,1,0} %broadcast.3), metadata={op_type="Mul" op_name="mul"} - %add.1 = f16[6,512,4096]{2,1,0} add(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %multiply.3), metadata={op_type="AddV2" op_name="add"} - %constant_2 = f16[] constant(0.79785), metadata={op_type="Mul" op_name="mul_1"} - %broadcast.2 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_2), dimensions={}, metadata={op_type="Mul" op_name="mul_1"} - %multiply.2 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %add.1, f16[6,512,4096]{2,1,0} %broadcast.2), metadata={op_type="Mul" op_name="mul_1"} - %tanh.0 = f16[6,512,4096]{2,1,0} tanh(f16[6,512,4096]{2,1,0} %multiply.2), metadata={op_type="Tanh" op_name="Tanh"} - %constant_1 = f16[] constant(1), metadata={op_type="AddV2" op_name="add_1"} - %broadcast.1 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_1), dimensions={}, metadata={op_type="AddV2" op_name="add_1"} - %add.0 = f16[6,512,4096]{2,1,0} add(f16[6,512,4096]{2,1,0} %tanh.0, f16[6,512,4096]{2,1,0} %broadcast.1), metadata={op_type="AddV2" op_name="add_1"} - %constant_0 = f16[] constant(0.5), metadata={op_type="Mul" op_name="mul_2"} - %broadcast.0 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_0), dimensions={}, metadata={op_type="Mul" op_name="mul_2"} - %multiply.1 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %add.0, f16[6,512,4096]{2,1,0} %broadcast.0), metadata={op_type="Mul" op_name="mul_2"} - ROOT %multiply.0 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %multiply.1), metadata={op_type="Mul" op_name="mul_3"} -} - -ENTRY %cluster_0__XlaCompiledKernel_true__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.24 (arg0.1: f16[6,512,4096]) -> f16[6,512,4096] { - %arg0.1 = f16[6,512,4096]{2,1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} - ROOT %fusion = f16[6,512,4096]{2,1,0} fusion(f16[6,512,4096]{2,1,0} %arg0.1), kind=kLoop, calls=%fused_computation, metadata={op_type="Mul" op_name="mul_3"} -} diff --git a/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc b/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc index b8a263df629a69..b999d1c05fa079 100644 --- a/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc @@ -61,7 +61,7 @@ TEST_F(TransposeEmitterTest, SimpleLogicalTranspose) { )"; CompileAndVerifyIr(kHloString, MakePlatformSpecificLlvm(expected_ir), /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); + /*run_optimization_passes=*/true); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -90,7 +90,8 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} + bc = f32[1,16,32]{2,1,0} bitcast(%s.1) + %t.1 = f32[1,32,16]{2,1,0} transpose(bc), dimensions={0,2,1} b = f32[32,16,1]{2,1,0} bitcast(%t.1) ROOT o = f32[32,16,1]{2,1,0} sqrt(b) } @@ -116,8 +117,10 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} - %t1.1 = f32[32,16]{1,0} transpose(%param_0.1), dimensions={1,0} + %bc.1 = f32[1,16,32]{2,1,0} bitcast(%s.1) + %bc.2 = f32[1,16,32]{2,1,0} bitcast(%param_0.1) + %t.1 = f32[1,32,16]{2,1,0} transpose(%bc.1), dimensions={0,2,1} + %t1.1 = f32[1,32,16]{2,1,0} transpose(%bc.2), dimensions={0,2,1} %r.1 = f32[32,16,1]{2,1,0} reshape(%t.1) %r1.1 = f32[32,16,1]{2,1,0} reshape(%t1.1) ROOT %tuple = (f32[32,16,1]{2,1,0}, f32[32,16,1]{2,1,0}) tuple(%r.1, %r1.1) @@ -170,14 +173,16 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %c.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} - %c1.1 = f32[32,16]{1,0} transpose(%param_0.1), dimensions={1,0} - ROOT %tuple = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(%c.1, %c1.1) + %bc.1 = f32[1,16,32]{2,1,0} bitcast(%s.1) + %bc.2 = f32[1,16,32]{2,1,0} bitcast(%param_0.1) + %c.1 = f32[1,32,16]{2,1,0} transpose(%bc.1), dimensions={0,2,1} + %c1.1 = f32[1,32,16]{2,1,0} transpose(%bc.2), dimensions={0,2,1} + ROOT %tuple = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) tuple(%c.1, %c1.1) } ENTRY main { %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation + ROOT %fusion = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) fusion(%p), kind=kInput, calls=%fused_computation } )"; @@ -251,14 +256,15 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %c.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} + %bc.1 = f32[1,16,32]{2,1,0} bitcast(%s.1) + %c.1 = f32[1,32,16]{2,1,0} transpose(%bc.1), dimensions={0,2,1} %c1.1 = f32[16,32]{1,0} exponential(%param_0.1) - ROOT %tuple = (f32[32,16]{1,0}, f32[16,32]{1,0}) tuple(%c.1, %c1.1) + ROOT %tuple = (f32[1,32,16]{2,1,0}, f32[16,32]{1,0}) tuple(%c.1, %c1.1) } ENTRY entry { %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[32,16]{1,0}, f32[16,32]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation + ROOT %fusion = (f32[1,32,16]{2,1,0}, f32[16,32]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation } )"; diff --git a/third_party/xla/xla/service/gpu/tests/xla-opt.cc b/third_party/xla/xla/service/gpu/tests/xla-opt.cc index cd5eeb835f143b..f27b6f82366230 100644 --- a/third_party/xla/xla/service/gpu/tests/xla-opt.cc +++ b/third_party/xla/xla/service/gpu/tests/xla-opt.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mlir/InitAllExtensions.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/fusions/triton/passes.h" #include "third_party/triton/bin/RegisterTritonDialects.h" @@ -23,6 +24,7 @@ int main(int argc, char **argv) { mlir::registerAllExtensions(registry); registerTritonDialects(registry); // This registers all passes as well. xla::gpu::registerTritonFusionTransformsPasses(); + xla::gpu::registerGpuFusionTransformsPasses(); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "xla-opt modular optimizer driver\n", registry)); diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index d17acdd8176a24..ea47be050b7385 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -20,6 +20,45 @@ package( licenses = ["notice"], ) +cc_library( + name = "algebraic_simplifier", + srcs = [ + "algebraic_simplifier.cc", + ], + hdrs = [ + "algebraic_simplifier.h", + ], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:algebraic_simplifier", + "//xla/service:hlo_pass", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "algebraic_simplifier_test", + srcs = ["algebraic_simplifier_test.cc"], + deps = [ + ":algebraic_simplifier", + "//xla/hlo/ir:hlo", + "//xla/service:algebraic_simplifier", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + # End-to-end tested via //third_party/tensorflow/compiler/xla/service/gpu:dot_algorithm_support_test cc_library( name = "algorithm_checker", @@ -74,6 +113,40 @@ xla_cc_test( ], ) +cc_library( + name = "all_gather_optimizer", + srcs = ["all_gather_optimizer.cc"], + hdrs = ["all_gather_optimizer.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "all_gather_optimizer_test", + srcs = ["all_gather_optimizer_test.cc"], + deps = [ + ":all_gather_optimizer", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "all_reduce_blueconnect", srcs = ["all_reduce_blueconnect.cc"], @@ -121,17 +194,102 @@ xla_cc_test( ], ) +cc_library( + name = "all_reduce_splitter", + srcs = ["all_reduce_splitter.cc"], + hdrs = ["all_reduce_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_opt_utils", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "all_reduce_splitter_test", + srcs = ["all_reduce_splitter_test.cc"], + deps = [ + ":all_reduce_splitter", + ":reduce_scatter_creator", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass_pipeline", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "async_collective_annotator", + srcs = ["async_collective_annotator.cc"], + hdrs = ["async_collective_annotator.h"], + deps = [ + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "async_collective_annotator_test", + srcs = ["async_collective_annotator_test.cc"], + deps = [ + ":async_collective_annotator", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "async_wrapper", srcs = ["async_wrapper.cc"], hdrs = ["async_wrapper.h"], deps = [ "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", - "//xla/service:hlo_proto_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", ], @@ -273,6 +431,139 @@ xla_cc_test( ], ) +cc_library( + name = "conv_padding_legalization", + srcs = ["conv_padding_legalization.cc"], + hdrs = ["conv_padding_legalization.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:shape_inference", + "//xla/service/gpu:cublas_cudnn", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "conv_padding_legalization_test", + srcs = ["conv_padding_legalization_test.cc"], + deps = [ + ":conv_padding_legalization", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:cublas_cudnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "conv_rewriter", + srcs = ["conv_rewriter.cc"], + hdrs = ["conv_rewriter.h"], + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "conv_rewriter_test", + srcs = ["conv_rewriter_test.cc"], + deps = [ + ":conv_rewriter", + "//xla:array4d", + "//xla:literal_util", + "//xla:protobuf_util", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:shape_inference", + "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "convert_async_collectives_to_sync", + srcs = ["convert_async_collectives_to_sync.cc"], + hdrs = ["convert_async_collectives_to_sync.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:convert_async_collectives_to_sync", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "convert_async_collectives_to_sync_test", + srcs = ["convert_async_collectives_to_sync_test.cc"], + deps = [ + ":convert_async_collectives_to_sync", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "copy_fusion", srcs = ["copy_fusion.cc"], @@ -435,6 +726,7 @@ xla_test( ]), shard_count = 10, deps = [ + ":conv_rewriter", ":cudnn_fused_conv_rewriter", "//xla:comparison_util", "//xla:error_spec", @@ -450,7 +742,6 @@ xla_test( "//xla/service:reshape_mover", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:gpu_conv_rewriter", "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", @@ -834,9 +1125,9 @@ xla_cc_test( # TODO(b/358278858): Currently lacking test coverage. cc_library( - name = "cudnn_workspace_rewriter", - srcs = if_cuda_is_configured(["cudnn_workspace_rewriter.cc"]), - hdrs = if_cuda_is_configured(["cudnn_workspace_rewriter.h"]), + name = "cudnn_custom_call_compiler", + srcs = if_cuda_is_configured(["cudnn_custom_call_compiler.cc"]), + hdrs = if_cuda_is_configured(["cudnn_custom_call_compiler.h"]), deps = if_cuda_is_configured([ "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -851,9 +1142,9 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/service/gpu/runtime:cudnn_thunk", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:gpu_fused_mha_runner", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:stream_executor_util", "//xla/stream_executor:dnn", @@ -942,28 +1233,73 @@ xla_test( ) cc_library( - name = "dot_sparsity_rewriter", - srcs = ["dot_sparsity_rewriter.cc"], - hdrs = ["dot_sparsity_rewriter.h"], + name = "dot_operand_converter", + srcs = ["dot_operand_converter.cc"], + hdrs = ["dot_operand_converter.h"], deps = [ + "//xla:shape_util", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", + "//xla/service:op_expander_pass", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:errors", ], ) -xla_cc_test( - name = "dot_sparsity_rewriter_test", - srcs = ["dot_sparsity_rewriter_test.cc"], - deps = [ - ":dot_sparsity_rewriter", +xla_test( + name = "dot_operand_converter_test", + srcs = if_gpu_is_configured(["dot_operand_converter_test.cc"]), + backends = [ + "gpu_a100", + "gpu_p100", + "gpu_v100", + "gpu_amd_any", + ], + deps = if_gpu_is_configured( + [ + ":dot_operand_converter", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", + ], + ) + [ + # b/317293391 + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "dot_sparsity_rewriter", + srcs = ["dot_sparsity_rewriter.cc"], + hdrs = ["dot_sparsity_rewriter.h"], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "dot_sparsity_rewriter_test", + srcs = ["dot_sparsity_rewriter_test.cc"], + deps = [ + ":dot_sparsity_rewriter", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", @@ -1027,10 +1363,10 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/ffi:ffi_api", - "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", "//xla/service:custom_call_target_registry", "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:gpu_constants", @@ -1053,7 +1389,11 @@ cc_library( xla_cc_test( name = "dynamic_slice_fusion_rewriter_test", - srcs = if_cuda_is_configured(["dynamic_slice_fusion_rewriter_test.cc"]), + srcs = ["dynamic_slice_fusion_rewriter_test.cc"], + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":dynamic_slice_fusion_rewriter", "//xla:shape_util", @@ -1399,3 +1739,1305 @@ cc_library( "@local_tsl//tsl/platform:statusor", ]), ) + +cc_library( + name = "horizontal_input_fusion", + srcs = ["horizontal_input_fusion.cc"], + hdrs = ["horizontal_input_fusion.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service/gpu:gpu_fusible", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "horizontal_input_fusion_test", + srcs = ["horizontal_input_fusion_test.cc"], + backends = ["gpu"], + deps = [ + ":horizontal_input_fusion", + "//xla:error_spec", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "horizontal_loop_fusion", + srcs = ["horizontal_loop_fusion.cc"], + hdrs = ["horizontal_loop_fusion.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:sub_byte_normalization", + "//xla/service/gpu:gpu_fusible", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "horizontal_loop_fusion_test", + srcs = ["horizontal_loop_fusion_test.cc"], + backends = ["gpu"], + deps = [ + ":horizontal_loop_fusion", + ":instruction_fusion", + "//xla:error_spec", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_dce", + "//xla/service:hlo_parser", + "//xla/service:hlo_pass", + "//xla/service:hlo_pass_pipeline", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + ], +) + +cc_library( + name = "instruction_fusion", + srcs = ["instruction_fusion.cc"], + hdrs = ["instruction_fusion.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:fusion_node_indexing_evaluation", + "//xla/service:fusion_queue", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:gpu_fusible", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "instruction_fusion_test", + srcs = ["instruction_fusion_test.cc"], + tags = [ + "nomsan", + "not_run:arm", + ], + deps = [ + ":instruction_fusion", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/tests:hlo_test_base", + "//xla/tests:test_utils", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "layout_assignment", + srcs = ["layout_assignment.cc"], + hdrs = ["layout_assignment.h"], + deps = [ + "//xla:shape_layout", + "//xla:shape_util", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:computation_layout", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:layout_assignment", + "//xla/service:logical_buffer", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:reduction_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "//xla/tsl/util:env_var", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "layout_assignment_test", + srcs = ["layout_assignment_test.cc"], + deps = [ + ":layout_assignment", + "//xla:shape_layout", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:computation_layout", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "move_copy_to_users", + srcs = ["move_copy_to_users.cc"], + hdrs = ["move_copy_to_users.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "move_copy_to_users_test", + srcs = ["move_copy_to_users_test.cc"], + deps = [ + ":move_copy_to_users", + "//xla/service:layout_assignment", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_dfs_reachability", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_graph_dumper", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "multi_output_fusion_test", + srcs = ["multi_output_fusion_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":multi_output_fusion", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "pipelined_p2p_rewriter", + srcs = ["pipelined_p2p_rewriter.cc"], + hdrs = ["pipelined_p2p_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "pipelined_p2p_rewriter_test", + srcs = ["pipelined_p2p_rewriter_test.cc"], + deps = [ + ":pipelined_p2p_rewriter", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "priority_fusion", + srcs = ["priority_fusion.cc"], + hdrs = ["priority_fusion.h"], + deps = [ + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:dump", + "//xla/service:fusion_queue", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_graph_dumper", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:fusion_process_dump_proto_cc", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/service/gpu/model:fusion_analysis_cache", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:gpu_indexing_performance_model", + "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", + "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/service/gpu/model:tiled_hlo_computation", + "//xla/service/gpu/model:triton_emitter_constraints", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "priority_fusion_test", + srcs = ["priority_fusion_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + tags = ["no_pip"], + deps = [ + ":priority_fusion", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "reduce_scatter_creator", + srcs = ["reduce_scatter_creator.cc"], + hdrs = ["reduce_scatter_creator.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_opt_utils", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "reduce_scatter_creator_test", + srcs = ["reduce_scatter_creator_test.cc"], + deps = [ + ":reduce_scatter_creator", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "reduction_degenerate_dim_remover", + srcs = ["reduction_degenerate_dim_remover.cc"], + hdrs = ["reduction_degenerate_dim_remover.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduction_degenerate_dim_remover_test", + srcs = [ + "reduction_degenerate_dim_remover_test.cc", + ], + deps = [ + ":reduction_degenerate_dim_remover", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "reduction_dimension_grouper", + srcs = ["reduction_dimension_grouper.cc"], + hdrs = ["reduction_dimension_grouper.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduction_dimension_grouper_test", + srcs = [ + "reduction_dimension_grouper_test.cc", + ], + deps = [ + ":reduction_dimension_grouper", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "reduction_layout_normalizer", + srcs = ["reduction_layout_normalizer.cc"], + hdrs = ["reduction_layout_normalizer.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "reduction_layout_normalizer_test", + srcs = [ + "reduction_layout_normalizer_test.cc", + ], + backends = ["gpu"], + deps = [ + ":reduction_layout_normalizer", + "//xla:error_spec", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "reduction_splitter", + srcs = ["reduction_splitter.cc"], + hdrs = ["reduction_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:reduction_utils", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduction_splitter_test", + srcs = ["reduction_splitter_test.cc"], + deps = [ + ":reduction_splitter", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "rename_fusions", + srcs = ["rename_fusions.cc"], + hdrs = ["rename_fusions.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "rename_fusions_test", + srcs = ["rename_fusions_test.cc"], + deps = [ + ":rename_fusions", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "sanitize_constant_names", + srcs = ["sanitize_constant_names.cc"], + hdrs = ["sanitize_constant_names.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:name_uniquer", + "//xla/service/llvm_ir:buffer_assignment_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "sanitize_constant_names_test", + srcs = ["sanitize_constant_names_test.cc"], + deps = [ + ":sanitize_constant_names", + "//xla:literal_util", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "scatter_slice_simplifier", + srcs = ["scatter_slice_simplifier.cc"], + hdrs = ["scatter_slice_simplifier.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +# TODO(b/358278858): Currently lacking test coverage. +cc_library( + name = "scatter_expander", + srcs = ["scatter_expander.cc"], + hdrs = ["scatter_expander.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:scatter_expander", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "scatter_slice_simplifier_test", + srcs = ["scatter_slice_simplifier_test.cc"], + deps = [ + ":scatter_slice_simplifier", + "//xla:shape_util", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "schedule_postprocessing", + srcs = ["schedule_postprocessing.cc"], + hdrs = ["schedule_postprocessing.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "schedule_postprocessing_test", + srcs = ["schedule_postprocessing_test.cc"], + deps = [ + ":schedule_postprocessing", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "scheduling_instruction_annotator", + srcs = ["scheduling_instruction_annotator.cc"], + hdrs = ["scheduling_instruction_annotator.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "scheduling_instruction_annotator_test", + srcs = ["scheduling_instruction_annotator_test.cc"], + deps = [ + ":scheduling_instruction_annotator", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "softmax_rewriter_triton", + srcs = ["softmax_rewriter_triton.cc"], + hdrs = ["softmax_rewriter_triton.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/service/gpu/model:fusion_analysis_cache", + "//xla/service/gpu/model:gpu_indexing_performance_model", + "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/service/gpu/model:tiled_hlo_computation", + "//xla/service/gpu/model:triton_emitter_constraints", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "softmax_rewriter_triton_test", + srcs = ["softmax_rewriter_triton_test.cc"], + deps = [ + ":softmax_rewriter_triton", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:instruction_fusion", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "sort_rewriter", + srcs = if_gpu_is_configured( + ["sort_rewriter.cc"], + ["sort_rewriter_stub.cc"], + ), + hdrs = ["sort_rewriter.h"], + deps = [ + "//xla:comparison_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:stable_sort_expander", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu/runtime:cub_sort_thunk", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "sort_rewriter_test", + srcs = if_cuda_is_configured(["sort_rewriter_test.cc"]), + backends = ["gpu"], + tags = ["no_oss"], + deps = [ + ":sort_rewriter", + "//xla:error_spec", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:cublas_cudnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "stream_attribute_annotator", + srcs = ["stream_attribute_annotator.cc"], + hdrs = ["stream_attribute_annotator.h"], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "stream_attribute_annotator_test", + srcs = ["stream_attribute_annotator_test.cc"], + deps = [ + ":stream_attribute_annotator", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "stream_attribute_async_wrapper", + srcs = ["stream_attribute_async_wrapper.cc"], + hdrs = ["stream_attribute_async_wrapper.h"], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "stream_attribute_async_wrapper_test", + srcs = ["stream_attribute_async_wrapper_test.cc"], + deps = [ + ":stream_attribute_async_wrapper", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "topk_specializer", + srcs = ["topk_specializer.cc"], + hdrs = ["topk_specializer.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:hlo_proto_cc", + "//xla/service:tuple_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_test( + name = "topk_specializer_test", + srcs = ["topk_specializer_test.cc"], + backends = ["gpu"], + deps = [ + ":topk_specializer", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:platform_util", + "//xla/service:topk_rewriter", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "topk_splitter", + srcs = ["topk_splitter.cc"], + hdrs = ["topk_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "topk_splitter_test", + srcs = ["topk_splitter_test.cc"], + deps = [ + ":topk_splitter", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_dce", + "//xla/service:pattern_matcher", + "//xla/service:topk_rewriter", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "transpose_dimension_grouper", + srcs = ["transpose_dimension_grouper.cc"], + hdrs = ["transpose_dimension_grouper.h"], + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "transpose_dimension_grouper_test", + srcs = [ + "transpose_dimension_grouper_test.cc", + ], + deps = [ + ":transpose_dimension_grouper", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "tree_reduction_rewriter", + srcs = ["tree_reduction_rewriter.cc"], + hdrs = ["tree_reduction_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service/gpu:reduction_utils", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "tree_reduction_rewriter_test", + srcs = [ + "tree_reduction_rewriter_test.cc", + ], + deps = [ + ":tree_reduction_rewriter", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +# TODO(b/358278858): Currently lacking test coverage. +cc_library( + name = "triangular_solve_rewriter", + srcs = ["triangular_solve_rewriter.cc"], + hdrs = ["triangular_solve_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service/gpu:cublas_cudnn", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "triton_fusion_numerics_verifier", + srcs = ["triton_fusion_numerics_verifier.cc"], + hdrs = ["triton_fusion_numerics_verifier.h"], + tags = ["gpu"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:shaped_buffer", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:buffer_comparator", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/autotuning:autotuner_compile_util", + "//xla/service/gpu/autotuning:autotuner_util", + "//xla/stream_executor:stream", + "//xla/tools:hlo_decomposer_lib", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "triton_fusion_numerics_verifier_test", + srcs = ["triton_fusion_numerics_verifier_test.cc"], + backend_tags = {"gpu": [ + "requires-gpu-sm80", + ]}, + backends = ["gpu"], + deps = [ + ":triton_fusion_numerics_verifier", + "//xla:shape_util", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service:platform_util", + "//xla/service/gpu/autotuning:autotuner_compile_util", + "//xla/service/gpu/autotuning:autotuner_util", + "//xla/stream_executor:platform", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "variadic_op_splitter", + srcs = ["variadic_op_splitter.cc"], + hdrs = ["variadic_op_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "variadic_op_splitter_test", + srcs = ["variadic_op_splitter_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":variadic_op_splitter", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "windowed_einsum_handler", + srcs = ["windowed_einsum_handler.cc"], + hdrs = ["windowed_einsum_handler.h"], + deps = [ + "//xla:literal", + "//xla:literal_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service:shape_inference", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "windowed_einsum_handler_test", + srcs = ["windowed_einsum_handler_test.cc"], + deps = [ + ":windowed_einsum_handler", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "pgle_accuracy_checker", + srcs = ["pgle_accuracy_checker.cc"], + hdrs = ["pgle_accuracy_checker.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:profile_guided_latency_estimator", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "pgle_accuracy_checker_test", + srcs = ["pgle_accuracy_checker_test.cc"], + deps = [ + ":pgle_accuracy_checker", + "//xla/hlo/ir:hlo", + "//xla/service:latency_hiding_scheduler", + "//xla/service:profile_guided_latency_estimator", + "//xla/service/gpu:gpu_latency_hiding_scheduler", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc similarity index 94% rename from third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc rename to third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc index 21e8d6ca7c0bce..d59ae2b6a1d039 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_algebraic_simplifier.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "absl/log/check.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h similarity index 93% rename from third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h rename to third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h index 855359654395a0..f29b31e8bb737b 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_ -#define XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_ #include @@ -75,4 +75,4 @@ class GpuAlgebraicSimplifier : public AlgebraicSimplifier { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc rename to third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc index 135ddb12ddf0db..c1e52e90a417c0 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_algebraic_simplifier.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" #include diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc rename to third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc index fe2d2d1e145140..2f7c130fab449e 100644 --- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_all_gather_optimizer.h" +#include "xla/service/gpu/transforms/all_gather_optimizer.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h similarity index 88% rename from third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h rename to third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h index e28e42246910f9..988c1f6a1bd5ba 100644 --- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_ -#define XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -43,4 +43,4 @@ class AllGatherOptimizer : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc rename to third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc index 5db5ffd47def70..27f6d65df781a1 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_all_gather_optimizer.h" +#include "xla/service/gpu/transforms/all_gather_optimizer.h" #include #include diff --git a/third_party/xla/xla/service/all_reduce_splitter.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc similarity index 99% rename from third_party/xla/xla/service/all_reduce_splitter.cc rename to third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc index ce1e0e2bc37fcd..51f71c06c800af 100644 --- a/third_party/xla/xla/service/all_reduce_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_splitter.h" +#include "xla/service/gpu/transforms/all_reduce_splitter.h" #include #include diff --git a/third_party/xla/xla/service/all_reduce_splitter.h b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h similarity index 94% rename from third_party/xla/xla/service/all_reduce_splitter.h rename to third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h index ac8dec7afa7833..91e081163035b1 100644 --- a/third_party/xla/xla/service/all_reduce_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_ALL_REDUCE_SPLITTER_H_ -#define XLA_SERVICE_ALL_REDUCE_SPLITTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -74,4 +74,4 @@ class AllReduceSplitter : public HloModulePass { } // namespace xla -#endif // XLA_SERVICE_ALL_REDUCE_SPLITTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_ diff --git a/third_party/xla/xla/service/all_reduce_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc similarity index 99% rename from third_party/xla/xla/service/all_reduce_splitter_test.cc rename to third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc index 3902a97c439724..ec2e66d1b66100 100644 --- a/third_party/xla/xla/service/all_reduce_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_splitter.h" +#include "xla/service/gpu/transforms/all_reduce_splitter.h" #include #include @@ -29,7 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/gpu_reduce_scatter_creator.h" +#include "xla/service/gpu/transforms/reduce_scatter_creator.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/tests/filecheck.h" diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc similarity index 94% rename from third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc rename to third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc index c2f6c04e5c274a..aa76aff4dfec49 100644 --- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_async_collective_annotator.h" +#include "xla/service/gpu/transforms/async_collective_annotator.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -29,7 +29,7 @@ limitations under the License. namespace xla { namespace gpu { -absl::StatusOr GpuAsyncCollectiveAnnotator::Run( +absl::StatusOr AsyncCollectiveAnnotator::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h similarity index 78% rename from third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h rename to third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h index 4000fbcbdd4991..1b41d5056b29d2 100644 --- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_ -#define XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_ #include @@ -29,12 +29,12 @@ namespace xla { namespace gpu { // Annotate async collectives with CollectiveBackendConfig. -class GpuAsyncCollectiveAnnotator : public HloModulePass { +class AsyncCollectiveAnnotator : public HloModulePass { public: - explicit GpuAsyncCollectiveAnnotator(HloPredicate is_collective_async) + explicit AsyncCollectiveAnnotator(HloPredicate is_collective_async) : is_collective_async_(std::move(is_collective_async)) {} absl::string_view name() const override { - return "gpu-async-collective-annotator"; + return "async-collective-annotator"; } using HloPassInterface::Run; @@ -49,4 +49,4 @@ class GpuAsyncCollectiveAnnotator : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc similarity index 93% rename from third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc rename to third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc index f874a7e565ea73..6622a7b2d20035 100644 --- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_async_collective_annotator.h" +#include "xla/service/gpu/transforms/async_collective_annotator.h" #include #include @@ -97,18 +97,18 @@ struct TestCase { absl::flat_hash_set expected_sync; }; -class GpuAsyncCollectiveAnnotatorTest +class AsyncCollectiveAnnotatorTest : public HloTestBase, public ::testing::WithParamInterface {}; -XLA_TEST_P(GpuAsyncCollectiveAnnotatorTest, Test) { +XLA_TEST_P(AsyncCollectiveAnnotatorTest, Test) { const TestCase& test_case = GetParam(); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString, /*replica_count=*/2)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, GpuAsyncCollectiveAnnotator(test_case.is_async_predicate) - .Run(module.get())); + bool changed, + AsyncCollectiveAnnotator(test_case.is_async_predicate).Run(module.get())); EXPECT_TRUE(changed); // Assert that all async collectives are annotated with the backend config. @@ -175,8 +175,8 @@ std::string TestCaseName(const ::testing::TestParamInfo& test_case) { return test_case.param.test_name; } -INSTANTIATE_TEST_SUITE_P(GpuAsyncCollectiveAnnotatorTest, - GpuAsyncCollectiveAnnotatorTest, +INSTANTIATE_TEST_SUITE_P(AsyncCollectiveAnnotatorTest, + AsyncCollectiveAnnotatorTest, ::testing::ValuesIn(TestCases()), TestCaseName); } // namespace } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc index baeecbaf03fcf6..e80f225e027508 100644 --- a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc @@ -21,12 +21,14 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "tsl/platform/errors.h" namespace xla::gpu { @@ -36,6 +38,10 @@ absl::StatusOr AsyncWrapper::Run( const absl::flat_hash_set& execution_threads) { bool changed = false; + XLA_VLOG_LINES( + 1, absl::StrCat("AsyncWrapper will process the following module:\n", + module->ToString())); + std::deque computations; computations.push_back(module->entry_computation()); while (!computations.empty()) { @@ -45,6 +51,10 @@ absl::StatusOr AsyncWrapper::Run( for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (predicate_(instruction)) { + XLA_VLOG_LINES( + 1, absl::StrCat( + "AsyncWrapper will make the following instruction async:\n", + instruction->ToString())); // If the predicate matches, then wrap the instructions in async blocks. TF_RETURN_IF_ERROR( computation @@ -64,6 +74,11 @@ absl::StatusOr AsyncWrapper::Run( } } } + + XLA_VLOG_LINES( + 1, + absl::StrCat("AsyncWrapper finished processing the following module:\n", + module->ToString())); return changed; } diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index 2d4aa527b62eca..ffa2866ac739db 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -232,6 +232,9 @@ static bool IsAsyncStartCommand(const HloInstruction* hlo, } if (hlo->opcode() == HloOpcode::kAsyncStart) { + if (IsCublasGemm(*hlo->async_wrapped_instruction())) { + return config.enabled_commands.contains(DebugOptions::CUBLAS); + } if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } @@ -248,6 +251,9 @@ static bool IsAsyncDoneCommand(const HloInstruction* hlo, } if (hlo->opcode() == HloOpcode::kAsyncDone) { + if (IsCublasGemm(*hlo->async_wrapped_instruction())) { + return config.enabled_commands.contains(DebugOptions::CUBLAS); + } if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 43d0dae777cbba..843428f6467909 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -1013,6 +1013,38 @@ ENTRY e { }); } -} // namespace +TEST_F(CommandBufferSchedulingTest, AsyncCustomCall) { + const char* hlo = R"( + HloModule m, is_scheduled=true + + ENTRY %main (a: s32[], b: s32[]) -> f32[2,2] { + %p = f32[2,2]{1,0} parameter(0) + %start1 = ((f32[2,2], f32[2,2]), (f32[2,2], s8[4]), u32[]) custom-call-start(f32[2,2] %p, f32[2,2] %p), custom_call_target="__cublas$gemm" + %start2 = ((f32[2,2], f32[2,2]), (f32[2,2], s8[4]), u32[]) custom-call-start(f32[2,2] %p, f32[2,2] %p), custom_call_target="__cublas$gemm" + %done1 = (f32[2,2], s8[4]) custom-call-done(((f32[2,2], f32[2,2]), (f32[2,2], s8[4]), u32[]) %start1) + %done2 = (f32[2,2], s8[4]) custom-call-done(((f32[2,2], f32[2,2]), (f32[2,2], s8[4]), u32[]) %start2) + %result1 = f32[2,2] get-tuple-element((f32[2,2], s8[4]) %done1), index=0 + %result2 = f32[2,2] get-tuple-element((f32[2,2], s8[4]) %done2), index=0 + ROOT %sum = f32[2,2] add(f32[2,2] %result1, f32[2,2] %result2) + })"; + + const char* expected = R"( +// CHECK: %command_buffer ([[P:.+]]: f32[2,2]) -> ((f32[2,2], s8[4]), (f32[2,2], s8[4])) { +// CHECK: %[[P]] = f32[2,2]{1,0} parameter(0) +// CHECK: %[[S1:.+]] = ((f32[2,2]{1,0}, f32[2,2]{1,0}), (f32[2,2]{1,0}, s8[4]{0}), u32[]) custom-call-start(%[[P]], %[[P]]), custom_call_target="__cublas$gemm" +// CHECK: %[[S2:.+]] = ((f32[2,2]{1,0}, f32[2,2]{1,0}), (f32[2,2]{1,0}, s8[4]{0}), u32[]) custom-call-start(%[[P]], %[[P]]), custom_call_target="__cublas$gemm" +// CHECK: %[[D1:.+]] = (f32[2,2]{1,0}, s8[4]{0}) custom-call-done(%[[S1]]) +// CHECK: %[[D2:.+]] = (f32[2,2]{1,0}, s8[4]{0}) custom-call-done(%[[S2]]) +// CHECK: ROOT %[[T:.+]] = ((f32[2,2]{1,0}, s8[4]{0}), (f32[2,2]{1,0}, s8[4]{0})) tuple(%[[D1]], %[[D2]]) +// CHECK: })"; + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +} // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc rename to third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc index 0b55f7d264ff00..f072a91307644b 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_conv_padding_legalization.h" +#include "xla/service/gpu/transforms/conv_padding_legalization.h" #include #include @@ -166,7 +166,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, } } // namespace -bool GpuConvPaddingLegalization::CanonicalizeForwardConvolution( +bool ConvPaddingLegalization::CanonicalizeForwardConvolution( HloInstruction* conv) { if (IsForwardConvolutionCanonical(*conv)) { return false; @@ -219,7 +219,7 @@ void IncreasePaddingHighBy(int64_t delta, WindowDimension* window_dim) { } } // namespace -bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( +bool ConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { CHECK_EQ(backward_conv->custom_call_target(), kCudnnConvBackwardFilterCallTarget); @@ -292,7 +292,7 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( return true; } -bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( +bool ConvPaddingLegalization::CanonicalizeBackwardInputConvolution( HloInstruction* backward_conv) { if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; @@ -418,7 +418,7 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( return true; } -absl::StatusOr GpuConvPaddingLegalization::RunOnComputation( +absl::StatusOr ConvPaddingLegalization::RunOnComputation( HloComputation* computation) { bool changed = false; std::vector convs; @@ -445,7 +445,7 @@ absl::StatusOr GpuConvPaddingLegalization::RunOnComputation( return changed; } -absl::StatusOr GpuConvPaddingLegalization::Run( +absl::StatusOr ConvPaddingLegalization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h similarity index 86% rename from third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h rename to third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h index 32e0238bed1b3d..1841c926d9545b 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h +++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_ -#define XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -30,10 +30,10 @@ namespace gpu { // An HLO pass that canonicalizes convolution instructions for GPU codegen. It // inserts Pad instructions before Convolution instructions with uncanonicalized // padding, so that they can be lowered to Cudnn/Miopen convolution. -class GpuConvPaddingLegalization : public HloModulePass { +class ConvPaddingLegalization : public HloModulePass { public: absl::string_view name() const override { - return "gpu-conv-padding-legalization"; + return "conv-padding-legalization"; } using HloPassInterface::Run; @@ -52,4 +52,4 @@ class GpuConvPaddingLegalization : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc similarity index 93% rename from third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc rename to third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc index edaf9d053d77c9..06682e7d1affd6 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_conv_padding_legalization.h" +#include "xla/service/gpu/transforms/conv_padding_legalization.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/cublas_cudnn.h" @@ -32,9 +32,9 @@ namespace { namespace m = ::xla::match; -using GpuConvPaddingLegalizationTest = HloTestBase; +using ConvPaddingLegalizationTest = HloTestBase; -TEST_F(GpuConvPaddingLegalizationTest, BackwardInputConvolve) { +TEST_F(ConvPaddingLegalizationTest, BackwardInputConvolve) { auto module = ParseAndReturnVerifiedModule(R"( HloModule convolution_module ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8[0]) { @@ -75,7 +75,7 @@ ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8 } )") .value(); - ASSERT_TRUE(GpuConvPaddingLegalization().Run(module.get()).value()); + ASSERT_TRUE(ConvPaddingLegalization().Run(module.get()).value()); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Tuple( m::Slice(m::GetTupleElement( diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc index cb5b1867241e58..e19622dc27911f 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include #include @@ -845,10 +845,10 @@ absl::StatusOr RunOnComputation(HloComputation* computation, } } // namespace -absl::StatusOr GpuConvRewriter::Run( +absl::StatusOr ConvRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString()); + XLA_VLOG_LINES(2, "ConvRewriter::Run(), before:\n" + module->ToString()); bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { @@ -856,11 +856,11 @@ absl::StatusOr GpuConvRewriter::Run( RunOnComputation(computation, compute_capability_)); changed |= result; } - XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(2, "ConvRewriter::Run(), after:\n" + module->ToString()); return changed; } -/*static*/ bool GpuConvRewriter::ConvIsLowerable(HloInstruction* conv) { +/*static*/ bool ConvRewriter::ConvIsLowerable(HloInstruction* conv) { return CanImplementAsGpuForwardConv(conv) || MatchBackwardFilter(conv) || MatchBackwardInput(conv); } diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h similarity index 82% rename from third_party/xla/xla/service/gpu/gpu_conv_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/conv_rewriter.h index 74b860f239872c..69369f1f5cb54a 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_ -#define XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -34,12 +34,12 @@ namespace gpu { // patterns of ops will be matched and fused into the custom call in // CudnnFusedConvRewriter. -class GpuConvRewriter : public HloModulePass { +class ConvRewriter : public HloModulePass { public: - explicit GpuConvRewriter(const se::GpuComputeCapability& compute_capability) + explicit ConvRewriter(const se::GpuComputeCapability& compute_capability) : compute_capability_(compute_capability) {}; - absl::string_view name() const override { return "gpu-conv-rewriter"; } + absl::string_view name() const override { return "conv-rewriter"; } static bool ConvIsLowerable(HloInstruction* conv); @@ -55,4 +55,4 @@ class GpuConvRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc similarity index 95% rename from third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc index f83bae8fc54586..d01ffd1829b7f8 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include #include @@ -45,9 +45,9 @@ namespace { namespace m = ::xla::match; -class GpuConvRewriterTest : public HloTestBase { +class ConvRewriterTest : public HloTestBase { public: - GpuConvRewriterTest() + ConvRewriterTest() : HloTestBase(/*verifier_layout_sensitive=*/true, /*allow_mixed_precision_in_hlo_verifier=*/false) { for (int i = 0; i < 2; ++i) { @@ -103,7 +103,7 @@ class GpuConvRewriterTest : public HloTestBase { } bool RunPass(HloModule* module) { - return GpuConvRewriter(GetComputeCapability()).Run(module).value(); + return ConvRewriter(GetComputeCapability()).Run(module).value(); } // A convolution window with stride 1 and zero padding. The size fields are @@ -113,7 +113,7 @@ class GpuConvRewriterTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) { +TEST_F(ConvRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -154,8 +154,7 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) { << md_after_opt.DebugString() << " vs " << metadata.DebugString(); } -TEST_F(GpuConvRewriterTest, - BackwardFilterConvolveEquivalentToForwardConvolution) { +TEST_F(ConvRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -186,7 +185,7 @@ TEST_F(GpuConvRewriterTest, } // Extracted from block35 training. -TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { +TEST_F(ConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -216,7 +215,7 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { } // Extracted from inception v3 training. -TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { +TEST_F(ConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -245,7 +244,7 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0))); } -TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { +TEST_F(ConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -274,7 +273,7 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0))); } -TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) { +TEST_F(ConvRewriterTest, BackwardInputConvolveEvenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -343,7 +342,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) { // Convolve([abc], [x], base_dilation=2) // = Convolve([abc], Reverse([x]), base_dilation=2) // = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) { +TEST_F(ConvRewriterTest, BackwardInputConvolve1x1Filter) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. HloInstruction* output = @@ -381,7 +380,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) { // BackwardInputConvolve([abc], [x], stride=1) is equivalent to // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input // convolution. -TEST_F(GpuConvRewriterTest, +TEST_F(ConvRewriterTest, BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. @@ -427,7 +426,7 @@ TEST_F(GpuConvRewriterTest, // 20x10x10x192 // // Gradients are padded unevenly. -TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { +TEST_F(ConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -479,7 +478,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { +TEST_F(ConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -533,7 +532,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { // // We should fuse BC even though padding on activations is uneven, because // GpuConvPaddingLegalization will canonicalize the fusion HLO. -TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { +TEST_F(ConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. HloInstruction* output = @@ -590,7 +589,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { // We currently don't fuse BC because GpuConvPaddingLegalization // doesn't support negative padding on the gradients of backward convolution // (b/32744257). -TEST_F(GpuConvRewriterTest, +TEST_F(ConvRewriterTest, BackwardInputConvolveNegativePaddingHighOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -632,7 +631,7 @@ TEST_F(GpuConvRewriterTest, // Check that we will materialize a reversed version of a constant in order to // pattern-match a backwards input convolution. -TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) { +TEST_F(ConvRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); std::string constant_str = @@ -659,7 +658,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) { 0))); } -TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternMatch) { +TEST_F(ConvRewriterTest, TestBackwardFilterPatternMatch) { // All filter dimensions are larger than the corresponding output dimensions. // This must be a backward filter convolution. const std::string module_str = absl::StrFormat(R"( @@ -681,7 +680,7 @@ TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternMatch) { 0))); } -TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternNoMatch) { +TEST_F(ConvRewriterTest, TestBackwardFilterPatternNoMatch) { // At least one filter dimension is smaller than the corresponding output // dimension. This must be a forward convolution. const std::string module_str = absl::StrFormat(R"( @@ -703,7 +702,7 @@ TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternNoMatch) { 0))); } -TEST_F(GpuConvRewriterTest, TestConv1dBackwardFilterPatternMatch) { +TEST_F(ConvRewriterTest, TestConv1dBackwardFilterPatternMatch) { // There exist one kernel dimension equal to output dimension, regard // it as backward filter if conv is 1d. const std::string module_str = absl::StrFormat(R"( @@ -726,7 +725,7 @@ TEST_F(GpuConvRewriterTest, TestConv1dBackwardFilterPatternMatch) { 0))); } -TEST_F(GpuConvRewriterTest, TestConv1dBackwardInputPatternMatch) { +TEST_F(ConvRewriterTest, TestConv1dBackwardInputPatternMatch) { // For conv1d backward input, filter may reverse first and then reshape. const std::string module_str = absl::StrFormat(R"( HloModule Test @@ -749,7 +748,7 @@ TEST_F(GpuConvRewriterTest, TestConv1dBackwardInputPatternMatch) { 0))); } -TEST_F(GpuConvRewriterTest, TestInvalidTypes) { +TEST_F(ConvRewriterTest, TestInvalidTypes) { const std::string module_str = absl::StrFormat(R"( HloModule Test @@ -766,8 +765,7 @@ TEST_F(GpuConvRewriterTest, TestInvalidTypes) { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_with_type)); - absl::Status s = - GpuConvRewriter(GetComputeCapability()).Run(m.get()).status(); + absl::Status s = ConvRewriter(GetComputeCapability()).Run(m.get()).status(); EXPECT_THAT( s, tsl::testing::StatusIs( absl::StatusCode::kUnimplemented, @@ -780,17 +778,14 @@ TEST_F(GpuConvRewriterTest, TestInvalidTypes) { absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fn"}}); TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_with_type)); - absl::Status s = GpuConvRewriter(se::CudaComputeCapability::Ampere()) - .Run(m.get()) - .status(); + absl::Status s = + ConvRewriter(se::CudaComputeCapability::Ampere()).Run(m.get()).status(); EXPECT_THAT(s, tsl::testing::StatusIs( absl::StatusCode::kUnimplemented, ::testing::HasSubstr( "FP8 convolutions are only supported on CUDA " "GPUs with compute capability at least 9.0"))); - s = GpuConvRewriter(se::RocmComputeCapability{"gfx942"}) - .Run(m.get()) - .status(); + s = ConvRewriter(se::RocmComputeCapability{"gfx942"}).Run(m.get()).status(); EXPECT_THAT(s, tsl::testing::StatusIs( absl::StatusCode::kUnimplemented, ::testing::HasSubstr( @@ -799,7 +794,7 @@ TEST_F(GpuConvRewriterTest, TestInvalidTypes) { // Test unsupported FP8 type module_with_type = absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fnuz"}}); TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(module_with_type)); - s = GpuConvRewriter(GetComputeCapability()).Run(m.get()).status(); + s = ConvRewriter(GetComputeCapability()).Run(m.get()).status(); EXPECT_THAT(s, tsl::testing::StatusIs( absl::StatusCode::kUnimplemented, diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc similarity index 97% rename from third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc rename to third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc index b8c87e2e1978a2..a7dc96ebefaf7e 100644 --- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc +++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h" +#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h similarity index 86% rename from third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h rename to third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h index ea56f7a91914ce..6507080a5fa49b 100644 --- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h +++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ -#define XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ #include @@ -44,4 +44,4 @@ class GpuConvertAsyncCollectivesToSync : public ConvertAsyncCollectivesToSync { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc rename to third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc index 03f18bd3c5eb6d..d38ab70864ac4c 100644 --- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h" +#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h" #include diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc index 43f1242e8a2982..82f88398c2bb0d 100644 --- a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/transforms/gemm_fusion.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc new file mode 100644 index 00000000000000..bbc5ed42e6d9dd --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -0,0 +1,340 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/stream_executor/cuda/cuda_dnn.h" +#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +inline absl::StatusOr AsCudnnFmhaMaskKind( + CudnnfMHABackendConfig_MaskType mask_type) { + switch (mask_type) { + case CudnnfMHABackendConfig::NO_MASK: + return CudnnfMHAMaskKind::kNoMask; + case CudnnfMHABackendConfig::PADDING: + return CudnnfMHAMaskKind::kPadding; + case CudnnfMHABackendConfig::CAUSAL: + return CudnnfMHAMaskKind::kCausal; + case CudnnfMHABackendConfig::PADDING_CAUSAL: + return CudnnfMHAMaskKind::kPaddingCausal; + case CudnnfMHABackendConfig::ALIBI: + return CudnnfMHAMaskKind::kAlibi; + default: + return xla::Internal("Unknown fmha mask kind."); + } +} + +using se::dnn::DataType; +using se::dnn::MatmulTensorDescriptor; +using se::dnn::TensorDescriptor; + +absl::StatusOr TensorDescriptorFor(const Shape &shape) { + TF_ASSIGN_OR_RETURN(const DataType type, + GetDNNDataTypeFromPrimitiveType(shape.element_type())); + return TensorDescriptor::For(type, shape.dimensions(), + shape.layout().minor_to_major()); +} + +enum Side { LHS, RHS }; + +absl::StatusOr MatmulTensorDescriptorFor( + const Shape &shape, const DotDimensionNumbers &dnums, const Side side) { + TF_ASSIGN_OR_RETURN(const DataType type, + GetDNNDataTypeFromPrimitiveType(shape.element_type())); + return MatmulTensorDescriptor::For( + type, shape.dimensions(), shape.layout().minor_to_major(), + (side == LHS) ? dnums.lhs_batch_dimensions() + : dnums.rhs_batch_dimensions(), + (side == LHS) ? dnums.lhs_contracting_dimensions() + : dnums.rhs_contracting_dimensions()); +} + +absl::StatusOr HloCustomCallToCuDnnGraph( + se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) { + if (IsFwdCustomCallTofMHA(*custom_call)) { + TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, + xla::gpu::GetCudnnfMHAKind(custom_call)); + TF_ASSIGN_OR_RETURN( + const auto gpu_config, + custom_call->backend_config()); + const xla::gpu::CudnnfMHABackendConfig &config = + gpu_config.cudnn_fmha_backend_config(); + + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor lhs_bmm1, + MatmulTensorDescriptorFor(custom_call->operand(0)->shape(), + config.bmm1_dot_dimension_numbers(), LHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor rhs_bmm1, + MatmulTensorDescriptorFor(custom_call->operand(1)->shape(), + config.bmm1_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor rhs_bmm2, + MatmulTensorDescriptorFor(custom_call->operand(2)->shape(), + config.bmm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + TensorDescriptor output, + TensorDescriptorFor(ShapeUtil::GetSubshape(custom_call->shape(), {0}))); + + std::optional activation; + const bool has_activation = + xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; + if (has_activation) { + TF_ASSIGN_OR_RETURN( + activation, TensorDescriptorFor( + ShapeUtil::GetSubshape(custom_call->shape(), {1}))); + } + + std::optional bias; + if (kind == CudnnfMHAKind::kScaleBiasSoftmax || + kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout) { + const HloInstruction &bias_hlo = *custom_call->operand(3); + TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(bias_hlo.shape())); + } + + const double dropout_rate = config.dropout_rate(); + + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionOperationGraph( + dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, bias, activation, + static_cast(config.fmha_scale()), dropout_rate > 0.0, + dropout_rate, dnn_mask_type)); + return std::move(graph); + } else { + TF_ASSIGN_OR_RETURN( + auto gpu_config, + custom_call->backend_config()); + xla::gpu::CudnnfMHABackendConfig &config = + *gpu_config.mutable_cudnn_fmha_backend_config(); + + int input_index = 0; + const Shape &bmm1_grad_gemm1_rhs_shape = + custom_call->operand(input_index++)->shape(); + const Shape &bmm1_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + const Shape &bmm2_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + const Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); + ++input_index; + const Shape &d_output_shape = custom_call->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, + GetCudnnfMHAKind(custom_call)); + + bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); + std::optional bias_shape; + if (has_bias) { + bias_shape = custom_call->operand(input_index++)->shape(); + } + + // Unused fwd_output_shape + ++input_index; + + if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || + config.mask_type() == + xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { + // skip q_seqlen and kv_seqlen + input_index += 2; + } + TF_RET_CHECK(input_index == custom_call->operand_count()); + + int output_index = 0; + const Shape &d_bmm1_lhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + const Shape &d_bmm1_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + const Shape &d_bmm2_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + bool has_dbias = custom_call->shape().tuple_shapes().size() == 5; + if (has_dbias) { + ++output_index; + } + // The last one is the workspace. + TF_RET_CHECK(output_index == + custom_call->shape().tuple_shapes().size() - 1); + + const DebugOptions &debug_options = + custom_call->GetModule()->config().debug_options(); + bool force_deterministic = + debug_options.xla_gpu_deterministic_ops() || + debug_options.xla_gpu_exclude_nondeterministic_ops(); + config.set_force_deterministic(force_deterministic); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm1_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm1_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm2_rhs_shape, + config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm1_lhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm1_lhs_shape, + config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm2_rhs_shape, + config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor d_output, + MatmulTensorDescriptorFor( + d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), + RHS)); + + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_lhs, + TensorDescriptorFor(d_bmm1_lhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_rhs, + TensorDescriptorFor(d_bmm1_rhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs, + TensorDescriptorFor(d_bmm2_rhs_shape)); + + std::optional bias; + if (bias_shape.has_value()) { + TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(*bias_shape)); + } + + const double dropout_rate = config.dropout_rate(); + + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( + dnn_support, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs, + bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs, + d_bmm1_rhs, d_bmm2_rhs, bias, dropout_rate, config.seed(), + config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt, + dnn_mask_type, force_deterministic)); + return std::move(graph); + } +} + +class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { + public: + explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport &dnn_support, + BinaryMap &compilation_results) + : dnn_support_(dnn_support), compilation_results_(compilation_results) {} + + void AddWorkspace(HloInstruction &hlo, int64_t workspace_size) { + if (workspace_size == 0) { + return; + } + VLOG(4) << "Applying workspace size " << workspace_size << " to " + << hlo.ToString(); + Shape *shape = hlo.mutable_shape(); + shape->mutable_tuple_shapes()->back().set_dimensions(0, workspace_size); + } + + absl::Status HandleCustomCall(HloInstruction *hlo) override { + if (!IsCustomCallTofMHA(*hlo)) { + return absl::OkStatus(); + } + + TF_ASSIGN_OR_RETURN(const std::string fingerprint_without_workspace, + FingerprintWithBackendConfig(*hlo)); + auto workspace_size_it = + workspace_sizes_.find(fingerprint_without_workspace); + if (workspace_size_it == workspace_sizes_.cend()) { + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + HloCustomCallToCuDnnGraph(dnn_support_, + DynCast(hlo))); + + const int64_t workspace_size = graph.Graph().get_workspace_size(); + workspace_sizes_.insert(workspace_size_it, + {fingerprint_without_workspace, workspace_size}); + AddWorkspace(*hlo, workspace_size); + + std::vector serialized_graph; + RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph)); + // Compute a new fingerprint with a potential workspace for the + // compilation results to match a fingerprint computed by the emitter. + TF_ASSIGN_OR_RETURN(const std::string fingerprint_with_workspace, + FingerprintWithBackendConfig(*hlo)); + compilation_results_[fingerprint_with_workspace] = + std::string(reinterpret_cast(serialized_graph.data()), + serialized_graph.size()); + } else { + VLOG(4) << "Cache hit."; + AddWorkspace(*hlo, workspace_size_it->second); + } + + MarkAsChanged(); + return absl::OkStatus(); + } + + private: + se::dnn::DnnSupport &dnn_support_; + BinaryMap &compilation_results_; + absl::flat_hash_map workspace_sizes_; +}; + +} // namespace + +absl::StatusOr CuDnnCustomCallCompiler::Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) { + XLA_SCOPED_LOGGING_TIMER_LEVEL("cuDNN custom call compiler", 8); + return CuDnnCustomCallVisitor(dnn_support_, compilation_results_) + .RunOnModule(module, execution_threads); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h similarity index 61% rename from third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h index 962841289b58dc..810286f91b8472 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ -#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" @@ -27,14 +28,18 @@ limitations under the License. namespace xla { namespace gpu { -// Rewrite cuDNN custom call to have correct workspace size by build graph -// and serialize so we can use it later -class CuDnnWorkspaceRewriter : public HloModulePass { +// Compile cuDNN custom calls to binaries and serialize them. +// Also adjust them in HLO to have correct workspace size. +class CuDnnCustomCallCompiler : public HloModulePass { public: - explicit CuDnnWorkspaceRewriter(se::StreamExecutor& stream_exec) - : dnn_support_(*stream_exec.AsDnn()) {} + explicit CuDnnCustomCallCompiler(se::StreamExecutor& stream_exec, + BinaryMap& compilation_results) + : dnn_support_(*stream_exec.AsDnn()), + compilation_results_(compilation_results) {} - absl::string_view name() const override { return "cudnn-workspace-rewriter"; } + absl::string_view name() const override { + return "cudnn-custom-call-compiler"; + } using HloPassInterface::Run; absl::StatusOr Run( @@ -43,9 +48,10 @@ class CuDnnWorkspaceRewriter : public HloModulePass { private: se::dnn::DnnSupport& dnn_support_; + BinaryMap& compilation_results_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc index 3602d94fb05528..5e22a6f2ec1af3 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc @@ -35,34 +35,33 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/hlo_module_config.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/dnn.h" -#include "xla/tests/verified_hlo_module.h" -#include "tsl/platform/statusor.h" - -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif // GOOGLE_CUDA - #include "xla/service/algebraic_simplifier.h" #include "xla/service/convert_mover.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include "xla/service/hlo_constant_folding.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_fix.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/reshape_mover.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif // GOOGLE_CUDA namespace xla { namespace gpu { @@ -244,7 +243,7 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { // On older architectures, disregard layout information and only verify // the basic configuration of the convolution Custom Call using the number // of operands and the window_size and serialized graph attributes based - // on the GpuConvRewriter and CudnnFusedConvRewriter passes. + // on the ConvRewriter and CudnnFusedConvRewriter passes. std::string::size_type p0 = custom_call_string.find(':'); std::string::size_type p1 = custom_call_string.find("custom-call"); custom_call_string.erase(p0 + 1, p1 - p0 - 2); @@ -254,8 +253,8 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(pre_hlo_string)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloPass(GpuConvRewriter(GetCudaComputeCapability()), - module.get())); + bool changed, + RunHloPass(ConvRewriter(GetCudaComputeCapability()), module.get())); EXPECT_TRUE(changed); RunAndFilecheckHloRewrite( module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), @@ -1317,7 +1316,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1351,7 +1350,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1392,7 +1391,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1446,7 +1445,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1492,7 +1491,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1547,7 +1546,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1587,7 +1586,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1628,7 +1627,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1678,7 +1677,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // elu fusion is only active on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -1727,7 +1726,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1780,7 +1779,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // relu6 fusion is only enabled on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -1824,7 +1823,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1872,7 +1871,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // Leaky-relu fusion is only enabled on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -1919,7 +1918,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1967,7 +1966,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2007,7 +2006,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2046,7 +2045,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2085,7 +2084,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2121,7 +2120,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2158,7 +2157,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2192,7 +2191,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2225,7 +2224,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2252,7 +2251,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2279,7 +2278,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2307,7 +2306,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2337,7 +2336,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2369,7 +2368,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2412,7 +2411,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2455,7 +2454,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2493,7 +2492,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2536,7 +2535,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2583,7 +2582,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2636,7 +2635,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2692,7 +2691,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2746,7 +2745,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2810,7 +2809,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2853,7 +2852,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc index 519b495e76be30..3ffd74e9e594b7 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc @@ -213,10 +213,13 @@ class GemmDimensionAdapter { return GemmDimensionAdapter{*dot, std::move(analysis)}; } - bool DimensionsAndStrides(const HloInstruction& hlo, - const TritonFusionAnalysis::Scope scope, - std::vector& dimensions, - std::vector& strides) { + struct Result { + std::vector sizes; + std::vector strides; + }; + + std::optional DimensionsAndStrides( + const HloInstruction& hlo, const TritonFusionAnalysis::Scope scope) { const DotDimensionNumbers& dims = dot_.dot_dimension_numbers(); // GEMM fusions require a specific canonical order of dimensions. constexpr int kBatchDimensionIndex = 0; @@ -253,29 +256,33 @@ class GemmDimensionAdapter { case TritonFusionAnalysis::Scope::META: LOG(FATAL) << "Unsupported scope."; } - dimensions.reserve(dim_indices.size()); - strides.reserve(dim_indices.size()); + + Result result; + result.sizes.reserve(dim_indices.size()); + result.strides.reserve(dim_indices.size()); + for (const int index : dim_indices) { const auto* spec = analysis_.IterSpec(scope, &hlo, index); if (spec == nullptr) { - dimensions.push_back(1); - strides.push_back(strides.empty() ? 1 : strides.back()); + result.sizes.push_back(1); + result.strides.push_back( + result.strides.empty() ? 1 : result.strides.back()); continue; } else { if (spec->size() == 1) { // The dimension is not split, nothing to do. } else if (spec->size() == 2) { if (FusionLevel(hlo) < 3) { - return false; + return std::nullopt; } if (!dims.lhs_batch_dimensions().empty()) { VLOG(8) << "Noncontracting dimension split is not compatible with " "batch dimensions."; - return false; + return std::nullopt; } if (index != lhs_noncontracting_index) { VLOG(8) << "Only LHS noncontracting dimension can be split."; - return false; + return std::nullopt; } switch (scope) { case TritonFusionAnalysis::Scope::LHS: @@ -285,40 +292,40 @@ class GemmDimensionAdapter { if (lhs_noncontracting_split_ != spec->back().count) { VLOG(8) << "Output non-contracting dimension has to be split " "the same way as the LHS input one if it is split."; - return false; + return std::nullopt; } break; default: VLOG(8) << "Only LHS noncontracting dimension can be split."; - return false; + return std::nullopt; } // Assign the major part of the noncontracting dimension to the // unused batch one. - CHECK_EQ(dimensions[kBatchDimensionIndex], 1); - dimensions[kBatchDimensionIndex] = spec->back().count; - strides[kBatchDimensionIndex] = spec->back().stride; + CHECK_EQ(result.sizes[kBatchDimensionIndex], 1); + result.sizes[kBatchDimensionIndex] = spec->back().count; + result.strides[kBatchDimensionIndex] = spec->back().stride; } else { VLOG(8) << "The dimension is split multiple times."; - return false; + return std::nullopt; } - dimensions.push_back(spec->front().count); - strides.push_back(spec->front().stride); + result.sizes.push_back(spec->front().count); + result.strides.push_back(spec->front().stride); } } if (lhs_noncontracting_split_ > 1 && scope == TritonFusionAnalysis::Scope::OUTPUT && - dimensions[kBatchDimensionIndex] == 1) { + result.sizes[kBatchDimensionIndex] == 1) { // LHS input noncontracting dimension is split but the corresponding // output one is not. Assign part of the output one to the unused batch // dimension. - dimensions[kBatchDimensionIndex] = lhs_noncontracting_split_; - dimensions[kOutputLHSNonContractingDimensionIndex] /= + result.sizes[kBatchDimensionIndex] = lhs_noncontracting_split_; + result.sizes[kOutputLHSNonContractingDimensionIndex] /= lhs_noncontracting_split_; - strides[kBatchDimensionIndex] = - strides[kOutputLHSNonContractingDimensionIndex] * - dimensions[kOutputLHSNonContractingDimensionIndex]; + result.strides[kBatchDimensionIndex] = + result.strides[kOutputLHSNonContractingDimensionIndex] * + result.sizes[kOutputLHSNonContractingDimensionIndex]; } - return true; + return result; } private: @@ -397,8 +404,7 @@ absl::StatusOr> HloFusionToCuDnnGraph( return std::nullopt; } auto add_parameter = [&](const HloInstruction& parameter, - std::vector& dimensions, - std::vector strides) { + const GemmDimensionAdapter::Result& dims) { const std::optional data_type = ToCudnnDataType(parameter.shape().element_type()); if (!data_type.has_value()) { @@ -407,8 +413,8 @@ absl::StatusOr> HloFusionToCuDnnGraph( } hlo_to_cudnn[¶meter] = graph.tensor( graph::Tensor_attributes() - .set_dim(dimensions) - .set_stride(strides) + .set_dim(dims.sizes) + .set_stride(dims.strides) .set_data_type(*data_type) .set_name(std::string(parameter.name())) .set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number()))); @@ -419,14 +425,13 @@ absl::StatusOr> HloFusionToCuDnnGraph( TritonFusionAnalysis::Scope::OUTPUT}) { for (const HloInstruction* parameter : adapter->analysis_.ScopeParameters(scope)) { - std::vector dimensions; - std::vector strides; - if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions, - strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*parameter, scope); + if (!dims.has_value()) { VLOG(3) << "Unsupported dimensions."; return std::nullopt; } - if (!add_parameter(*parameter, dimensions, strides)) { + if (!add_parameter(*parameter, *dims)) { return std::nullopt; } } @@ -507,19 +512,19 @@ absl::StatusOr> HloFusionToCuDnnGraph( // setting output of the unary shapes results in the rejection of // the cuDNN graph. if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) { - const auto scope = adapter->analysis_.QueryInstructionScope(*hlo); - std::vector dimensions; - std::vector strides; + const std::optional scope = + adapter->analysis_.QueryInstructionScope(*hlo); if (!scope.has_value()) { LOG(FATAL) << "No scope for instruction: " << hlo->ToShortString(); } - if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions, - strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*hlo, *scope); + if (!dims.has_value()) { VLOG(3) << "Unsupported hlo for querying dimensions: " << hlo->ToShortString(); } else { - hlo_to_cudnn[hlo]->set_dim(dimensions); + hlo_to_cudnn[hlo]->set_dim(dims->sizes); } } } else if (hlo->operand_count() == 2) { @@ -563,17 +568,17 @@ absl::StatusOr> HloFusionToCuDnnGraph( if (instructions.back()->shape().IsTuple()) { output = instructions.back()->operand(0); } - std::vector dimensions; - std::vector strides; - if (!adapter->DimensionsAndStrides( - *output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*output, + TritonFusionAnalysis::Scope::OUTPUT); + if (!dims.has_value()) { VLOG(3) << "Unsupported dimensions."; return std::nullopt; } hlo_to_cudnn[output] ->set_output(true) - .set_dim(dimensions) - .set_stride(strides) + .set_dim(dims->sizes) + .set_stride(dims->strides) .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count())); if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) { json dump; diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc index d130c087426042..a3dbc71132949a 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc @@ -286,7 +286,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,4,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[8,8,6]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]]) @@ -298,7 +298,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[8,6,8]{2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: ROOT {{.*}} = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -346,7 +347,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,1,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,1,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,6]{3,2,1,0} transpose([[P0]]), dimensions={1,0,3,2} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,6]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]]) @@ -358,7 +359,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,6,8]{2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: ROOT {{.*}} = f32[2,1,6,8]{3,2,1,0} bitcast([[FUSION]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -406,7 +408,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> f32[2,4,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,24]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]]) @@ -418,7 +420,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,24,8]{2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: ROOT {{.*}} = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -466,7 +469,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> f32[2,4,1,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]]) @@ -478,7 +481,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,4,8]{2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: ROOT {{.*}} = f32[2,4,1,8]{3,2,1,0} bitcast([[FUSION]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -757,7 +761,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,24]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]]) @@ -769,13 +773,14 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,24,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION0]]) ; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1 ; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]]) ; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2 ; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]]) ; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -825,7 +830,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]]) @@ -837,13 +842,14 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[BITCAST:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} bitcast([[FUSION0]]) ; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1 ; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]]) ; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2 ; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]]) ; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -1129,7 +1135,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2} +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[8,8,6]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]]) @@ -1142,9 +1148,11 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { ; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" ; CHECK: } ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[TRANSPOSE1:%[^ ]+]] = f32[8,6,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[TRANSPOSE1]]) ; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) -; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P3]]), dimensions={0,1,3,2} -; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-NEXT: [[TRANSPOSE2:%[^ ]+]] = f32[8,8,6]{2,1,0} fusion([[P3]]), kind=kLoop, calls{{.*}} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE2]]) ; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 ; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), @@ -1154,14 +1162,13 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { ; CHECK-DAG: "kind":"LAYER_BWD" ; CHECK: } ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 -; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] -; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0 -; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1 +; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[8,6,8]{2,1,0} fusion([[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-DAG: [[BITCAST2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION]]) ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 ; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE4]]) ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 ; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE5]]) -; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[BITCAST]], [[BITCAST2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -1237,7 +1244,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,24]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]]) @@ -1250,9 +1257,11 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { ; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" ; CHECK: } ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[TRANSPOSE1:%[^ ]+]] = f32[2,24,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[TRANSPOSE1]]) ; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) -; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P3]]), dimensions={0,3,1,2} -; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-NEXT: [[TRANSPOSE2:%[^ ]+]] = f32[2,8,24]{2,1,0} fusion([[P3]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE2]]) ; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 ; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), @@ -1262,14 +1271,13 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { ; CHECK-DAG: "kind":"LAYER_BWD" ; CHECK: } ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 -; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] -; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0 -; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1 +; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[2,24,8]{2,1,0} fusion([[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-DAG: [[BITCAST2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION]]) ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 ; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE4]]) ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 ; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE5]]) -; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[BITCAST]], [[BITCAST2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -1345,7 +1353,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1], {{.*}}: f32[2,4,1,8]) -> (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]]) @@ -1358,9 +1366,11 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { ; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" ; CHECK: } ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[BITCAST:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} bitcast([[TRANSPOSE1]]) ; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(3) -; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P3]]), dimensions={2,0,3,1} -; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-NEXT: [[TRANSPOSE2:%[^ ]+]] = f32[2,8,4]{2,1,0} fusion([[P3]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE2]]) ; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 ; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), @@ -1370,14 +1380,13 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { ; CHECK-DAG: "kind":"LAYER_BWD" ; CHECK: } ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 -; CHECK-DAG: [[FUSION0:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] -; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=0 -; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=1 +; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,4,8]{2,1,0} fusion([[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-DAG: [[BITCAST2:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} bitcast([[FUSION0]]) ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 ; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE4]]) ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 ; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE5]]) -; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[BITCAST]], [[BITCAST2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) )"; TestNorm(hlo_text, optimized_hlo); diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc deleted file mode 100644 index b5440a8a2af53f..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc +++ /dev/null @@ -1,272 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/transforms/cudnn_workspace_rewriter.h" - -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_clone_context.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/cuda/cuda_dnn.h" -#include "xla/stream_executor/dnn.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace { - -// create cuDNN graphs from HloCustomCall -absl::StatusOr HloCustomCallToCuDnnGraph( - se::dnn::DnnSupport& dnn_support, HloCustomCallInstruction* custom_call) { - if (IsFwdCustomCallTofMHA(*custom_call)) { - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - std::optional mask_shape, bias_shape; - { - bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || - kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; - - if (has_bias) { - const HloInstruction* bias = custom_call->operand(3); - bias_shape = bias->shape(); - } - } - - TF_ASSIGN_OR_RETURN( - const auto gpu_config, - custom_call->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - absl::InlinedVector output_shapes = { - ShapeUtil::GetSubshape(custom_call->shape(), {0})}; - - bool has_activation = - xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; - if (has_activation) { - output_shapes.push_back( - ShapeUtil::GetSubshape(custom_call->shape(), {1})); - } - - Shape q_shape = custom_call->operand(0)->shape(); - Shape k_shape = custom_call->operand(1)->shape(); - Shape v_shape = custom_call->operand(2)->shape(); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - GpufMHADescriptor descriptor = {kind, - config, - cudnn_mask_type, - q_shape, - k_shape, - v_shape, - intermediate_tensor_shape, - output_shapes, - config.bmm1_dot_dimension_numbers(), - config.bmm2_dot_dimension_numbers(), - mask_shape, - bias_shape}; - - TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, - GpufMHAConfig::For(descriptor)); - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionOperationGraph( - dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1, - fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias, - fmha_config.activation, static_cast(*fmha_config.fmha_scale), - fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.dropout_rate, dnn_mask_type)); - return std::move(graph); - } else { - TF_ASSIGN_OR_RETURN( - auto gpu_config, - custom_call->backend_config()); - xla::gpu::CudnnfMHABackendConfig& config = - *gpu_config.mutable_cudnn_fmha_backend_config(); - - int input_index = 0; - Shape bmm1_grad_gemm1_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm1_grad_gemm2_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm2_grad_gemm2_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); - input_index++; - Shape d_output_shape = custom_call->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, - GetCudnnfMHAKind(custom_call)); - std::optional mask_shape; - - bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || - kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); - std::optional bias_shape; - if (has_bias) { - bias_shape = custom_call->operand(input_index++)->shape(); - } - - std::optional fwd_output_shape = - custom_call->operand(input_index++)->shape(); - if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || - config.mask_type() == - xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { - // skip q_seqlen and kv_seqlen - input_index += 2; - } - TF_RET_CHECK(input_index == custom_call->operand_count()); - - int output_index = 0; - Shape d_bmm1_lhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - Shape d_bmm1_rhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - Shape d_bmm2_rhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - std::optional d_s_shape; - std::optional d_bias_shape; - bool has_dbias = custom_call->shape().tuple_shapes().size() == 5; - if (has_dbias) { - d_bias_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - } - // The last one is the workspace. - TF_RET_CHECK(output_index == - custom_call->shape().tuple_shapes().size() - 1); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - - const bool force_deterministic = - RequireDeterminism(custom_call->GetModule()->config()); - // set the correct force_deterministic attribute here - config.set_force_deterministic(force_deterministic); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); - - GpufMHABackwardDescriptor descriptor = { - kind, - config, - cudnn_mask_type, - bmm1_grad_gemm1_rhs_shape, - bmm1_grad_gemm2_rhs_shape, - bmm2_grad_gemm1_lhs_shape, - bmm2_grad_gemm2_rhs_shape, - d_output_shape, - d_bmm1_lhs_shape, - d_bmm1_rhs_shape, - d_bmm2_rhs_shape, - config.bmm1_grad_gemm1_dot_dimension_numbers(), - config.bmm1_grad_gemm2_dot_dimension_numbers(), - config.bmm2_grad_gemm1_dot_dimension_numbers(), - config.bmm2_grad_gemm2_dot_dimension_numbers(), - d_s_shape, - fwd_output_shape, - mask_shape, - d_bias_shape, - bias_shape, - force_deterministic}; - - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config, - GpufMHABackwardConfig::For(descriptor)); - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); - - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( - dnn_support, fmha_config.bmm1_grad_gemm1_rhs, - fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs, - fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output, - fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs, - fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate, - fmha_config.seed, *fmha_config.fmha_scale, - fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.bias != std::nullopt, dnn_mask_type, - force_deterministic)); - return std::move(graph); - } -} - -class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { - public: - explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport& dnn_support) - : dnn_support_(dnn_support) {} - - absl::Status HandleCustomCall(HloInstruction* hlo) override { - if (!IsCustomCallTofMHA(*hlo)) { - // don't do anything about other cuDNN custom calls - return absl::OkStatus(); - } - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); - - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - HloCustomCallToCuDnnGraph(dnn_support_, - DynCast(hlo))); - auto workspace = graph.Graph().get_workspace_size(); - if (workspace != 0) { - // rewrite custom call to have correct workspace size - VLOG(4) << "Rewriting: " << hlo->ToString(); - Shape* shape = hlo->mutable_shape(); - shape->mutable_tuple_shapes(shape->tuple_shapes_size() - 1) - ->set_dimensions(0, workspace); - MarkAsChanged(); - } - return absl::OkStatus(); - } - - private: - se::dnn::DnnSupport& dnn_support_; -}; - -} // namespace - -absl::StatusOr CuDnnWorkspaceRewriter::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - XLA_SCOPED_LOGGING_TIMER("cuDNN workspace rewriter"); - return CuDnnCustomCallVisitor(dnn_support_) - .RunOnModule(module, execution_threads); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc similarity index 97% rename from third_party/xla/xla/service/gpu/dot_operand_converter.cc rename to third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc index 2a298e67eaf70e..d9e095e2c57ce0 100644 --- a/third_party/xla/xla/service/gpu/dot_operand_converter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dot_operand_converter.h" +#include "xla/service/gpu/transforms/dot_operand_converter.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter.h b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h similarity index 88% rename from third_party/xla/xla/service/gpu/dot_operand_converter.h rename to third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h index d277a24100c0b3..b269bed8b6a6f3 100644 --- a/third_party/xla/xla/service/gpu/dot_operand_converter.h +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ -#define XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_ #include @@ -43,4 +43,4 @@ class DotOperandConverter : public OpExpanderPass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_ diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/dot_operand_converter_test.cc rename to third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc index 63b0017012f419..be05b6767abbfd 100644 --- a/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dot_operand_converter.h" +#include "xla/service/gpu/transforms/dot_operand_converter.h" #include diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index a3c4701f013c5c..1eadef692a6839 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -56,6 +57,8 @@ namespace gpu { namespace { +namespace m = ::xla::match; + // A dataflow path flowing from a definition to a user. using DefUseDataflowPath = absl::InlinedVector; @@ -149,6 +152,98 @@ bool IsAlignedSlice(const HloInstruction* slice) { return true; } +// Pattern matches the following IR (generated by `jax.lax.scan`) to check if +// the offset is a loop iteration number: + +// clang-format off +// param = (s32[], s32[], s32[16]{0}, s32[16]{0}) parameter(0) +// // the index in `gte` has to be the loop iteration index +// gte = s32[] get-tuple-element(param), index=0 +// c0 = s32[] constant(0) compare = pred[] compare(gte, c0), direction=LT +// c_trip_count = s32[] constant(16) +// add = s32[] add(gte, c_trip_count) select = s32[] select(compare, add, gte) +// clang-format on + +bool IsLoopIterationNumber(const HloInstruction& offset) { + const HloComputation* parent = offset.parent(); + if (!parent->IsWhileBodyComputation()) return false; + + // Scan loops trip count must be known at compile time as it iterates over the + // leading dimension of the statically shaped input. + const HloInstruction* while_instr = parent->WhileCallInstruction(); + auto config = while_instr->backend_config(); + if (!config.ok() || !config->has_known_trip_count()) return false; + int32_t trip_count = config->known_trip_count().n(); + + // First lets check the offset computation pattern + if (!Match(&offset, m::Select(m::Lt(m::GetTupleElement(m::Parameter(0)), + m::ConstantScalar(0)), + m::Add(m::GetTupleElement(m::Parameter(0)), + m::ConstantScalar(trip_count)), + m::GetTupleElement(m::Parameter())))) { + return false; + } + + // Next, we check that the parameter used in offset computation is the loop + // induction variable + int64_t param_idx = offset.operand(2)->tuple_index(); + const HloInstruction* root = offset.parent()->root_instruction(); + if (root->opcode() != HloOpcode::kTuple) { + return false; + } + // Check the update operation + const HloInstruction* updated_var = + offset.parent()->root_instruction()->operand(param_idx); + if (!Match(updated_var, m::Add(m::GetTupleElement(m::Parameter(0), param_idx), + m::ConstantScalar(1)))) { + return false; + } + // Check that the condition considers this. + const HloInstruction* condition_root = + while_instr->while_condition()->root_instruction(); + if (!Match(condition_root, + m::Lt(m::GetTupleElement(m::Parameter(0), param_idx), + m::ConstantScalar(trip_count)))) { + return false; + } + // Check init + const HloInstruction* init_loop_iter = + while_instr->operand(0)->operand(param_idx); + if (!Match(init_loop_iter, m::ConstantScalar(0))) { + return false; + } + + return true; +} + +// This returns true for the constants that are handled in the dynamic slice +// fusion runtime. These constants do not force a D2H copy and hence preserve +// the cuda graph. +bool IsHandledConstantForDynamicSliceFusion(const HloInstruction& offset) { + if (auto* cst = DynCast(&offset)) { + switch (cst->shape().element_type()) { + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U32: + case PrimitiveType::U64: + return true; + default: + return false; + }; + } + return false; +} + +// This checks whether a dynamic index operation has all offsets that are either +// constant or loop iteration offsets. +bool HasConstantOrLoopIterationOffsets( + const HloDynamicIndexInstruction& instr) { + return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) { + return IsLoopIterationNumber(*offset) || + IsHandledConstantForDynamicSliceFusion(*offset); + }); +} + UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { UseDefDataflowPaths sliced_operand_paths; @@ -193,8 +288,15 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { }); if (maybe_slice_instr == std::nullopt) continue; - - if (slice_found || processed_instrs.contains(maybe_slice_instr.value())) { + auto dynamic_index_operation = + DynCast(maybe_slice_instr.value()); + bool valid_slice_found = + slice_found && + ((dynamic_index_operation && + HasConstantOrLoopIterationOffsets(*dynamic_index_operation)) || + (*maybe_slice_instr)->opcode() == HloOpcode::kSlice); + if (valid_slice_found || + processed_instrs.contains(maybe_slice_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced operand path // during the latest traversal. @@ -241,7 +343,12 @@ DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { }, /*visit_operands=*/false); if (maybe_dus_instr == std::nullopt) return; - if (dus_found || processed_instrs.contains(maybe_dus_instr.value())) { + auto dynamic_index_operation = + DynCast(maybe_dus_instr.value()); + bool valid_dus_found = + dus_found && dynamic_index_operation && + HasConstantOrLoopIterationOffsets(*dynamic_index_operation); + if (valid_dus_found || processed_instrs.contains(maybe_dus_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced user path // during the latest traversal. @@ -405,8 +512,9 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( const absl::flat_hash_set& execution_threads) { absl::flat_hash_map> - matches; + matches_kv; + std::vector matches; // Collect all potential custom call matches in the non-fusion computations. for (HloComputation* computation : module->computations()) { if (computation->IsFusionComputation()) continue; @@ -433,8 +541,9 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( } if (has_sliced_operand_paths || has_sliced_user_paths) { - matches[instr] = std::make_pair(std::move(sliced_operand_paths), - std::move(sliced_user_paths)); + matches_kv[instr] = std::make_pair(std::move(sliced_operand_paths), + std::move(sliced_user_paths)); + matches.push_back(instr); } } } @@ -442,7 +551,8 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( if (matches.empty()) return false; - for (auto& [hero, paths] : matches) { + for (HloInstruction* hero : matches) { + auto& paths = matches_kv[hero]; auto& [sliced_operand_paths, sliced_user_paths] = paths; std::vector matched_instrs; absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs)); diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc index 36fa64bc46e0a9..2bd7168adfc06c 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc @@ -1785,9 +1785,6 @@ TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSConstantOffset) { RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); } -// This is not required to pass, but in current implementation this works by -// forcing a D2H copy. Adding this here to ensure that the change in this -// behaviour is intentional. TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSParameterOffset) { const char* hlo = R"( HloModule test, replica_count=2 @@ -1806,18 +1803,8 @@ TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSParameterOffset) { reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, param_2, constant_0) })"; - - const char* expected = R"( - // CHECK: %address-computation{{.+}} { - // CHECK: %[[RS:.+]] = f16[64,128]{1,0} reduce-scatter({{.+}}) - // CHECK: ROOT %{{.+}} = f16[128,128]{1,0} dynamic-update-slice(%{{.+}}, %[[RS]], %{{.+}}, %{{.+}}) - // CHECK: } - // CHECK: ENTRY {{.+}} { - // CHECK-NOT: reduce-scatter - // CHECK: ROOT %{{.+}} = {{.+}} fusion(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}), kind=kCustom, calls=%address-computation{{.+}}"name":"dynamic_address_computation" - // CHECK: } - )"; - RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), + std::nullopt); } TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSLoopIterationOffset) { @@ -1881,4 +1868,197 @@ TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSLoopIterationOffset) { RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); } +TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) { + const char* hlo = R"( + HloModule test + + %Body { + param = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0) + p0 = get-tuple-element(param), index=0 + p1 = get-tuple-element(param), index=1 + p2 = get-tuple-element(param), index=2 + loop_iter = get-tuple-element(param), index=3 + + bitcast.41 = f16[8,8]{1,0} bitcast(p0) + bitcast.42 = f16[8,8]{1,0} bitcast(p1) + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1) + c0 = u32[] constant(0) + c_trip_count = u32[] constant(11) + compare = pred[] compare(loop_iter, c0), direction=LT + add = u32[] add(loop_iter, c_trip_count) + offset = u32[] select(compare, add, loop_iter) + dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, offset, c0, c0) + c1 = u32[] constant(1) + add2 = u32[] add(loop_iter, c1) + ROOT tuple = tuple(p0, p1, dus, u32[] add2) + } + + %Cond { + %param.1 = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0) + %i.1 = u32[] get-tuple-element(%param.1), index=3 + %trip_count = u32[] constant(11) + ROOT %done = pred[] compare(u32[] %i.1, u32[] %trip_count), direction=LT + } + + ENTRY %test { + %p0.1 = f16[1,8,8]{2,1,0} parameter(0) + %p1.1 = f16[1,8,8]{2,1,0} parameter(1) + %p2.1 = f16[4,8,8]{2,1,0} parameter(2) + %c0.1 = u32[] constant(0) + %initial_tuple = tuple(%p0.1, %p1.1, %p2.1, u32[] %c0.1) + ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"11"}} + })"; + + const char* expected = R"( + // CHECK: %Body{{.+}}{ + // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0) + // CHECK: %[[LOOP_ITER:.+]] = u32[] get-tuple-element(%[[PARAM]]), index=3 + // CHECK: %[[OFFSET:.+]] = u32[] select({{.+}}) + // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, {{.+}}, {{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%address-computation, {{.+}}"name":"dynamic_address_computation" + // CHECK: ROOT %tuple = {{.+}} tuple(%{{.+}}, %{{.+}}, %[[ADDRESS_COMPUTATION]], %{{.+}}) + // CHECK: } + // CHECK: ENTRY %test{{.+}}{ + // CHECK: ROOT %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"11"}} + } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + +TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmParameterOffset) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[1,8,8]{2,1,0} parameter(0) + p1 = f16[1,8,8]{2,1,0} parameter(1) + p2 = f16[4,8,8]{2,1,0} parameter(2) + p3 = s32[] parameter(3) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + bitcast.41 = f16[8,8]{1,0} bitcast(p0) + bitcast.42 = f16[8,8]{1,0} bitcast(p1) + + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1) + ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, p3, c0_s32, c0_s32) + })"; + + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), + std::nullopt); +} + +TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) { + const char* hlo = R"( + HloModule lax_scan + + // This is the HLO generated for the following: + // + // inp = jax.random.uniform(jax.random.key(128), (128, 128, 128)) + // init = jnp.identity(128) + // ans = jax.lax.scan(lambda carry, x : (init, x@carry), init, inp) + + Body { + arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(arg_tuple.15), index=0 + constant.21 = s32[] constant(1) + add.2 = s32[] add(get-tuple-element.16, constant.21) + get-tuple-element.30 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=4 + get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=2 + get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=3 + constant.23 = s32[] constant(0) + compare.2 = pred[] compare(get-tuple-element.16, constant.23), direction=LT + constant.22 = s32[] constant(128) + add.3 = s32[] add(get-tuple-element.16, constant.22) + select.1 = s32[] select(compare.2, add.3, get-tuple-element.16) + dynamic-slice.1 = f32[1,128,128]{2,1,0} dynamic-slice(get-tuple-element.19, select.1, constant.23, constant.23), dynamic_slice_sizes={1,128,128} + bitcast.72 = f32[128,128]{1,0} bitcast(dynamic-slice.1) + get-tuple-element.17 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=1 + custom-call.1 = (f32[128,128]{1,0}, s8[131072]{0}) custom-call(bitcast.72, get-tuple-element.17), custom_call_target="__cublas$gemm" + get-tuple-element = f32[128,128]{1,0} get-tuple-element(custom-call.1), index=0 + bitcast.77 = f32[1,128,128]{2,1,0} bitcast(get-tuple-element) + dynamic-update-slice.1 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.18, bitcast.77, select.1, constant.23, constant.23) + ROOT tuple.38 = tuple(add.2, get-tuple-element.30, dynamic-update-slice.1, get-tuple-element.19, get-tuple-element.30) + } // Body + + Cond { + arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + get-tuple-element.41 = s32[] get-tuple-element(arg_tuple.40), index=0 + constant.46 = s32[] constant(128) + ROOT compare.3 = pred[] compare(get-tuple-element.41, constant.46), direction=LT + } + + ENTRY main { + constant.4 = s32[] constant(0) + Arg_1.2 = f32[128,128]{1,0} parameter(1) + constant.5 = f32[] constant(0) + broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={} + Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2) + Arg_0.1 = f32[128,128]{1,0} parameter(0) + tuple.7 = tuple(constant.4, Arg_1.2, broadcast.1, Arg_2.3, Arg_0.1) + while.48 = while(tuple.7), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}} + get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while.48), index=1 + get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while.48), index=2 + ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51) + } // main.55 + +)"; + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + const char* expected = R"( + // CHECK: %address-computation{{.*}} {{.+}} { + // CHECK: {{.+}} = {{.+}}dynamic-slice + // CHECK: {{.+}} = {{.+}}custom-call + // CHECK: {{.+}} = {{.+}}dynamic-update-slice + // CHECK: } + // CHECK: %Body{{.+}}{ + // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0) + // CHECK: %[[LOOP_ITER:.+]] = s32[] get-tuple-element(%[[PARAM]]), index=0 + // CHECK: %[[OFFSET:.+]] = s32[] select({{.+}}) + // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%address-computation{{.+}}"name":"dynamic_address_computation" + // CHECK: %[[GTE:.+]] = {{.+}} get-tuple-element(%[[ADDRESS_COMPUTATION]]), index=0 + // CHECK: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %[[GTE]], %{{.+}}) + // CHECK: } + // CHECK: ENTRY %main{{.+}}{ + // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"128"}} + // CHECK: } + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc index 08018305cfbf47..65326580472470 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/triton_fusion_analysis.h" @@ -783,7 +784,6 @@ absl::StatusOr RunOnComputation( return visitor.changed(); } - } // namespace bool ShouldTritonHandleGEMM(HloDotInstruction& dot, diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc index 85ad2e8f530ef8..f72650ef7ff9aa 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/transforms/gemm_fusion.h" #include +#include #include #include @@ -1329,6 +1330,85 @@ ENTRY main { EXPECT_FALSE(result.ok()); } +constexpr auto kInt4Dot = R"( +ENTRY e { + p0 = s8[16,16] parameter(0) + p1 = s4[16,16] parameter(1) + p1c = bf16[16,16] convert(p1) + ROOT dot = bf16[16,16] dot(p0, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + +TEST_F(SmallDotGemmFusionTest, Int4DotIsRewritten) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kInt4Dot)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_triton_gemm_int4(true); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); +} + +TEST_F(SmallDotGemmFusionTest, Int4DotIsNotRewritten) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kInt4Dot)); + EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value()); +} + +TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) { + const std::string kInt4Dot = R"( + ENTRY main { + lhs1 = s4[4,1024]{1,0} parameter(0) + lhs2 = s4[4,1024]{1,0} parameter(1) + rhs = bf16[1024,4]{1,0} parameter(2) + lhs_concat = s4[8,1024]{1,0} concatenate(lhs1, lhs2), dimensions={0} + lhs_converted = bf16[8,1024]{1,0} convert(lhs_concat) + ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kInt4Dot)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_triton_gemm_int4(true); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + // Check that the fusion is present and that the lhs is not converted. + MatchHloModule(*module, R"( +CHECK: gemm_fusion_dot_computation +CHECK: %parameter_0 = s4[8,1024]{1,0} parameter(0) +CHECK: ENTRY +CHECK-DAG: ROOT {{.*}} = bf16[8,4]{1,0} fusion(s4[8,1024]{1,0} %lhs_concat, bf16[1024,4]{1,0} %rhs) +})"); +} + +TEST_F(SmallDotGemmFusionTest, Int4ConvertPlusNegateIsRewritten) { + const std::string kInt4Dot = R"( + ENTRY main { + lhs = s4[8,1024]{1,0} parameter(0) + rhs = f32[1024,4]{1,0} parameter(1) + lhs_converted = f32[8,1024]{1,0} convert(lhs) + lhs_negated = f32[8,1024]{1,0} negate(lhs_converted) + ROOT dot = f32[8,4]{1,0} dot(lhs_negated, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kInt4Dot)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_triton_gemm_int4(true); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + // Check that the fusion is present and that convert and negation is fused in + // it. + MatchHloModule(*module, R"( +CHECK: gemm_fusion_dot_computation +CHECK: %parameter_0 = s4[8,1024]{1,0} parameter(0) +CHECK: ENTRY +CHECK-DAG: ROOT {{.*}} = f32[8,4]{1,0} fusion(s4[8,1024]{1,0} %lhs, f32[1024,4]{1,0} %rhs) +})"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 1cd214ec275256..82895a5b3ae967 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -44,6 +45,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" @@ -550,10 +552,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { public: explicit GemmRewriterVisitor(const se::GpuComputeCapability &gpu_version, const int32_t toolkit_version, - const bool f8_rewrite) + const GemmRewriterOptions options) : gpu_version_(gpu_version), toolkit_version_(toolkit_version), - f8_rewrite_(f8_rewrite) {} + options_(options) {} absl::Status HandleDot(HloInstruction *instr) override { if (!IsMatrixMultiplication(*instr) && @@ -618,50 +620,54 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gemm_backend_config.set_lhs_stride(lhs_stride); gemm_backend_config.set_rhs_stride(rhs_stride); - if (f8_rewrite_) { - // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call. - TF_ASSIGN_OR_RETURN( - bool supported_by_cublaslt, - GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); - std::optional a, b; - if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot && - (a = MatchFp8Param( - const_cast(instr->operand(0)))) && - (b = MatchFp8Param( - const_cast(instr->operand(1))))) { - if (IsRocm(gpu_version_) && toolkit_version_ < 60200 && - instr->shape().element_type() != F16 && - instr->shape().element_type() != F32) { - TF_ASSIGN_OR_RETURN(instr, - TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); + switch (options_.dtype) { + case GemmRewriterOptions::DType::kFp8Only: { + // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call. + TF_ASSIGN_OR_RETURN( + bool supported_by_cublaslt, + GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); + std::optional a, b; + if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot && + (a = MatchFp8Param( + const_cast(instr->operand(0)))) && + (b = MatchFp8Param( + const_cast(instr->operand(1))))) { + if (IsRocm(gpu_version_) && toolkit_version_ < 60200 && + instr->shape().element_type() != F16 && + instr->shape().element_type() != F32) { + TF_ASSIGN_OR_RETURN( + instr, TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); + } + TF_ASSIGN_OR_RETURN(bool created_call, + CreateF8CustomCall(instr, gpu_backend_config, + a.value(), b.value())); + if (created_call) { + return absl::OkStatus(); + } } - TF_ASSIGN_OR_RETURN(bool created_call, - CreateF8CustomCall(instr, gpu_backend_config, - a.value(), b.value())); - if (created_call) { - return absl::OkStatus(); + if (IsF8Type(instr->operand(0))) { + // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt + // custom call, so turn into an FP16 dot which may be rewritten as an + // FP16 Triton, cublas or cublasLt call. + TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); } + break; } - if (IsF8Type(instr->operand(0))) { - // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt - // custom call, so turn into an FP16 dot which may be rewritten as an - // FP16 Triton, cublas or cublasLt call. - TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); - } - } else { - // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. - TF_ASSIGN_OR_RETURN( - absl::string_view gemm_custom_call_target, - GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); - const Shape &output_shape = instr->shape(); - HloInstruction *gemm_call = - instr->AddInstruction(HloInstruction::CreateCustomCall( - output_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - gemm_custom_call_target)); - TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); - } + case GemmRewriterOptions::DType::kNonFp8Only: { + // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. + TF_ASSIGN_OR_RETURN( + absl::string_view gemm_custom_call_target, + GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); + const Shape &output_shape = instr->shape(); + HloInstruction *gemm_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, + {instr->mutable_operand(0), instr->mutable_operand(1)}, + gemm_custom_call_target)); + TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); + } break; + }; return absl::OkStatus(); } @@ -757,6 +763,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } absl::Status HandleAdd(HloInstruction *instr) override { + if (options_.bias_mode == GemmRewriterOptions::BiasMode::kNoBias) { + // See comments for `GemmRewriterOptions::BiasMode` for details. + return absl::OkStatus(); + } + HloInstruction *bias, *existing_gemm = nullptr; HloInstruction *optional_slice = nullptr; HloInstruction *optional_convert = nullptr; @@ -1062,8 +1073,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - absl::Span batch_dims = + absl::Span a_batch_dims = + gemm_backend_config.dot_dimension_numbers().lhs_batch_dimensions(); + absl::Span b_batch_dims = gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions(); + const size_t num_batch_dims = a_batch_dims.size(); // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32 // format. Set the factors to one when no scaling factors were captured. @@ -1129,22 +1143,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { "dimension."; return false; } - if ((a.commutative_ops.empty() ? a.fp8_input - : a.commutative_ops.back().first) - ->shape() - .dimensions_size() - - batch_dims.size() != - 2 || - (b.commutative_ops.empty() ? b.fp8_input - : b.commutative_ops.back().first) - ->shape() - .dimensions_size() - - batch_dims.size() != - 2) { - VLOG(1) << "Failed to rewrite " << instr->ToShortString() - << "into FP8 Custom Call. A and B must have one non-contracting " - "dimension."; - return false; + for (const MatchedFp8Param ¶m : {a, b}) { + const HloInstruction *input = param.commutative_ops.empty() + ? param.fp8_input + : param.commutative_ops.back().first; + if (input->shape().rank() != num_batch_dims + 2) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << "into FP8 Custom Call. Inputs must have exactly one " + "contracting and one non-contracting dimension."; + return false; + } } // Sequentially apply the collected unary, dynamic-slice, pad and select ops @@ -1192,49 +1200,49 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { shift_ops(a.fp8_input, a.commutative_ops); shift_ops(b.fp8_input, b.commutative_ops); - TF_ASSIGN_OR_RETURN(bool a_is_col_major, - MatrixIsColumnMajor(*instr, gemm_backend_config, "a")); - TF_ASSIGN_OR_RETURN(bool b_is_col_major, - MatrixIsColumnMajor(*instr, gemm_backend_config, "b")); + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, + GemmConfig::For(instr, gemm_backend_config)); DotDimensionNumbers *dim_nums = gemm_backend_config.mutable_dot_dimension_numbers(); - int batch_dim_offset = batch_dims.size(); // cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to // be row-major. If A is column-major, swap the contracting and // non-contracting dimension and transpose the matrix to effectively make it // column-major. // TODO(philipphack): Remove once cuBLASLt supports A being column-major - if (a_is_col_major) { - CHECK(a_contracting_dims[0] == batch_dim_offset || - a_contracting_dims[0] == batch_dim_offset + 1); - if (a_contracting_dims[0] == batch_dim_offset) { - dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset + 1); + if (gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor) { + CHECK(a_contracting_dims[0] == num_batch_dims || + a_contracting_dims[0] == num_batch_dims + 1); + if (a_contracting_dims[0] == num_batch_dims) { + dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims + 1); } else { - dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset); + dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims); } a.fp8_input = - TransposeMatrix(a.fp8_input, a_contracting_dims[0], batch_dims); + TransposeMatrix(a.fp8_input, a_contracting_dims[0], a_batch_dims); } // Similarly, cuBLASLt requires the second operand to be column-major, so // make it column-major if it is currently row-major. - if (!b_is_col_major) { - CHECK(b_contracting_dims[0] == batch_dim_offset || - b_contracting_dims[0] == batch_dim_offset + 1); - if (b_contracting_dims[0] == batch_dim_offset) { - dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset + 1); + if (gemm_config.rhs_layout.order == MatrixLayout::Order::kRowMajor) { + CHECK(b_contracting_dims[0] == num_batch_dims || + b_contracting_dims[0] == num_batch_dims + 1); + if (b_contracting_dims[0] == num_batch_dims) { + dim_nums->set_rhs_contracting_dimensions(0, num_batch_dims + 1); } else { - dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset); + dim_nums->set_rhs_contracting_dimensions(0, num_batch_dims); } b.fp8_input = - TransposeMatrix(b.fp8_input, b_contracting_dims[0], batch_dims); + TransposeMatrix(b.fp8_input, b_contracting_dims[0], b_batch_dims); } - a.fp8_input = PadOperandToMultipleOf16(batch_dims, a.fp8_input); - b.fp8_input = PadOperandToMultipleOf16(batch_dims, b.fp8_input); - Shape new_output_shape = PadShapeToMultipleOf16(instr->shape(), batch_dims); + a.fp8_input = PadOperandToMultipleOf16(a_batch_dims, a.fp8_input); + b.fp8_input = PadOperandToMultipleOf16(b_batch_dims, b.fp8_input); + std::vector out_batch_dims(num_batch_dims); + std::iota(out_batch_dims.begin(), out_batch_dims.end(), 0); + Shape new_output_shape = + PadShapeToMultipleOf16(instr->shape(), out_batch_dims); std::vector operands_list = { a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one}; @@ -1820,7 +1828,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { private: se::GpuComputeCapability gpu_version_; int32_t toolkit_version_; - bool f8_rewrite_; + GemmRewriterOptions options_; // Choose cublas or cublasLt for the target of the custom call that instr will // be rewritten into. @@ -2120,47 +2128,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { output_dtype)); } - absl::StatusOr MatrixIsColumnMajor( - const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config, - const std::string matrix_name = "output") const { - const HloInstruction *lhs = instr.operand(0); - const HloInstruction *rhs = instr.operand(1); - - const DotDimensionNumbers &dot_dims = - gemm_backend_config.dot_dimension_numbers(); - // We use ALG_UNSET and kDefaultComputePrecision because we don't care about - // the precision, just the layout, since we're just checking if the matrix - // is column-major. - TF_ASSIGN_OR_RETURN( - GemmConfig gemm_config, - GemmConfig::For( - lhs->shape(), dot_dims.lhs_batch_dimensions(), - dot_dims.lhs_contracting_dimensions(), rhs->shape(), - dot_dims.rhs_batch_dimensions(), - dot_dims.rhs_contracting_dimensions(), - /*output_shape=*/instr.shape(), gemm_backend_config.alpha_real(), - gemm_backend_config.alpha_imag(), gemm_backend_config.beta(), - /*precision_algorithm=*/PrecisionConfig::ALG_UNSET, - /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision, - gemm_backend_config.grad_x(), gemm_backend_config.grad_y())); - - if (matrix_name == "lhs" || matrix_name == "a") { - return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor; - } else if (matrix_name == "rhs" || matrix_name == "b") { - return gemm_config.rhs_layout.order == MatrixLayout::Order::kColumnMajor; - } else if (matrix_name == "output" || matrix_name == "d") { - return gemm_config.output_layout.order == - MatrixLayout::Order::kColumnMajor; - } else { - return Internal("Invalid matrix name."); - } - } - absl::StatusOr GemmIsSupportedByCublasLt( const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config) const { const HloInstruction *lhs = instr.operand(0); - const HloInstruction *rhs = instr.operand(1); const Shape &output_shape = instr.shape(); TF_ASSIGN_OR_RETURN( @@ -2187,9 +2158,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } - TF_ASSIGN_OR_RETURN(bool output_is_column_major, - MatrixIsColumnMajor(instr, gemm_backend_config)); - if (auto isrocm = std::get_if(&gpu_version_); isrocm) { if (!isrocm->has_hipblaslt()) { @@ -2206,10 +2174,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (std::holds_alternative(gpu_version_)) { - auto cuda_compute_capability_ = - std::get(gpu_version_); - if (cuda_compute_capability_.IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + if (std::get(gpu_version_).IsAtLeastAmpere()) { // cuBlasLt has an implementation for complex data with compute type // 32F_FAST_32TF that uses tensor cores and that is free from the // restriction. This implementation only works on Ampere @@ -2217,36 +2182,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return true; } } - // Get the rhs non-contracting dimensions as they will eventually be at the - // cublasLt level. - std::vector rhs_non_contracting_dims; - const DotDimensionNumbers &dot_dims = - gemm_backend_config.dot_dimension_numbers(); - - if (!output_is_column_major) { - // cublasLt's matmul output is column major by default. This gemm requires - // the output to be in row major. Later we will swap lhs & rhs (and - // transpose each operand) of this gemm. Since we care about the rhs at - // the cublasLt level, this swap means that we care about the lhs right - // here. - TF_ASSIGN_OR_RETURN( - rhs_non_contracting_dims, - GetNonContractingDims(lhs->shape(), dot_dims.lhs_batch_dimensions(), - dot_dims.lhs_contracting_dimensions())); - } else { - TF_ASSIGN_OR_RETURN( - rhs_non_contracting_dims, - GetNonContractingDims(rhs->shape(), dot_dims.rhs_batch_dimensions(), - dot_dims.rhs_contracting_dimensions())); - } - const auto lhs_non_contracting_dimension_size = absl::c_accumulate( - rhs_non_contracting_dims, 1, [&](int64_t size, int64_t dim) { - return size * lhs->shape().dimensions(dim); - }); + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, + GemmConfig::For(&instr, gemm_backend_config)); // Check that the size of the non-contracting dimension is not too large. - return lhs_non_contracting_dimension_size <= kMaxDimensionSize; + return gemm_config.rhs_layout.num_cols <= kMaxDimensionSize; } // Turns an F8 dot with unsupported output type into an F8 dot with F32 @@ -2263,16 +2204,20 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return f32_dot; } - // Turns an F8 dot into an F16 dot, converting operands to F16 and + // Turns an F8 dot into an F16 dot, converting operands to F16 (or BF16) and // converting the output back to F8. absl::StatusOr TurnF8DotIntoF16Dot(HloInstruction *instr) { DCHECK(IsF8Type(instr->operand(0))); DCHECK(IsF8Type(instr->operand(1))); - // Convert operands to F16 + // If the output type is BF16, the input types have to be BF16 as well. + PrimitiveType conv_type = + instr->shape().element_type() == BF16 ? BF16 : F16; + + // Convert operands to F16 (or BF16). for (int i = 0; i < 2; ++i) { Shape operand_f16_shape = instr->operand(i)->shape(); - operand_f16_shape.set_element_type(F16); + operand_f16_shape.set_element_type(conv_type); HloInstruction *convert = instr->AddInstruction(HloInstruction::CreateConvert( operand_f16_shape, instr->mutable_operand(i))); @@ -2395,8 +2340,8 @@ class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { absl::StatusOr RunOnComputation(HloComputation *computation, se::GpuComputeCapability gpu_version, int32_t toolkit_version, - bool f8_rewrite) { - GemmRewriterVisitor visitor(gpu_version, toolkit_version, f8_rewrite); + GemmRewriterOptions options) { + GemmRewriterVisitor visitor(gpu_version, toolkit_version, options); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version); TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor)); @@ -2406,10 +2351,10 @@ absl::StatusOr RunOnComputation(HloComputation *computation, } // anonymous namespace GemmRewriter::GemmRewriter(se::GpuComputeCapability gpu_version, - int32_t toolkit_version, bool f8_rewrite) + int32_t toolkit_version, GemmRewriterOptions options) : gpu_version_(gpu_version), toolkit_version_(toolkit_version), - f8_rewrite_(f8_rewrite) {} + options_(options) {} absl::StatusOr GemmRewriter::Run( HloModule *module, @@ -2419,7 +2364,7 @@ absl::StatusOr GemmRewriter::Run( module->MakeNonfusionComputations(execution_threads)) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, gpu_version_, - toolkit_version_, f8_rewrite_)); + toolkit_version_, options_)); changed |= result; } return changed; diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h index e3260fdee45f1a..cce09c45c464f6 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h @@ -45,12 +45,40 @@ namespace gpu { // (we assume transposes are already folded), and rewrites it into a custom call // where (A, B, C) are three operands respectively, and `alpha` and `beta` are // stored in the backend config. + +struct GemmRewriterOptions { + // The DType of the GEMM to rewrite. + enum class DType { kFp8Only, kNonFp8Only }; + DType dtype = DType::kNonFp8Only; + + // Disabling bias prevents using the `beta * C` term the GEMM, which can + // remove dependencies between multiple matrix multiplications. This, in + // turn, can improve the performance of overall computation by allowing + // multiple GEMMs to be scheduled in parallel. + // + // As an example, consider the following computation: `(A * A) + (B * B)`. + // With bias enabled, the `GemmRewriter` will emit the following GEMMs: + // + // AA := GEMM(A * A) + // ROOT := GEMM(B * B + AA) + // + // Because the second GEMM depends on the first, they cannot be scheduled in + // parallel. Instead, with bias disabled, the `GemmRewriter` will emit the + // following: + // + // AA := GEMM(A * A) + // BB := GEMM(B * B) + // ROOT := AA + BB + // + // In this case, the two GEMMs can be scheduled in parallel. + enum class BiasMode { kBias, kNoBias }; + BiasMode bias_mode = BiasMode::kBias; +}; + class GemmRewriter : public HloModulePass { public: - // When f8_rewrite is true, only FP8 GEMMs are rewritten. Otherwise, non-FP8 - // GEMMs are rewritten. GemmRewriter(se::GpuComputeCapability gpu_version, int32_t toolkit_version, - bool f8_rewrite = false); + GemmRewriterOptions options = {}); absl::string_view name() const override { return "cublas-gemm-rewriter"; } using HloPassInterface::Run; @@ -61,7 +89,7 @@ class GemmRewriter : public HloModulePass { private: se::GpuComputeCapability gpu_version_; int32_t toolkit_version_; - bool f8_rewrite_; + GemmRewriterOptions options_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc index 2d07d518616376..c0098fc808f8f6 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -4805,6 +4805,35 @@ class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest { static constexpr const char* kF8E4M3AmaxPlaceholder{"<>"}; }; +TEST_P(ParameterizedFp8GemmRewriteTest, SupportsF8NonMajorBatchDim) { + const char* hlo_text = R"( +HloModule t + +ENTRY main { + %bitcast.73421 = f8e4m3fn[16,8,640]{2,1,0} parameter(0) + %parameter_1.5 = f8e4m3fn[8,640,5120]{2,1,0} parameter(1) + %parameter_2 = f8e4m3fn[8,640,5120]{2,1,0} parameter(2) + %concatenate.2145 = f8e4m3fn[8,640,10240]{2,1,0} concatenate( + f8e4m3fn[8,640,5120]{2,1,0} %parameter_1.5, + f8e4m3fn[8,640,5120]{2,1,0} %parameter_2), + dimensions={2} + %dot.6237 = f32[8,16,10240]{2,1,0} dot( + f8e4m3fn[16,8,640]{2,1,0} %bitcast.73421, + f8e4m3fn[8,640,10240]{2,1,0} %concatenate.2145), + lhs_batch_dims={1}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + ROOT %convert.20480 = bf16[8,16,10240]{2,1,0} convert( + f32[8,16,10240]{2,1,0} %dot.6237) +})"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: custom-call({{.*}}"lhs_batch_dimensions":["1"],"rhs_batch_dimensions":["0"] + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteToF8OnPreAda) { if (HasFp8Support()) { GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300."; @@ -4879,7 +4908,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) { ErrorSpec{1e-2, 1e-2})); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(Capability(), GetToolkitVersion(), /*f8_rewrite=*/true), + GemmRewriter(Capability(), GetToolkitVersion(), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: <>[16,16], {{.*}}: <>[16,16]) -> <>[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,16]{1,0} parameter(0) @@ -4915,7 +4945,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> <>[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -4977,7 +5007,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: <>[16,16]) -> <>[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5039,7 +5069,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5100,7 +5130,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[13,17], {{.*}}: <>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[13,17]{1,0} parameter(0) @@ -5167,7 +5197,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -5205,7 +5235,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5269,7 +5299,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -5337,7 +5367,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3], {{.*}}: <>[32,16]) -> f32[16,16] { @@ -5403,7 +5433,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -5411,7 +5441,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[32,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[32,32]{1,0} parameter(0) @@ -5475,7 +5505,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -5483,7 +5513,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5551,7 +5581,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_FALSE(changed); } @@ -5588,7 +5618,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[10,16,32], {{.*}}: <>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[10,16,32]{2,1,0} parameter(0) @@ -5652,7 +5682,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -5717,7 +5747,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -5803,7 +5833,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5904,7 +5934,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5982,7 +6012,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -6025,7 +6055,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -6093,7 +6123,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[14,31], {{.*}}: <>[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] { @@ -6165,7 +6195,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> <>[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6223,7 +6253,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6280,7 +6310,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6339,7 +6369,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6412,7 +6442,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6486,7 +6516,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-NOT: divide @@ -6538,7 +6568,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6626,7 +6656,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<>[16,16], f16[]) { @@ -6704,7 +6734,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <>[16,16] { @@ -6779,7 +6809,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6848,7 +6878,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6915,7 +6945,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6929,7 +6959,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,16,16]{2,1,0} parameter(0) @@ -7001,7 +7031,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -7017,7 +7047,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,15,15], {{.*}}: <>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,15,15]{2,1,0} parameter(0) @@ -7093,7 +7123,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -7107,7 +7137,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,16,16]{2,1,0} parameter(0) @@ -7175,7 +7205,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -7191,7 +7221,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3,15,15], {{.*}}: <>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[3,15,15]{2,1,0} parameter(0) @@ -7267,14 +7297,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[48,16], {{.*}}: <>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[48,16]{1,0} parameter(0) @@ -7343,7 +7373,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7411,7 +7441,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7477,7 +7507,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7545,7 +7575,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { ; CHECK-DAG: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7632,7 +7662,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7719,7 +7749,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<>[16,16], f16[]) { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7809,7 +7839,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7950,7 +7980,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -8026,7 +8056,7 @@ ENTRY f { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -8065,7 +8095,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -8102,7 +8132,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_FALSE(changed); #endif @@ -8111,7 +8141,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fnuz[16,32]{1,0} parameter(0) @@ -8207,6 +8237,114 @@ ENTRY main { )"); } +TEST_F(GemmRewriteTest, DotWithBias) { + const char* hlo = R"( + HloModule m + + ENTRY main { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + p2 = f32[1024,1024] parameter(2) + p3 = f32[1024,1024] parameter(3) + dot0 = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot1 = f32[1024,1024] dot(p2, p3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT root = f32[1024,1024] add(dot0, dot1) + })"; + + const char* expected = R"() + // CHECK: %[[P0:.*]] = f32[1024,1024]{1,0} parameter(0) + // CHECK: %[[P1:.*]] = f32[1024,1024]{1,0} parameter(1) + // CHECK: %[[P2:.*]] = f32[1024,1024]{1,0} parameter(2) + // CHECK: %[[P3:.*]] = f32[1024,1024]{1,0} parameter(3) + // CHECK: %[[TUPLE0:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P2]], %[[P3]]) + // CHECK: %[[S0:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE0]]), index=0 + // CHECK: %[[TUPLE1:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]], %[[S0]]) + // CHECK: ROOT %[[S1:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE1]]), index=0 + })"; + + RunAndFilecheckHloRewrite( + hlo, + GemmRewriter( + se::CudaComputeCapability{}, /*toolkit_version=*/0, + GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only}), + expected); +} + +TEST_F(GemmRewriteTest, DotWithoutBias) { + const char* hlo = R"( + HloModule m + + ENTRY main { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + p2 = f32[1024,1024] parameter(2) + p3 = f32[1024,1024] parameter(3) + dot0 = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot1 = f32[1024,1024] dot(p2, p3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT root = f32[1024,1024] add(dot0, dot1) + })"; + + const char* expected = R"() + // CHECK: %[[P0:.*]] = f32[1024,1024]{1,0} parameter(0) + // CHECK: %[[P1:.*]] = f32[1024,1024]{1,0} parameter(1) + // CHECK: %[[TUPLE0:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]) + // CHECK: %[[S0:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE0]]), index=0 + // CHECK: %[[P2:.*]] = f32[1024,1024]{1,0} parameter(2) + // CHECK: %[[P3:.*]] = f32[1024,1024]{1,0} parameter(3) + // CHECK: %[[TUPLE1:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P2]], %[[P3]]) + // CHECK: %[[S1:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE1]]), index=0 + // CHECK: ROOT %[[S2:.*]] = f32[1024,1024]{1,0} add(%[[S0]], %[[S1]]) + })"; + + RunAndFilecheckHloRewrite( + hlo, + GemmRewriter(se::CudaComputeCapability{}, /*toolkit_version=*/0, + GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only, + GemmRewriterOptions::BiasMode::kNoBias}), + expected); +} + +TEST_F(CublasLtGemmRewriteTest, CublasLtSuccessfullyMatchesLargeC64Lhs) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + p0 = c64[2000,3000,3]{2,1,0} parameter(0) + p1 = c64[3,6]{1,0} parameter(1) + ROOT dot = c64[2000,3000,6]{2,1,0} dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={0} +} +)"; + // Large lhs is fine for cuBLASlt. + MatchOptimizedHlo(hlo_text, + R"(; CHECK: custom_call_target="__cublas$lt$matmul")"); +} + +TEST_F(CublasLtGemmRewriteTest, CublasLtOnlyMatchesLargeC64RhsPostAmpere) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + p0 = c64[6,3]{1,0} parameter(0) + p1 = c64[3,2000,3000]{2,1,0} parameter(1) + ROOT dot = c64[6,2000,3000]{2,1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + if (HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) { + // From Ampere onwards, cuBLASlt supports large rhs. + MatchOptimizedHlo(hlo_text, + R"(; CHECK: custom_call_target="__cublas$lt$matmul")"); + } else { + // Rhs with non-contracting dimensions > 4194240 (combined) is not fine for + // C64 type. + MatchOptimizedHlo( + hlo_text, R"(; CHECK-NOT: custom_call_target="__cublas$lt$matmul")"); + } +} + class GemmRewriteAllocationTest : public GpuCodegenTest { public: void CheckNumberOfAllocations(const std::string& hlo, diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc similarity index 97% rename from third_party/xla/xla/service/gpu/horizontal_input_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc index c6938569686611..befe869ac072df 100644 --- a/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/horizontal_input_fusion.h" +#include "xla/service/gpu/transforms/horizontal_input_fusion.h" #include #include @@ -169,13 +169,13 @@ absl::StatusOr HorizontalInputFusionImpl::Run() { } // namespace -absl::StatusOr GpuHorizontalInputFusion::RunOnComputation( +absl::StatusOr HorizontalInputFusion::RunOnComputation( HloComputation* computation) { HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_); return horizontal_fusion_impl.Run(); } -absl::StatusOr GpuHorizontalInputFusion::Run( +absl::StatusOr HorizontalInputFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h similarity index 74% rename from third_party/xla/xla/service/gpu/horizontal_input_fusion.h rename to third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h index 370ce7bd0509af..a08168d4c3f5a5 100644 --- a/third_party/xla/xla/service/gpu/horizontal_input_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ -#define XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -29,24 +29,22 @@ namespace gpu { // This optimization pass horizontally fuses kInput fusions to both reduce the // kernel launch overhead and increase parallelism degree. See -// GpuHorizontalFusion for general description and motivation about horizontal -// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals +// HorizontalLoopFusion for general description and motivation about horizontal +// fusion. HorizontalLoopFusion deals with kLoop fusions while this pass deals // with kInput fusions. // -// Following GpuHorizontalFusion, a simple yet effective heuristic is used +// Following HorizontalLoopFusion, a simple yet effective heuristic is used // to search the fusion candidates while avoiding creating cycles. That is, // we simply search for fusion candidates by looking for instructions whose // outputs are all consumed by the same instruction. This catches the typical // target cases; often, the candidate instructions are just consumed by the // ROOT tuple of the entry computation. -class GpuHorizontalInputFusion : public HloModulePass { +class HorizontalInputFusion : public HloModulePass { public: - explicit GpuHorizontalInputFusion(const se::DeviceDescription& d) + explicit HorizontalInputFusion(const se::DeviceDescription& d) : device_info_(d) {} - absl::string_view name() const override { - return "gpu_horizontal_input_fusion"; - } + absl::string_view name() const override { return "horizontal_input_fusion"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -62,4 +60,4 @@ class GpuHorizontalInputFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc index 2d458f9db452d1..5fc1a54acd8d53 100644 --- a/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/horizontal_input_fusion.h" +#include "xla/service/gpu/transforms/horizontal_input_fusion.h" #include #include @@ -42,7 +42,7 @@ class HorizontalInputFusionTest : public GpuCodegenTest { public: se::DeviceDescription device_description_{ TestGpuDeviceInfo::RTXA6000DeviceInfo()}; - GpuHorizontalInputFusion horizontal_input_fusion_{device_description_}; + HorizontalInputFusion horizontal_input_fusion_{device_description_}; }; TEST_F(HorizontalInputFusionTest, BasicTest) { diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc similarity index 99% rename from third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index 80c46cb7a5d5af..0a3d705103c416 100644 --- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/horizontal_loop_fusion.h" #include #include @@ -713,13 +713,13 @@ absl::StatusOr HorizontalLoopFusionImpl::Run() { } // namespace -absl::StatusOr GpuHorizontalLoopFusion::RunOnComputation( +absl::StatusOr HorizontalLoopFusion::RunOnComputation( HloComputation* computation) { HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_); return horizontal_fusion_impl.Run(); } -absl::StatusOr GpuHorizontalLoopFusion::Run( +absl::StatusOr HorizontalLoopFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Run horizontal fusion."; diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h similarity index 92% rename from third_party/xla/xla/service/gpu/horizontal_loop_fusion.h rename to third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h index 5daed0378aa903..f29bcd31044991 100644 --- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ -#define XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_ #include @@ -122,15 +122,12 @@ namespace gpu { // outputs of Mul and Add are row-major. // // Note, reshapes are added only if the tensors isn't already a vector. -class GpuHorizontalLoopFusion : public HloModulePass { +class HorizontalLoopFusion : public HloModulePass { public: - GpuHorizontalLoopFusion() = default; - explicit GpuHorizontalLoopFusion(absl::string_view prefix) - : prefix_(prefix) {} + HorizontalLoopFusion() = default; + explicit HorizontalLoopFusion(absl::string_view prefix) : prefix_(prefix) {} - absl::string_view name() const override { - return "gpu_horizontal_loop_fusion"; - } + absl::string_view name() const override { return "horizontal_loop_fusion"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -145,4 +142,4 @@ class GpuHorizontalLoopFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc index 4045183dcf0867..781d27a64d716c 100644 --- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/horizontal_loop_fusion.h" #include #include @@ -27,7 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/instruction_fusion.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_fix.h" @@ -85,7 +85,7 @@ TEST_F(HorizontalLoopFusionTest, BasicTest) { )") .value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -136,7 +136,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) { )") .value(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { @@ -172,7 +172,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { )") .value(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { @@ -259,7 +259,7 @@ TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { )") .value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); int input_fusion_count = 0; int loop_fusion_count = 0; @@ -308,7 +308,7 @@ TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) { fusion.AddPass(/*may_duplicate=*/true, device_info); EXPECT_TRUE(fusion.Run(module.get()).value()); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); VLOG(2) << "Dump after horizontal fusion:"; @@ -415,7 +415,7 @@ TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) { )") .value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -545,7 +545,7 @@ TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) { })") .value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -586,7 +586,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) { )") .value(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { @@ -627,7 +627,7 @@ TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { .value(); HloPassFix iterative_h_fusion("iterative_h_fusion"); - iterative_h_fusion.AddPass(); + iterative_h_fusion.AddPass(); iterative_h_fusion.AddPass(); EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value()); @@ -699,7 +699,7 @@ TEST_F(HorizontalLoopFusionTest, TraversalOrder) { .value(); HloPassFix iterative_h_fusion("iterative_h_fusion"); - iterative_h_fusion.AddPass(); + iterative_h_fusion.AddPass(); EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value()); // Verify that the total number of fusion instructions is 2 so that we @@ -773,7 +773,7 @@ ENTRY main { )"; auto module = ParseAndReturnUnverifiedModule(hlo_text).value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); VLOG(2) << module->ToString(); @@ -843,7 +843,7 @@ TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) { })") .value(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/instruction_fusion.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc similarity index 99% rename from third_party/xla/xla/service/gpu/instruction_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc index 8751d44f8972ea..5e32f2ec0c2ee1 100644 --- a/third_party/xla/xla/service/gpu/instruction_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/instruction_fusion.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" #include #include diff --git a/third_party/xla/xla/service/gpu/instruction_fusion.h b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h similarity index 94% rename from third_party/xla/xla/service/gpu/instruction_fusion.h rename to third_party/xla/xla/service/gpu/transforms/instruction_fusion.h index 29eb0325e1a23b..d7fb7f2cb47ded 100644 --- a/third_party/xla/xla/service/gpu/instruction_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ -#define XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_ #include @@ -79,4 +79,4 @@ class GpuInstructionFusion : public InstructionFusion { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/instruction_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/instruction_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc index fa96edfd364aa2..140cc6e52641ea 100644 --- a/third_party/xla/xla/service/gpu/instruction_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/instruction_fusion.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" #include #include @@ -126,12 +126,14 @@ TEST_F(InstructionFusionTest, TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); - HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); - HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0)); - HloInstruction* transpose2 = builder.AddInstruction( - HloInstruction::CreateTranspose(ShapeUtil::MakeShape(F32, {}), exp1, {})); + Shape operand_shape = ShapeUtil::MakeShape(F32, {64, 32}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, operand_shape, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(operand_shape, HloOpcode::kExp, param)); + HloInstruction* transpose2 = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {32, 64}), exp1, {1, 0})); auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -464,7 +466,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) { .value(); // Multi-output fusion is disabled here and performed in the - // GpuMultiOutputFusion pass instead. + // MultiOutputFusion pass instead. ASSERT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value()); } diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gpu_layout_assignment.cc rename to third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index 008dbaeade1ab9..caa8d3c10f90e6 100644 --- a/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_layout_assignment.h" +#include "xla/service/gpu/transforms/layout_assignment.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment.h b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h similarity index 94% rename from third_party/xla/xla/service/gpu/gpu_layout_assignment.h rename to third_party/xla/xla/service/gpu/transforms/layout_assignment.h index 70741fea030efb..efa58f3f8c3c72 100644 --- a/third_party/xla/xla/service/gpu/gpu_layout_assignment.h +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ -#define XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_ #include #include @@ -78,4 +78,4 @@ class GpuLayoutAssignment : public LayoutAssignment { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc rename to third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 81f9e00548d9da..dd1cbc65bb3fde 100644 --- a/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_layout_assignment.h" +#include "xla/service/gpu/transforms/layout_assignment.h" #include #include diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users.cc b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc similarity index 99% rename from third_party/xla/xla/service/gpu/move_copy_to_users.cc rename to third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc index acc10db6af6927..ae66093da4507d 100644 --- a/third_party/xla/xla/service/gpu/move_copy_to_users.cc +++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/move_copy_to_users.h" +#include "xla/service/gpu/transforms/move_copy_to_users.h" #include diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users.h b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h similarity index 87% rename from third_party/xla/xla/service/gpu/move_copy_to_users.h rename to third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h index 4a7dfb43bbf6ec..698db0460602f1 100644 --- a/third_party/xla/xla/service/gpu/move_copy_to_users.h +++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_ -#define XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -36,4 +36,4 @@ class MoveCopyToUsers : public HloModulePass { } // end namespace xla -#endif // XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_ diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/move_copy_to_users_test.cc rename to third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc index 10179c1b32cacd..85999dbf63a5b5 100644 --- a/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/move_copy_to_users.h" +#include "xla/service/gpu/transforms/move_copy_to_users.h" #include diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc similarity index 96% rename from third_party/xla/xla/service/gpu/multi_output_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc index 6ac1217151aa65..35bfe8eb092038 100644 --- a/third_party/xla/xla/service/gpu/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/multi_output_fusion.h" +#include "xla/service/gpu/transforms/multi_output_fusion.h" #include #include @@ -307,13 +307,13 @@ FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, } // namespace -void GpuMultiOutputFusion::RecomputeReachability() { +void MultiOutputFusion::RecomputeReachability() { reachability_ = HloDfsReachability::Build(computation_); } -bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, - FusionInfoCache* fusion_info_cache, - GpuHloCostAnalysis* cost_analysis) { +bool MultiOutputFusion::FuseSiblings(HloInstruction* parent, + FusionInfoCache* fusion_info_cache, + GpuHloCostAnalysis* cost_analysis) { const HloComputation* computation = parent->parent(); const HloModule* module = computation->parent(); bool dump_fusion = @@ -402,7 +402,7 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, return changed; } -absl::StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { +absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { bool changed = false; RecomputeReachability(); GpuHloCostAnalysis cost_analysis({shape_size_function_, @@ -494,9 +494,9 @@ absl::StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { return changed; } -void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer, - absl::string_view label, - const HloInstruction* producer) { +void MultiOutputFusion::DumpFusionState(const HloInstruction& consumer, + absl::string_view label, + const HloInstruction* producer) { if (consumer.GetModule() ->config() .debug_options() @@ -505,7 +505,7 @@ void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer, } } -absl::StatusOr GpuMultiOutputFusion::Run( +absl::StatusOr MultiOutputFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.h b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h similarity index 93% rename from third_party/xla/xla/service/gpu/multi_output_fusion.h rename to third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h index 82789d3be5791d..9ebabe6b460000 100644 --- a/third_party/xla/xla/service/gpu/multi_output_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ -#define XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_ #include @@ -74,7 +74,7 @@ namespace gpu { // Note that sibling (1) and producer-consumer (2) multi-output fusion can be // combined. // -// The GpuMultiOutputFusion pass modifies the HLO in reverse post-order (defs +// The MultiOutputFusion pass modifies the HLO in reverse post-order (defs // before uses). First, it attempts to fuse the consumer ops of the current op, // which are siblings (1). Hereafter, it attempts to fuse the current op with // one of its consumers (2). This order avoids a phase ordering issue (described @@ -83,7 +83,7 @@ namespace gpu { // order of traversal, and hence, not get into the way of subsequent fusion // attempts. // -// The GpuMultiOutputFusion pass ensures several conditions are met for fusion. +// The MultiOutputFusion pass ensures several conditions are met for fusion. // Some of them are relevant for correctness. In particular, no cycles must be // introduced into the HLO module. Moreover, the code emitters for multi-output // fusion must support the combination of ops and their shapes. Other @@ -92,9 +92,9 @@ namespace gpu { // * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e. // the fusion kinds must match. -class GpuMultiOutputFusion : public HloModulePass { +class MultiOutputFusion : public HloModulePass { public: - explicit GpuMultiOutputFusion( + explicit MultiOutputFusion( const se::DeviceDescription& device_info, HloCostAnalysis::ShapeSizeFunction shape_size_function) : device_info_(device_info), shape_size_function_(shape_size_function) {} @@ -131,4 +131,4 @@ class GpuMultiOutputFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc similarity index 96% rename from third_party/xla/xla/service/gpu/multi_output_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc index 3cbaa26d49d723..4b6920464c8b51 100644 --- a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/multi_output_fusion.h" +#include "xla/service/gpu/transforms/multi_output_fusion.h" #include #include @@ -48,17 +48,15 @@ class MultiOutputFusionTest : public HloTestBase { } public: - GpuMultiOutputFusion mof_{ - TestGpuDeviceInfo::RTXA6000DeviceInfo(), - ShapeSizeBytesFunction()}; + MultiOutputFusion mof_{TestGpuDeviceInfo::RTXA6000DeviceInfo(), + ShapeSizeBytesFunction()}; - void CheckGpuMultiOutputFusion(absl::string_view hlo, - std::optional expected) { + void CheckMultiOutputFusion(absl::string_view hlo, + std::optional expected) { RunAndFilecheckHloRewrite( hlo, - GpuMultiOutputFusion{ - TestGpuDeviceInfo::RTXA6000DeviceInfo(), - ShapeSizeBytesFunction()}, + MultiOutputFusion{TestGpuDeviceInfo::RTXA6000DeviceInfo(), + ShapeSizeBytesFunction()}, expected); } }; @@ -179,7 +177,7 @@ ENTRY entry { ROOT root = (f32[512]{0}, f16[512]{0}) tuple(reduce.1, fusion) })"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation // CHECK-NEXT: [[param_0_2_0:%[^ ]+]] = f32[128,512,28,28]{3,2,1,0} parameter(0) // CHECK-NEXT: [[c_1_1:%[^ ]+]] = f16[128,512,28,28]{3,2,1,0} convert([[param_0_2_0]]) @@ -1781,7 +1779,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) { // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) // CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) @@ -1799,28 +1797,28 @@ TEST_F(TransposeMultiOutputFusionTest, MultipleTransposes) { HloModule module fused_computation { - param_0.1 = f32[16,32]{1,0} parameter(0) - s.1 = f32[16,32]{1,0} sqrt(param_0.1) - ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0} + param_0.1 = f32[1,16,32]{2,1,0} parameter(0) + s.1 = f32[1,16,32]{2,1,0} sqrt(param_0.1) + ROOT c.1 = f32[1,32,16]{2,1,0} transpose(s.1), dimensions={0,2,1} } ENTRY main { - p = f32[16,32]{1,0} parameter(0) - fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation - c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0} - ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(fusion, c1) + p = f32[1,16,32]{2,1,0} parameter(0) + fusion = f32[1,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation + c1 = f32[1,32,16]{2,1,0} transpose(p), dimensions={0,2,1} + ROOT t = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) tuple(fusion, c1) } )"; - CheckGpuMultiOutputFusion(hlo, R"( -// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[32,16], f32[32,16]) { -// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) -// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) -// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0} -// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[32,16]{1,0} transpose([[param_0_1_0]]), dimensions={1,0} -// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple([[c_1_2]], [[c1_1_3]]) + CheckMultiOutputFusion(hlo, R"( +// CHECK: %fused_computation (param_0.1: f32[1,16,32]) -> (f32[1,32,16], f32[1,32,16]) { +// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[1,16,32]{2,1,0} parameter(0) +// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[1,16,32]{2,1,0} sqrt([[param_0_1_0]]) +// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[s_1_1]]), dimensions={0,2,1} +// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[param_0_1_0]]), dimensions={0,2,1} +// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) tuple([[c_1_2]], [[c1_1_3]]) // CHECK-NEXT: } -// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] +// CHECK: [[fusion_0:%[^ ]+]] = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] )"); } @@ -1829,27 +1827,27 @@ TEST_F(TransposeMultiOutputFusionTest, CopyAndTranspose) { HloModule module fused_computation { - param_0.1 = f32[16,32]{1,0} parameter(0) - s.1 = f32[16,32]{1,0} sqrt(param_0.1) - ROOT c.1 = f32[16,32]{0,1} copy(s.1) + param_0.1 = f32[1,16,32]{2,1,0} parameter(0) + s.1 = f32[1,16,32]{2,1,0} sqrt(param_0.1) + ROOT c.1 = f32[1,16,32]{1,2,0} copy(s.1) } ENTRY main { - p = f32[16,32]{1,0} parameter(0) - fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation - c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0} - ROOT t = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple(fusion, c1) + p = f32[1,16,32]{2,1,0} parameter(0) + fusion = f32[1,16,32]{1,2,0} fusion(p), kind=kInput, calls=fused_computation + c1 = f32[1,32,16]{2,1,0} transpose(p), dimensions={0,2,1} + ROOT t = (f32[1,16,32]{1,2,0}, f32[1,32,16]{2,1,0}) tuple(fusion, c1) } )"; - CheckGpuMultiOutputFusion(hlo, R"( - // CHECK: %fused_computation ({{[^ ]+}} f32[16,32]) -> (f32[16,32], f32[32,16]) { - // CHECK-NEXT: [[param_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) - // CHECK-NEXT: [[s_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0]]) - // CHECK-NEXT: [[copy:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1]]) - // CHECK-NEXT: [[transpose:[^ ]+]] = f32[32,16]{1,0} transpose([[param_0]]), dimensions={1,0} - // CHECK-NEXT: ROOT {{[^ ]+}} = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple([[copy]], [[transpose]]) - // CHECK: %fusion = (f32[16,32]{0,1}, f32[32,16]{1,0}) fusion(%{{.*}}), kind=kInput, calls=%fused_computation + CheckMultiOutputFusion(hlo, R"( + // CHECK: %fused_computation ({{[^ ]+}} f32[1,16,32]) -> (f32[1,16,32], f32[1,32,16]) { + // CHECK-NEXT: [[param_0:%[^ ]+]] = f32[1,16,32]{2,1,0} parameter(0) + // CHECK-NEXT: [[s_1:%[^ ]+]] = f32[1,16,32]{2,1,0} sqrt([[param_0]]) + // CHECK-NEXT: [[copy:%[^ ]+]] = f32[1,16,32]{1,2,0} copy([[s_1]]) + // CHECK-NEXT: [[transpose:[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[param_0]]), dimensions={0,2,1} + // CHECK-NEXT: ROOT {{[^ ]+}} = (f32[1,16,32]{1,2,0}, f32[1,32,16]{2,1,0}) tuple([[copy]], [[transpose]]) + // CHECK: %fusion = (f32[1,16,32]{1,2,0}, f32[1,32,16]{2,1,0}) fusion(%{{.*}}), kind=kInput, calls=%fused_computation )"); } @@ -1871,7 +1869,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation (param_0.1: f16[16,32]) -> (f32[16,32], f16[16,32]) { // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f16[16,32]{1,0} parameter(0) // CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} convert([[param_0_1_0]]) @@ -1908,7 +1906,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, std::nullopt); + CheckMultiOutputFusion(hlo, std::nullopt); } // Do not group incompatible transposes. @@ -1941,7 +1939,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, std::nullopt); + CheckMultiOutputFusion(hlo, std::nullopt); } // A variation of the test above, where no CSE was run. @@ -1975,7 +1973,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, std::nullopt); + CheckMultiOutputFusion(hlo, std::nullopt); } TEST_F(TransposeMultiOutputFusionTest, CopyAndInput) { @@ -1996,7 +1994,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) { // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) // CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) @@ -2013,30 +2011,30 @@ TEST_F(TransposeMultiOutputFusionTest, TransposeAndInputEpilogueFusion) { HloModule module fused_computation { - param_0.1 = f32[16,32]{1,0} parameter(0) - s.1 = f32[16,32]{1,0} sqrt(param_0.1) - t.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0} + param_0.1 = f32[1,16,32]{2,1,0} parameter(0) + s.1 = f32[1,16,32]{2,1,0} sqrt(param_0.1) + t.1 = f32[1,32,16]{2,1,0} transpose(s.1), dimensions={0,2,1} ROOT out = f32[32,16,1]{2,1,0} bitcast(t.1) } ENTRY main { - p = f32[16,32]{1,0} parameter(0) + p = f32[1,16,32]{2,1,0} parameter(0) fusion = f32[32,16,1]{2,1,0} fusion(p), kind=kInput, calls=fused_computation - c1 = exponential(p) - ROOT t = tuple(fusion, c1) + c1 = f32[1,16,32]{2,1,0} exponential(p) + ROOT t = (f32[32,16,1]{2,1,0}, f32[1,16,32]{2,1,0}) tuple(fusion, c1) } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation -// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) -// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) -// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]) +// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[1,16,32]{2,1,0} parameter(0) +// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[1,16,32]{2,1,0} sqrt([[param_0_1_0]]) +// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[s_1_1]]) // CHECK-NEXT: [[out_3:%[^ ]+]] = f32[32,16,1]{2,1,0} bitcast([[c_1_2]]) -// CHECK-NEXT: [[c1_1_4:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]]) -// CHECK-NEXT: ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) tuple([[out_3]], [[c1_1_4]]) +// CHECK-NEXT: [[c1_1_4:%[^ ]+]] = f32[1,16,32]{2,1,0} exponential([[param_0_1_0]]) +// CHECK-NEXT: ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[1,16,32]{2,1,0}) tuple([[out_3]], [[c1_1_4]]) // CHECK-NEXT: } -// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] +// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[1,16,32]{2,1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] )"); } @@ -2073,7 +2071,7 @@ ENTRY computation { )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_elementwise // CHECK-NEXT: [[p_1_0:%[^ ]+]] = f32[200]{0} parameter(0) // CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[200]{0} sqrt([[p_1_0]]) @@ -2117,7 +2115,7 @@ ENTRY computation { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_elementwise (p.1: f32[10,20]) -> (f32[10,20], f32[]) { // CHECK-NEXT: [[p_1_0:%[^ ]+]] = f32[10,20]{1,0} parameter(0) // CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[10,20]{1,0} sqrt([[p_1_0]]) @@ -2179,7 +2177,7 @@ ENTRY computation { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation.1 (param_0.8: f32[], param_1.10: f32[], param_2.7: f16[100,200]) -> (f16[100,200], f32[]) { // CHECK-NEXT: [[one_3_0:%[^ ]+]] = f32[] constant(1) // CHECK-NEXT: [[one_b_3_1:%[^ ]+]] = f32[100,200]{1,0} broadcast([[one_3_0]]), dimensions={} diff --git a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.cc b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.cc new file mode 100644 index 00000000000000..0ddc33e3d21c46 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.cc @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/pgle_accuracy_checker.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +absl::StatusOr PGLEAccuracyChecker::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_RETURN_IF_ERROR(pgle_estimator_.CheckAccuracy(*module)); + return false; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.h b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.h new file mode 100644 index 00000000000000..35bfd100f36a33 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_PGLE_ACCURACY_CHECKER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_PGLE_ACCURACY_CHECKER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/service/profile_guided_latency_estimator.h" + +namespace xla::gpu { + +// This pass checks the accuracy of the input feedback-driven optimization (FDO) +// profile. If any non-NOP instruction from the given HloModule is not present +// in the profile this pass fails. +class PGLEAccuracyChecker : public HloModulePass { + public: + explicit PGLEAccuracyChecker(ProfileGuidedLatencyEstimator& pgle_estimator) + : pgle_estimator_(pgle_estimator) {} + absl::string_view name() const override { return "pgle-accuracy-checker"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + ProfileGuidedLatencyEstimator& pgle_estimator_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_PGLE_ACCURACY_CHECKER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc new file mode 100644 index 00000000000000..3f2d1ab6426fd0 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/pgle_accuracy_checker.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" +#include "xla/service/latency_hiding_scheduler.h" +#include "xla/service/profile_guided_latency_estimator.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using PGLEAccuracyCheckerTest = HloTestBase; +using ::tensorflow::profiler::ProfiledInstructionsProto; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::StatusIs; + +// Constructs PGLE estimator for a given `profile`. +std::unique_ptr GetProfileGuidedLatencyEstimator( + ProfiledInstructionsProto& profile) { + auto gpu_latency_estimator = + std::make_unique(/*pointer_size=*/8); + SchedulerConfig config; + auto aggregator = std::make_unique(); + return std::make_unique( + config, std::move(gpu_latency_estimator), profile, std::move(aggregator)); +} + +TEST_F(PGLEAccuracyCheckerTest, + ReturnsOkAndNoIRChangeIfAllInstructionsAreFoundInTheProfile) { + const absl::string_view kHloString = R"( + HloModule m + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32,32] parameter(1) + p2 = f32[32,32] parameter(2) + p3 = f32[32] parameter(3) + + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT _ = (f32[32],f32[32],f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Profile string, cost does not matter. + const std::string kProfileString = R"pb( + costs { name: "dot0" cost_us: 1.0 } + costs { name: "dot1" cost_us: 1.0 } + costs { name: "add0" cost_us: 1.0 } + costs { name: "ar-start" cost_us: 1.0 } + costs { name: "ar-start1" cost_us: 1.0 } + )pb"; + + ProfiledInstructionsProto profile; + ASSERT_TRUE(TextFormat::ParseFromString(kProfileString, &profile)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + *module->mutable_config().mutable_fdo_profile() = kProfileString; + + auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile); + PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + pgle_accuracy_checker.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(PGLEAccuracyCheckerTest, + ReturnsInvalidArgumentIfThereAreMissingInstructionsFromTheProfile) { + const absl::string_view kHloString = R"( + HloModule m + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32,32] parameter(1) + p2 = f32[32,32] parameter(2) + p3 = f32[32] parameter(3) + + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT _ = (f32[32],f32[32],f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Profile string, cost does not matter. + // We're missing `dot1` and `ar-start` from the profile. + const std::string kProfileString = R"pb( + costs { name: "dot0" cost_us: 1.0 } + costs { name: "add0" cost_us: 1.0 } + costs { name: "ar-start1" cost_us: 1.0 } + )pb"; + + ProfiledInstructionsProto profile; + ASSERT_TRUE(TextFormat::ParseFromString(kProfileString, &profile)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + *module->mutable_config().mutable_fdo_profile() = kProfileString; + + auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile); + PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator); + EXPECT_THAT(pgle_accuracy_checker.Run(module.get()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc index d0e841c4f9ebc1..493d1671e0e7fc 100644 --- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/pipelined_p2p_rewriter.h" +#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h similarity index 94% rename from third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h index 88b6bb662f2ed7..d2aca8ca17064c 100644 --- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ -#define XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_ #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -130,4 +130,4 @@ class PipelinedP2PRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc index c0b6fb31f1f29e..287603c6d0de93 100644 --- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/pipelined_p2p_rewriter.h" +#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h" #include #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc similarity index 81% rename from third_party/xla/xla/service/gpu/priority_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index 8efa94549c8cff..bae58de45849a1 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/priority_fusion.h" +#include "xla/service/gpu/transforms/priority_fusion.h" #include #include @@ -31,6 +31,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/meta/type_traits.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -46,14 +47,19 @@ limitations under the License. #include "xla/service/fusion_queue.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" @@ -64,6 +70,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" namespace xla { @@ -114,6 +121,19 @@ bool IsFusible(const HloInstruction& instr) { } } +// Returns a GpuBackendConfig proto for a Triton fusion with the given +// BlockLevelParameters. +GpuBackendConfig GetTritonGpuBackendConfig( + const BlockLevelParameters& block_level_parameters) { + GpuBackendConfig gpu_backend_config; + gpu_backend_config.mutable_fusion_backend_config()->set_kind( + std::string(kTritonFusionKind)); + *gpu_backend_config.mutable_fusion_backend_config() + ->mutable_block_level_fusion_config() = + block_level_parameters.ToBlockLevelFusionConfig(); + return gpu_backend_config; +} + // An implementation of FusionQueue that determines whether to fuse instructions // according to a cost model, and chooses the next fusion candidate according to // dynamically updated priorities. The elements in the queue are producer nodes @@ -121,23 +141,26 @@ bool IsFusible(const HloInstruction& instr) { // performance when fusing it to all of its fusible users. We greedily pick the // max-benefit producer to fuse, and update the estimated benefits of the fused // nodes and their operands. -class GpuPriorityFusionQueue { +class PriorityFusionQueue { using Priority = int64_t; using CanFuseCallback = std::function; public: - GpuPriorityFusionQueue( - HloComputation* computation, - const GpuHloCostAnalysis::Options& cost_analysis_options, - const se::DeviceDescription* device_info, - FusionProcessDumpProto* fusion_process_dump, - tsl::thread::ThreadPool* thread_pool, mlir::MLIRContext* mlir_context, - HloFusionAnalysisCache& fusion_analysis_cache, - bool triton_softmax_priority_fusion_enabled) + PriorityFusionQueue(HloComputation* computation, + const GpuHloCostAnalysis::Options& cost_analysis_options, + const se::DeviceDescription* device_info, + FusionProcessDumpProto* fusion_process_dump, + tsl::thread::ThreadPool* thread_pool, + mlir::MLIRContext* mlir_context, + HloFusionAnalysisCache& fusion_analysis_cache, + bool triton_softmax_priority_fusion_enabled) : computation_(computation), device_info_(device_info), cost_analysis_(cost_analysis_options, *device_info), + gpu_indexing_performance_model_(device_info, &fusion_analysis_cache, + cost_analysis_options.shape_size, + mlir_context), fusion_process_dump_(fusion_process_dump), thread_pool_(thread_pool), mlir_context_(mlir_context), @@ -155,7 +178,7 @@ class GpuPriorityFusionQueue { // Initializes the priority queue. std::vector instructions; for (auto* instruction : computation->MakeInstructionPostOrder()) { - UpdatePerformanceModelCache(instruction); + TF_CHECK_OK(UpdatePerformanceModelCache(instruction)); if (instruction->opcode() == HloOpcode::kParameter || instruction->user_count() == 0 || !instruction->IsFusible() || instruction->opcode() == HloOpcode::kTuple || @@ -247,36 +270,49 @@ class GpuPriorityFusionQueue { return !current_consumers_.empty(); } - void UpdatePerformanceModelCache(HloInstruction* producer) { - if (!IsFusible(*producer) && !IsGenericTritonFusion(*producer)) { - return; + absl::Status UpdatePerformanceModelCache(HloInstruction* producer) { + bool is_triton_fusion = IsGenericTritonFusion(*producer); + if (!IsFusible(*producer) && !is_triton_fusion) { + return absl::OkStatus(); } - auto config = GpuPerformanceModelOptions::PriorityFusion( - &fusion_analysis_cache_, &gpu_performance_model_cache_); + if (gpu_performance_model_cache_.Get(*producer)) { + return absl::OkStatus(); + } - if (!gpu_performance_model_cache_.Get(*producer)) { - auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction( + EstimateRunTimeData runtime_data; + if (is_triton_fusion) { + TF_ASSIGN_OR_RETURN( + runtime_data, + gpu_indexing_performance_model_.EstimateRunTimeForTriton(producer)); + } else { + auto config = GpuPerformanceModelOptions::PriorityFusion( + &fusion_analysis_cache_, &gpu_performance_model_cache_); + runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction( producer, *device_info_, &cost_analysis_, config); - gpu_performance_model_cache_.Set(*producer, runtime_data); } + + gpu_performance_model_cache_.Set(*producer, runtime_data); + + return absl::OkStatus(); } // Update priorities of all affected ops. - void UpdatePriorities() { + absl::Status UpdatePriorities() { // Revisit costs of all updated ops. It's important to update cost analysis // before recalculating priorities. for (auto instruction : to_update_priority_) { - TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction)); + TF_RETURN_IF_ERROR(cost_analysis_.RevisitInstruction(instruction)); } for (auto producer : to_update_priority_) { - UpdatePerformanceModelCache(producer); + TF_RETURN_IF_ERROR(UpdatePerformanceModelCache(producer)); } ComputeAndSetPriorities(std::vector{ to_update_priority_.begin(), to_update_priority_.end()}); to_update_priority_.clear(); + return absl::OkStatus(); } // Prepares producer and consumer instruction to be fused. Invalidates caches @@ -303,6 +339,14 @@ class GpuPriorityFusionQueue { } } + block_level_parameters_cache_.erase(instruction); + for (const HloInstruction* operand : instruction->operands()) { + auto it = block_level_parameters_cache_.find(operand); + if (it != block_level_parameters_cache_.end()) { + it->second.erase(instruction); + } + } + gpu_performance_model_cache_.Invalidate(*instruction); fusion_analysis_cache_.Invalidate(*instruction); fusion_info_cache_.Invalidate(instruction); @@ -378,6 +422,17 @@ class GpuPriorityFusionQueue { reverse_map_.erase(reverse_it); } + // Returns a map from consumer to BlockLevelParameters. This is used to + // determine if a producer-consumer fusion is a Triton fusion. + absl::flat_hash_map + GetBlockLevelParametersMap(const HloInstruction* producer) { + auto it = block_level_parameters_cache_.find(producer); + if (it == block_level_parameters_cache_.end()) { + return {}; + } + return it->second; + } + HloInstruction* current_producer() { return current_producer_; } const std::vector& current_consumers() { @@ -434,6 +489,24 @@ class GpuPriorityFusionQueue { run_times.time_fused); } + FusionDecision IsTritonSupported(const HloInstruction& instruction) { + if (instruction.opcode() != HloOpcode::kFusion) { + return IsTritonSupportedInstruction( + instruction, device_info_->gpu_compute_capability()); + } + + for (const HloInstruction* instruction : + instruction.fused_instructions_computation()->instructions()) { + if (auto codegen_decision = IsTritonSupportedInstruction( + *instruction, device_info_->gpu_compute_capability()); + !codegen_decision) { + return codegen_decision; + } + } + + return {}; + } + FusionDecision CanFuseTriton(HloInstruction* producer, HloInstruction* consumer) { if (!triton_softmax_priority_fusion_enabled_) { @@ -444,24 +517,52 @@ class GpuPriorityFusionQueue { if (!IsFusible(*consumer)) { return "the consumer is not fusible"; } + + if (auto fusion_decision = IsTritonSupported(*consumer); + !fusion_decision) { + return fusion_decision; + } } else { if (!IsFusible(*producer)) { return "the producer is not fusible"; } + + if (auto fusion_decision = IsTritonSupported(*producer); + !fusion_decision) { + return fusion_decision; + } } auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer); - SymbolicTileAnalysisOrError symbolic_tile_analysis_or = - SymbolicTileAnalysis::AnalyzeFusion(*fusion, mlir_context_); + absl::StatusOr tiled_run_time_data_or_error = + gpu_indexing_performance_model_.TryFindBestTilingForFusion(*fusion); + + if (!tiled_run_time_data_or_error.ok()) { + return FusionDecision{ + absl::StrCat("TiledRunTimeDataOrError return status: ", + tiled_run_time_data_or_error.status().message())}; + } if (const auto* fusion_decision = - std::get_if(&symbolic_tile_analysis_or)) { + std::get_if(&*tiled_run_time_data_or_error)) { return { absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ", fusion_decision->Explain())}; } + TiledRunTimeData tiled_run_time_data = + std::get(*std::move(tiled_run_time_data_or_error)); + + gpu_performance_model_cache_.Set( + *producer, *consumer, tiled_run_time_data.runtime_data.exec_time); + + { + absl::MutexLock lock(&block_level_parameters_cache_mutex_); + block_level_parameters_cache_[producer][consumer] = + tiled_run_time_data.block_level_parameters; + } + return {}; } @@ -611,9 +712,12 @@ class GpuPriorityFusionQueue { const se::DeviceDescription* device_info_; - // Reference to cost model that defines priorities in the queue. + // Cost Analysis that is used to estimate the cost of a fusion. GpuHloCostAnalysis cost_analysis_; + // Performance model that is used to estimate the run time of a fusion. + GpuPerformanceModelWithIndexingAnalysis gpu_indexing_performance_model_; + // The priority queue of producers, implemented as an ordered map, where a // key is a pair: the first element is the priority and the second element is // the unique ID of the instruction to break ties. @@ -655,6 +759,14 @@ class GpuPriorityFusionQueue { can_fuse_cache_; absl::Mutex can_fuse_cache_mutex_; + // Caches block level parameters for a (producer, consumer) pair. A cache + // entry is invalidated if producer or consumer is modified. + absl::flat_hash_map< + const HloInstruction*, + absl::flat_hash_map> + block_level_parameters_cache_; + absl::Mutex block_level_parameters_cache_mutex_; + GpuPerformanceModelCache gpu_performance_model_cache_; // Cache for `FusionFitsInBudget` to avoid recomputing expensive properties @@ -668,8 +780,7 @@ class GpuPriorityFusionQueue { } // namespace -/*static*/ bool GpuPriorityFusion::IsExpensive( - const HloInstruction& instruction) { +/*static*/ bool PriorityFusion::IsExpensive(const HloInstruction& instruction) { // Some floating-point math ops are cheap on the GPU. switch (instruction.opcode()) { case HloOpcode::kDivide: @@ -702,15 +813,15 @@ bool IsSmallConstant(const HloInstruction* instr) { ShapeUtil::ElementsIn(instr->shape()) <= 1; } -bool GpuPriorityFusion::ConsumeFuel(HloInstruction* producer, - HloInstruction* consumer) { +bool PriorityFusion::ConsumeFuel(HloInstruction* producer, + HloInstruction* consumer) { return xla::ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] { return absl::StrFormat("Not fusing producer %s with consumer %s", producer->name(), consumer->name()); }); }; -absl::StatusOr GpuPriorityFusion::Run( +absl::StatusOr PriorityFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool dump_enabled = @@ -756,7 +867,7 @@ absl::StatusOr GpuPriorityFusion::Run( for (auto* computation : fusible_computations) { CHECK(!computation->IsFusionComputation()); - auto fusion_queue = std::make_unique( + auto fusion_queue = std::make_unique( computation, cost_analysis_options_, &device_info_, fusion_process_dump_.get(), thread_pool_, &mlir_context_, fusion_analysis_cache_, triton_softmax_priority_fusion_enabled); @@ -764,6 +875,10 @@ absl::StatusOr GpuPriorityFusion::Run( while (fusion_queue->DequeueNextProducer()) { auto producer = fusion_queue->current_producer(); + absl::flat_hash_map + block_level_parameters_map = + fusion_queue->GetBlockLevelParametersMap(producer); + for (auto* consumer : fusion_queue->current_consumers()) { // Don't fuse into single bitcasts. We ignore them in the check // CanFuseWithAllNonBitcastUsers(), so we need to check it here. @@ -780,6 +895,12 @@ absl::StatusOr GpuPriorityFusion::Run( fusion_queue->OnFusingInstruction(fusion_instruction, producer, consumer); + auto backend_config_it = block_level_parameters_map.find(consumer); + if (backend_config_it != block_level_parameters_map.end()) { + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config( + GetTritonGpuBackendConfig(backend_config_it->second))); + } + changed = true; } @@ -789,7 +910,7 @@ absl::StatusOr GpuPriorityFusion::Run( TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); } - fusion_queue->UpdatePriorities(); + TF_RETURN_IF_ERROR(fusion_queue->UpdatePriorities()); } // Fuse all constants. @@ -831,8 +952,8 @@ absl::StatusOr GpuPriorityFusion::Run( return changed; } -FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, - int64_t operand_index) { +FusionDecision PriorityFusion::ShouldFuse(HloInstruction* consumer, + int64_t operand_index) { // This method is called in `InstructionFusion::Run` right before fusion, but // it will always return true. Fusion decision are fully controlled by the // PriorityQueue. If the queue returns a producer that shouldn't be fused, @@ -840,7 +961,7 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, return {}; } -HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( +HloInstruction::FusionKind PriorityFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { // Derive kInput/kLoop fusion kinds from fusion analysis. This shouldn't // matter but some passes downstream still query these instead of fusion @@ -862,15 +983,10 @@ HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( } } -HloInstruction* GpuPriorityFusion::FuseInstruction( +HloInstruction* PriorityFusion::FuseInstruction( HloInstruction* fusion_instruction, HloInstruction* producer) { HloInstruction* result = fusion_instruction; if (producer->opcode() == HloOpcode::kFusion) { - if (IsGenericTritonFusion(*producer)) { - TF_CHECK_OK(fusion_instruction->set_backend_config( - *producer->backend_config())); - } - fusion_instruction->MergeFusionInstruction(producer); } else { result = InstructionFusion::FuseInstruction(fusion_instruction, producer); @@ -878,7 +994,7 @@ HloInstruction* GpuPriorityFusion::FuseInstruction( return result; } -std::unique_ptr GpuPriorityFusion::GetFusionQueue( +std::unique_ptr PriorityFusion::GetFusionQueue( HloComputation* computation) { return nullptr; } diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h similarity index 87% rename from third_party/xla/xla/service/gpu/priority_fusion.h rename to third_party/xla/xla/service/gpu/transforms/priority_fusion.h index 999eb78ceca245..fce2be535e23a3 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_PRIORITY_FUSION_H_ -#define XLA_SERVICE_GPU_PRIORITY_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_ #include @@ -41,12 +41,12 @@ limitations under the License. namespace xla { namespace gpu { -class GpuPriorityFusion : public InstructionFusion { +class PriorityFusion : public InstructionFusion { public: - GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool, - const se::DeviceDescription& device, - GpuHloCostAnalysis::Options cost_analysis_options) - : InstructionFusion(GpuPriorityFusion::IsExpensive), + PriorityFusion(tsl::thread::ThreadPool* thread_pool, + const se::DeviceDescription& device, + GpuHloCostAnalysis::Options cost_analysis_options) + : InstructionFusion(PriorityFusion::IsExpensive), thread_pool_(thread_pool), device_info_(device), cost_analysis_options_(std::move(cost_analysis_options)), @@ -97,4 +97,4 @@ class GpuPriorityFusion : public InstructionFusion { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_PRIORITY_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc similarity index 85% rename from third_party/xla/xla/service/gpu/priority_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc index 7f56cf4e0e4691..b552fd9ade5366 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/priority_fusion.h" +#include "xla/service/gpu/transforms/priority_fusion.h" #include @@ -80,13 +80,22 @@ class PriorityFusionTest : public HloTestBase { return kinds; } - GpuPriorityFusion priority_fusion_{ + PriorityFusion priority_fusion_{ /*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(), GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(), /*per_second_rates=*/{}, /*count_multiple_input_accesses=*/true}}; }; +class PriorityFusionWithTritonEnabledTest : public PriorityFusionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = PriorityFusionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_triton_softmax_priority_fusion(true); + return debug_options; + } +}; + TEST_F(PriorityFusionTest, FuseWithSharedArgument) { auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module @@ -332,9 +341,10 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { broadcast.18310.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30039.clone.1), dimensions={} multiply.12550.clone.1 = f32[2048,24576]{1,0} multiply(tanh.798.clone.1, broadcast.18310.clone.1) convert.29370.clone.1 = bf16[2048,24576]{1,0} convert(multiply.12550.clone.1) - bitcast.22330 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1) - transpose.6582 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22330), dimensions={0,3,2,1} - convert.33705 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6582) + bitcast.1 = bf16[2048,2048,12]{2,1,0} bitcast(convert.29370.clone.1) + transpose.6582 = bf16[12,2048,2048]{2,1,0} transpose(bitcast.1), dimensions={2,1,0} + bitcast = bf16[1,12,2048,2048]{3,2,1,0} bitcast(transpose.6582) + convert.33705 = f32[1,12,2048,2048]{3,2,1,0} convert(bitcast) constant_10212 = f32[] constant(-2.38197633e+38) broadcast.22828 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10212), dimensions={} select.589 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22829, convert.33705, broadcast.22828) @@ -346,9 +356,9 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { bitcast.11069 = pred[2048,2048]{1,0} bitcast(predarg) broadcast.22825 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3} - bitcast.22331 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1) - transpose.6580 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22331), dimensions={0,3,2,1} - convert.33703 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6580) + transpose.6580 = bf16[12,2048,2048]{2,1,0} transpose(bitcast.1), dimensions={2,1,0} + bitcast.2 = bf16[1,12,2048,2048]{3,2,1,0} bitcast(transpose.6580) + convert.33703 = f32[1,12,2048,2048]{3,2,1,0} convert(bitcast.2) constant_10213 = f32[] constant(-2.38197633e+38) broadcast.22824 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10213), dimensions={} select.587 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22825, convert.33703, broadcast.22824) @@ -361,9 +371,9 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { constant_468 = f32[] constant(-2.38197633e+38) broadcast.22833 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3} - bitcast.22332 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1) - transpose.6584 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22332), dimensions={0,3,2,1} - convert.33707 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6584) + transpose.6584 = bf16[12,2048,2048]{2,1,0} transpose(bitcast.1), dimensions={2,1,0} + bitcast.3 = bf16[1,12,2048,2048]{3,2,1,0} bitcast(transpose.6584) + convert.33707 = f32[1,12,2048,2048]{3,2,1,0} convert(bitcast.3) broadcast.22832 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_468), dimensions={} select.591 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22833, convert.33707, broadcast.22832) broadcast.22821 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2} @@ -860,11 +870,8 @@ TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) { EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); } -TEST_F(PriorityFusionTest, CanMergeTritonFusionWithBothProducerAndConsumer) { -#ifndef GOOGLE_CUDA - GTEST_SKIP() << "Triton fusion only enable for CUDA devices."; -#endif - +TEST_F(PriorityFusionWithTritonEnabledTest, + CanMergeTritonFusionWithBothProducerAndConsumer) { const std::string kHloText = R"( HloModule t add { @@ -897,13 +904,10 @@ ENTRY main { param_0 = f32[125]{0} parameter(0) param_1 = f32[125,127]{1,0} parameter(1) producer_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=producer_computation - triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} + triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1","127"],"num_warps":"1"}}} ROOT consumer_fusion = f32[125,127]{1,0} fusion(param_1, triton_softmax), kind=kLoop, calls=consumer_computation })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - auto debug_options = module->config().debug_options(); - debug_options.set_xla_gpu_enable_triton_softmax_priority_fusion(true); - module->mutable_config().set_debug_options(debug_options); EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); @@ -911,7 +915,140 @@ ENTRY main { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kCustom); - EXPECT_TRUE(IsGenericTritonFusion(*root)); + ASSERT_TRUE(IsGenericTritonFusion(*root)); + + EXPECT_TRUE(root->backend_config() + ->fusion_backend_config() + .has_block_level_fusion_config()); + EXPECT_EQ(root->backend_config() + ->fusion_backend_config() + .block_level_fusion_config() + .output_tile_sizes_size(), + 2); +} + +TEST_F(PriorityFusionWithTritonEnabledTest, + FuseTritonProducerWithTwoConsumers) { + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +producer_computation { + parameter_0 = f32[125]{0} parameter(0) + ROOT broadcast = f32[125,127] broadcast(parameter_0), dimensions={0} +} + +consumer_computation.1 { + parameter_0 = f32[125,127] parameter(0) + ROOT log = f32[125,127] log(parameter_0) +} + +consumer_computation.2 { + parameter_0 = f32[125,127] parameter(0) + ROOT exp = f32[125,127] exponential(parameter_0) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + producer_fusion = f32[125,127] fusion(param_0), kind=kCustom, calls=producer_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1","127"],"num_warps":"1"}}} + consumer_fusion.1 = f32[125,127] fusion(producer_fusion), kind=kLoop, calls=consumer_computation.1 + consumer_fusion.2 = f32[125,127] fusion(producer_fusion), kind=kLoop, calls=consumer_computation.2 + ROOT tuple = (f32[125,127], f32[125,127]) tuple(consumer_fusion.1, consumer_fusion.2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction *fusion1, *fusion2; + EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(&fusion1, m::Parameter()), + m::Fusion(&fusion2, m::Parameter())))); + EXPECT_TRUE(IsGenericTritonFusion(*fusion1)); + TF_ASSERT_OK_AND_ASSIGN(auto backend_config1, + fusion1->backend_config()); + EXPECT_TRUE( + backend_config1.fusion_backend_config().has_block_level_fusion_config()); + EXPECT_EQ(backend_config1.fusion_backend_config() + .block_level_fusion_config() + .output_tile_sizes_size(), + 2); + + EXPECT_TRUE(IsGenericTritonFusion(*fusion2)); + TF_ASSERT_OK_AND_ASSIGN(auto backend_config2, + fusion2->backend_config()); + EXPECT_TRUE( + backend_config2.fusion_backend_config().has_block_level_fusion_config()); + EXPECT_EQ(backend_config2.fusion_backend_config() + .block_level_fusion_config() + .output_tile_sizes_size(), + 2); +} + +TEST_F(PriorityFusionWithTritonEnabledTest, + TritonProducerNotSupported_DoNotFuse) { + const std::string kHloText = R"( +HloModule t + +producer_computation { + parameter_0 = c64[] parameter(0) + broadcast = c64[125,127] broadcast(parameter_0), dimensions={} + ROOT real = f32[125,127] real(broadcast) +} + +triton_computation { + parameter_0 = f32[125,127] parameter(0) + parameter_1 = f32[125,127] parameter(1) + ROOT add = f32[125,127] add(parameter_0, parameter_1) +} + +ENTRY main { + param_0 = c64[] parameter(0) + param_1 = f32[125,127] parameter(1) + producer_fusion = f32[125,127] fusion(param_0), kind=kLoop, calls=producer_computation + ROOT triton_fusion = f32[125,127] fusion(producer_fusion, param_1), kind=kCustom, calls=triton_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1","127"],"num_warps":"1"}}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + // Triton does not support c64, so producer_fusion and triton_fusion and will + // not be fused. + EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); +} + +TEST_F(PriorityFusionWithTritonEnabledTest, + TritonConsumerNotSupported_DoNotFuse) { + const std::string kHloText = R"( +HloModule t + +triton_computation { + parameter_0 = f32[] parameter(0) + ROOT boardcast = f32[125,127] broadcast(parameter_0), dimensions={} +} + +consumer_computation { + parameter_0 = c64[] parameter(0) + parameter_1 = f32[125,127] parameter(1) + broadcast = c64[125,127] broadcast(parameter_0), dimensions={} + real = f32[125,127] real(broadcast) + ROOT add = f32[125,127] add(real, parameter_1) +} + +ENTRY main { + param_0 = f32[] parameter(1) + param_1 = c64[] parameter(0) + triton_fusion = f32[125,127] fusion(param_0), kind=kCustom, calls=triton_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1","127"],"num_warps":"1"}}} + ROOT consumer_fusion = f32[125,127] fusion(param_1, triton_fusion), kind=kLoop, calls=consumer_computation +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + // Triton does not support c64, so triton_fusion and consumer_fusion will not + // be fused. + EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); } TEST_F(PriorityFusionTest, DoNotFuseInsideReducer) { diff --git a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc rename to third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc index 7f1f800d2bc3de..d33c849168151e 100644 --- a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_reduce_scatter_creator.h" +#include "xla/service/gpu/transforms/reduce_scatter_creator.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h similarity index 87% rename from third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h rename to third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h index fcecb460747cc7..4e74394052595a 100644 --- a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h +++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_ -#define XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -40,4 +40,4 @@ class ReduceScatterCreator : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_ diff --git a/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc index b1d2734d9b0e49..39a2c72a10a213 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_reduce_scatter_creator.h" +#include "xla/service/gpu/transforms/reduce_scatter_creator.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc similarity index 98% rename from third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc index ac5419cf28d872..8c2929c0787f54 100644 --- a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_degenerate_dim_remover.h" +#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h similarity index 88% rename from third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h rename to third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h index 03d6819081d5da..1630aecff00e76 100644 --- a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ -#define XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -53,4 +53,4 @@ class ReductionDegenerateDimRemover : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc index bb6eb634db78a5..7a9b7fa3fdbe0e 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_degenerate_dim_remover.h" +#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h" #include diff --git a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc similarity index 98% rename from third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc index 8ab4fcf648a255..ca4fba4fac9403 100644 --- a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_dimension_grouper.h" +#include "xla/service/gpu/transforms/reduction_dimension_grouper.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h similarity index 80% rename from third_party/xla/xla/service/gpu/reduction_dimension_grouper.h rename to third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h index 8ee4efd0cfd261..d179dcd6c78415 100644 --- a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ -#define XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -32,12 +32,12 @@ namespace gpu { // // For example, // -// f[] out = reduce(f[10,20,30] input, dimensions={0,1,2}) +// out = f32[] reduce(f32[10,20,30] input, dimensions={0,1,2}) // // becomes: // -// f[600] tmp = f[600] bitcast(f[10,20,30] input) -// f[] out = reduce(f[600] tmp, dimensions={0}) +// tmp = f32[6000] bitcast(f32[10,20,30] input) +// out = f32[] reduce(f32[6000] tmp, dimensions={0}) // class ReductionDimensionGrouper : public HloModulePass { public: @@ -53,4 +53,4 @@ class ReductionDimensionGrouper : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc index fa149a13b940c0..afbbbec01d3c27 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_dimension_grouper.h" +#include "xla/service/gpu/transforms/reduction_dimension_grouper.h" #include diff --git a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc similarity index 99% rename from third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc index a91fdf7e387b7a..fd45f8b34ec55b 100644 --- a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_layout_normalizer.h" +#include "xla/service/gpu/transforms/reduction_layout_normalizer.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h similarity index 89% rename from third_party/xla/xla/service/gpu/reduction_layout_normalizer.h rename to third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h index 7d2d207773e057..f6e2d7c200dd67 100644 --- a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ -#define XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -51,4 +51,4 @@ class ReductionLayoutNormalizer : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc index 817d9c73c95b16..46f5e9320eadfc 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_layout_normalizer.h" +#include "xla/service/gpu/transforms/reduction_layout_normalizer.h" #include diff --git a/third_party/xla/xla/service/gpu/reduction_splitter.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc similarity index 98% rename from third_party/xla/xla/service/gpu/reduction_splitter.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc index cd37319a47de30..dce9288888a8a5 100644 --- a/third_party/xla/xla/service/gpu/reduction_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_splitter.h" +#include "xla/service/gpu/transforms/reduction_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_splitter.h b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h similarity index 92% rename from third_party/xla/xla/service/gpu/reduction_splitter.h rename to third_party/xla/xla/service/gpu/transforms/reduction_splitter.h index 7e7652500e6d3a..87520d3d7063b1 100644 --- a/third_party/xla/xla/service/gpu/reduction_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ -#define XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -56,4 +56,4 @@ class ReductionSplitter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_ diff --git a/third_party/xla/xla/service/gpu/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/reduction_splitter_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc index 13a5210fee2ee6..4b9f6fb130ed0f 100644 --- a/third_party/xla/xla/service/gpu/reduction_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_splitter.h" +#include "xla/service/gpu/transforms/reduction_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/rename_fusions.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc similarity index 98% rename from third_party/xla/xla/service/gpu/rename_fusions.cc rename to third_party/xla/xla/service/gpu/transforms/rename_fusions.cc index a2a3048a05655e..9ab62f68664ebd 100644 --- a/third_party/xla/xla/service/gpu/rename_fusions.cc +++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/rename_fusions.h" +#include "xla/service/gpu/transforms/rename_fusions.h" #include #include diff --git a/third_party/xla/xla/service/gpu/rename_fusions.h b/third_party/xla/xla/service/gpu/transforms/rename_fusions.h similarity index 90% rename from third_party/xla/xla/service/gpu/rename_fusions.h rename to third_party/xla/xla/service/gpu/transforms/rename_fusions.h index c3065a4dbd1df5..5abcd6169cc9d1 100644 --- a/third_party/xla/xla/service/gpu/rename_fusions.h +++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RENAME_FUSIONS_H_ -#define XLA_SERVICE_GPU_RENAME_FUSIONS_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -44,4 +44,4 @@ class RenameFusions : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RENAME_FUSIONS_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_ diff --git a/third_party/xla/xla/service/gpu/rename_fusions_test.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/rename_fusions_test.cc rename to third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc index 60c97cf2ff9438..47470859f84d2e 100644 --- a/third_party/xla/xla/service/gpu/rename_fusions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/rename_fusions.h" +#include "xla/service/gpu/transforms/rename_fusions.h" #include diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc similarity index 96% rename from third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc rename to third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc index 771e8cbed8a9a0..3841f4a1551f77 100644 --- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc +++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sanitize_constant_names.h" +#include "xla/service/gpu/transforms/sanitize_constant_names.h" #include @@ -29,7 +29,7 @@ namespace xla { namespace gpu { -absl::StatusOr GpuSanitizeConstantNames::Run( +absl::StatusOr SanitizeConstantNames::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h similarity index 84% rename from third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h rename to third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h index 08701a4fe3432d..f743137f764ffb 100644 --- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h +++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ -#define XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -28,7 +28,7 @@ namespace gpu { // Sanitizes HLO instruction names for the GPU backend. Currently, it only // replaces . and - in the HLO constant instruction names with _ to please the // LLVM PTX backend. -class GpuSanitizeConstantNames : public HloModulePass { +class SanitizeConstantNames : public HloModulePass { public: absl::string_view name() const override { return "sanitize-constant-names"; } @@ -41,4 +41,4 @@ class GpuSanitizeConstantNames : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc similarity index 91% rename from third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc rename to third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc index 17f45dc100f684..8e9779003af6f5 100644 --- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sanitize_constant_names.h" +#include "xla/service/gpu/transforms/sanitize_constant_names.h" #include #include @@ -44,7 +44,7 @@ TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value()); HloInstruction *root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->name(), "equal_to"); } @@ -59,7 +59,7 @@ TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value()); HloInstruction *root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->name(), "equal_to"); } @@ -74,7 +74,7 @@ TEST_F(SanitizeConstantNamesTest, NewInstructionNameRegisteredWithModule) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value()); HloInstruction *root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->name(), "equal_to"); @@ -99,7 +99,7 @@ TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value()); EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"), GmockMatch(m::Constant())); EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"), diff --git a/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc b/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc similarity index 95% rename from third_party/xla/xla/service/gpu/gpu_scatter_expander.cc rename to third_party/xla/xla/service/gpu/transforms/scatter_expander.cc index b03b340cb8bbd9..26eb2107087a0d 100644 --- a/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc +++ b/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_scatter_expander.h" +#include "xla/service/gpu/transforms/scatter_expander.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/service/gpu/gpu_scatter_expander.h b/third_party/xla/xla/service/gpu/transforms/scatter_expander.h similarity index 83% rename from third_party/xla/xla/service/gpu/gpu_scatter_expander.h rename to third_party/xla/xla/service/gpu/transforms/scatter_expander.h index 100350cb67ac01..f86b93235b2b5b 100644 --- a/third_party/xla/xla/service/gpu/gpu_scatter_expander.h +++ b/third_party/xla/xla/service/gpu/transforms/scatter_expander.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_ -#define XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_ #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -26,7 +26,7 @@ namespace xla { class GpuScatterExpander : public ScatterExpander { public: // Although we pass kEliminateAllScatters, we override this behavior in - // InstruuctionMatchesPattern and select only some scatters to expand. + // InstructionMatchesPattern and select only some scatters to expand. GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {} absl::string_view name() const override { return "gpu_scatter_expander"; } @@ -37,4 +37,4 @@ class GpuScatterExpander : public ScatterExpander { } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_ diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc similarity index 99% rename from third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc rename to third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc index 9672bf259a328c..d9c1debacc5e27 100644 --- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/scatter_slice_simplifier.h" +#include "xla/service/gpu/transforms/scatter_slice_simplifier.h" #include #include diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h similarity index 92% rename from third_party/xla/xla/service/gpu/scatter_slice_simplifier.h rename to third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h index 349837747466b6..96f39b5fbed1a6 100644 --- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h +++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_ -#define XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -55,4 +55,4 @@ class ScatterSliceSimplifier : public HloModulePass { } // namespace xla -#endif // XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc rename to third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc index 281a4f0576e0c7..8f1c93c1ec31d0 100644 --- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/scatter_slice_simplifier.h" +#include "xla/service/gpu/transforms/scatter_slice_simplifier.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc rename to third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc index a0af798118669d..9929b355345b4d 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_schedule_postprocessing.h" +#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include @@ -132,7 +132,7 @@ absl::StatusOr ProcessComputation( } // anonymous namespace -absl::StatusOr GpuSchedulePostprocessing::Run( +absl::StatusOr SchedulePostprocessing::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!module->has_schedule()) return false; diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h similarity index 83% rename from third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h rename to third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h index d8eda81f257803..899098dfcce68f 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ -#define XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -34,11 +34,9 @@ namespace gpu { // attribute value untouch for the operations with is_sync=true and for P2P // operations, assumming the runtime won't use those values. // -class GpuSchedulePostprocessing : public HloModulePass { +class SchedulePostprocessing : public HloModulePass { public: - absl::string_view name() const override { - return "gpu-schedule-postprocessing"; - } + absl::string_view name() const override { return "schedule-postprocessing"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -49,4 +47,4 @@ class GpuSchedulePostprocessing : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc similarity index 91% rename from third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc rename to third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc index 9d4956bdd5b4db..0c9c6e675e1fa7 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_schedule_postprocessing.h" +#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include @@ -32,9 +32,9 @@ namespace xla { namespace gpu { namespace { -using GpuSchedulePostprocessingTest = HloTestBase; +using SchedulePostprocessingTest = HloTestBase; -TEST_F(GpuSchedulePostprocessingTest, SynchronousOpsNotChanged) { +TEST_F(SchedulePostprocessingTest, SynchronousOpsNotChanged) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -47,12 +47,12 @@ TEST_F(GpuSchedulePostprocessingTest, SynchronousOpsNotChanged) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_FALSE(changed); } -TEST_F(GpuSchedulePostprocessingTest, P2POpsNotChanged) { +TEST_F(SchedulePostprocessingTest, P2POpsNotChanged) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -71,12 +71,12 @@ TEST_F(GpuSchedulePostprocessingTest, P2POpsNotChanged) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_FALSE(changed); } -TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { +TEST_F(SchedulePostprocessingTest, AsynchronousOpsChanged) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -89,7 +89,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_TRUE(changed); @@ -101,7 +101,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { EXPECT_TRUE(collective_backend_config.no_parallel_custom_call()); } -TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { +TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -115,7 +115,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_FALSE(changed); @@ -127,7 +127,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); } -TEST_F(GpuSchedulePostprocessingTest, +TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelNestedCustomcall) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -146,7 +146,7 @@ TEST_F(GpuSchedulePostprocessingTest, )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_FALSE(changed); diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc similarity index 85% rename from third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc rename to third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc index f9b4ae37cb0249..d7962130a2eeb8 100644 --- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/scheduling_instruction_annotator.h" +#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "tsl/platform/statusor.h" namespace xla::gpu { @@ -37,6 +38,12 @@ absl::StatusOr AnnotateSchedulingInstructionNames( if (!inst->metadata().scheduling_name().empty()) { continue; } + // We skip constants as we might have to sanitize them in order to satisfy + // LLVM backend. I.e. we allow `GpuSanitizeConstantNames` pass to run post + // scheduling. + if (inst->opcode() == HloOpcode::kConstant) { + continue; + } inst->set_metadata_scheduling_name(inst->name()); changed = true; } diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h similarity index 87% rename from third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h rename to third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h index 3f9b769d3b85f0..03c21bbf09b784 100644 --- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ -#define XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -41,4 +41,4 @@ class SchedulingInstructionAnnotator : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc similarity index 73% rename from third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc rename to third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc index 146607f790da52..abe8d50a63c09b 100644 --- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/scheduling_instruction_annotator.h" +#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include @@ -72,6 +72,40 @@ TEST_F(SchedulingInstructionAnnotatorTest, EXPECT_TRUE(filecheck_matches); } +TEST_F(SchedulingInstructionAnnotatorTest, SkipsAnnotatingConstants) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[1] parameter(0) + c1 = f32[1] constant(42) + ROOT add0 = f32[1] add(p0, c1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + SchedulingInstructionAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + + ASSERT_TRUE(changed); + constexpr absl::string_view kExpected = R"( +// CHECK: %[[P0:.+]] = {{.*}} parameter(0) +// CHECK-SAME: scheduling_name="[[P0]]" +// CHECK-NEXT: %[[C1:.+]] = f32[1] +// CHECK-NOT: scheduling_name +// CHECK-SAME: constant({42}) +// CHECK: %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[C1]]) +// CHECK-SAME: scheduling_name="[[ADD0]]" + )"; + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck( + module->ToString(HloPrintOptions().set_print_operand_shape(false)), + kExpected)); + EXPECT_TRUE(filecheck_matches); +} + TEST_F(SchedulingInstructionAnnotatorTest, DoesNotAnnotateAllInstructionsWithTheirRespectiveNames) { constexpr absl::string_view kHloString = R"( diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc similarity index 99% rename from third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc rename to third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index c6bd79636b924c..fe43b285c834dd 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/softmax_rewriter_triton.h" +#include "xla/service/gpu/transforms/softmax_rewriter_triton.h" #include #include @@ -47,6 +47,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -457,7 +458,8 @@ absl::StatusOr CanSymbolicTileAnalysisTileDiamondChain( mlir::MLIRContext context; SymbolicTileAnalysisOrError symbolic_tile_analysis_or_error = SymbolicTileAnalysis::AnalyzeComputation( - *softmax_fusion->called_computation(), &context); + *softmax_fusion->called_computation(), &context, + TritonEmitterConstraints::GetBuilder()); bool can_tile = std::holds_alternative( symbolic_tile_analysis_or_error); diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h similarity index 94% rename from third_party/xla/xla/service/gpu/softmax_rewriter_triton.h rename to third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h index 9da8cc54daf400..36f780f43cd1e8 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ -#define XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_ #include #include @@ -98,4 +98,4 @@ class SoftmaxRewriterTriton : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_ diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc rename to third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc index 8488031e19afdc..1b3139c9d40132 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/softmax_rewriter_triton.h" +#include "xla/service/gpu/transforms/softmax_rewriter_triton.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc similarity index 96% rename from third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc index 217387c2548f60..b299db8d19316a 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sort_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" #include #include @@ -203,8 +203,7 @@ bool IsCubCompatibleSort(HloSortInstruction* sort_op) { VLOG(2) << "Sort dimension should be the minor one"; return false; } - if (Product(operand_shape.dimensions()) < - GpuSortRewriter::SortSizeThreshold()) { + if (Product(operand_shape.dimensions()) < SortRewriter::SortSizeThreshold()) { VLOG(2) << "Tensor shape size is too small to see an improvement"; return false; } @@ -239,7 +238,7 @@ HloInstruction* UnpackResultPair(HloSortInstruction* sort_op, } // namespace // Rewrites a single sort instruction with a custom call. -absl::StatusOr GpuSortRewriter::RunOnInstruction( +absl::StatusOr SortRewriter::RunOnInstruction( HloSortInstruction* sort_op) { // Get the sort tensor index and direction. SortComputationAnalysis sort_config = AnalyzeSortOp(*sort_op).value(); @@ -307,7 +306,7 @@ absl::StatusOr GpuSortRewriter::RunOnInstruction( } // Rewrites the sorts in the given computation into calls to CUB. -absl::StatusOr GpuSortRewriter::RunOnComputation( +absl::StatusOr SortRewriter::RunOnComputation( HloComputation* computation) { std::vector sort_ops; for (auto* inst : computation->instructions()) { @@ -325,17 +324,17 @@ absl::StatusOr GpuSortRewriter::RunOnComputation( } // Replace compatible sort operations with custom calls. -absl::StatusOr GpuSortRewriter::Run( +absl::StatusOr SortRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), before:\n" + module->ToString()); + XLA_VLOG_LINES(2, "SortRewriter::Run(), before:\n" + module->ToString()); bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); changed |= result; } - XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(2, "SortRewriter::Run(), after:\n" + module->ToString()); return changed; } diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h similarity index 88% rename from third_party/xla/xla/service/gpu/gpu_sort_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/sort_rewriter.h index 51dba3c95d9efa..406df7a0472a27 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_ -#define XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -31,9 +31,9 @@ namespace gpu { // Only a subset of shapes is supported - either a single tensor with a simple // compare function or a pair of tensors where keys are unsigned integers. -class GpuSortRewriter : public HloModulePass { +class SortRewriter : public HloModulePass { public: - absl::string_view name() const override { return "gpu-sort-rewriter"; } + absl::string_view name() const override { return "sort-rewriter"; } // CUB radix sort is slower than XLA sort on small shapes, so do not rewrite // tensors with sizes below this limit. @@ -60,4 +60,4 @@ class GpuSortRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc similarity index 85% rename from third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc rename to third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc index abacbc1111bfdb..e9bf60cdb4c9b7 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc @@ -13,30 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sort_rewriter.h" - #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" #include "tsl/platform/statusor.h" namespace xla { namespace gpu { -absl::StatusOr GpuSortRewriter::RunOnInstruction( +absl::StatusOr SortRewriter::RunOnInstruction( HloSortInstruction* sort_op) { return false; } -absl::StatusOr GpuSortRewriter::RunOnComputation( +absl::StatusOr SortRewriter::RunOnComputation( HloComputation* computation) { return false; } -absl::StatusOr GpuSortRewriter::Run( +absl::StatusOr SortRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return false; diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc similarity index 91% rename from third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc index 69cdb92e39ed77..853de5b50ba6c6 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sort_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" #include @@ -35,18 +35,18 @@ namespace { namespace m = ::xla::match; -class GpuSortRewriterTest : public HloTestBase { +class SortRewriterTest : public HloTestBase { public: void SetUp() override { HloTestBase::SetUp(); - GpuSortRewriter::SetSortSizeThresholdForTestingOnly(1000); + SortRewriter::SetSortSizeThresholdForTestingOnly(1000); } bool RunModuleAndPass(HloModule* module) { auto cloned = module->Clone(); - bool changed = GpuSortRewriter().Run(module).value(); + bool changed = SortRewriter().Run(module).value(); if (changed) { - // Here we run an end to end test to make sure that GpuSortRewriter does + // Here we run an end to end test to make sure that SortRewriter does // not introduce an incorrect rewrite. To do this, we need to clone the // original module because the interpreter cannot process the already // optimized module. @@ -62,7 +62,7 @@ class GpuSortRewriterTest : public HloTestBase { }; // Basic sort: ascending. -TEST_F(GpuSortRewriterTest, SortKeysLessThan) { +TEST_F(SortRewriterTest, SortKeysLessThan) { constexpr char kHlo[] = R"( HloModule TestModule @@ -88,7 +88,7 @@ ENTRY %main { } // Basic sort: descending. -TEST_F(GpuSortRewriterTest, SortKeysGreaterThan) { +TEST_F(SortRewriterTest, SortKeysGreaterThan) { constexpr char kHlo[] = R"( HloModule TestModule @@ -114,7 +114,7 @@ ENTRY %main { } // Comparer swaps the parameter order -> direction is reversed. -TEST_F(GpuSortRewriterTest, SortKeysGreaterThanSwapped) { +TEST_F(SortRewriterTest, SortKeysGreaterThanSwapped) { constexpr char kHlo[] = R"( HloModule TestModule @@ -140,7 +140,7 @@ ENTRY %main { } // Sort a pair of tensors, keys go first. -TEST_F(GpuSortRewriterTest, SortPairs) { +TEST_F(SortRewriterTest, SortPairs) { constexpr char kHlo[] = R"( HloModule TestModule @@ -167,7 +167,7 @@ ENTRY %main { } // Sort a pair of tensors, keys go last. -TEST_F(GpuSortRewriterTest, SortPairsSwapped) { +TEST_F(SortRewriterTest, SortPairsSwapped) { constexpr char kHlo[] = R"( HloModule TestModule @@ -194,7 +194,7 @@ ENTRY %main { } // CUB sort doesn't support more than two tensors. -TEST_F(GpuSortRewriterTest, NoRewriteManyTensors) { +TEST_F(SortRewriterTest, NoRewriteManyTensors) { constexpr char kHlo[] = R"( HloModule TestModule @@ -221,7 +221,7 @@ ENTRY %main { } // Only 1D shapes are supported. -TEST_F(GpuSortRewriterTest, NoRewriteNonMinorSortDimension) { +TEST_F(SortRewriterTest, NoRewriteNonMinorSortDimension) { constexpr char kHlo[] = R"( HloModule TestModule @@ -241,7 +241,7 @@ ENTRY %main { } // Kernels are compiled for a subset of types. -TEST_F(GpuSortRewriterTest, NoRewriteUnsupportedType) { +TEST_F(SortRewriterTest, NoRewriteUnsupportedType) { constexpr char kHlo[] = R"( HloModule TestModule @@ -261,7 +261,7 @@ ENTRY %main { } // Comparer must be a simple function. -TEST_F(GpuSortRewriterTest, NoRewriteComplexComparer) { +TEST_F(SortRewriterTest, NoRewriteComplexComparer) { constexpr char kHlo[] = R"( HloModule TestModule @@ -282,7 +282,7 @@ ENTRY %main { } // Comparer must use adjacent input values. -TEST_F(GpuSortRewriterTest, NoRewriteMixedKeysValues) { +TEST_F(SortRewriterTest, NoRewriteMixedKeysValues) { constexpr char kHlo[] = R"( HloModule TestModule @@ -306,7 +306,7 @@ ENTRY %main { } // Small shapes do not see improvement from CUB sort. -TEST_F(GpuSortRewriterTest, NoRewriteSmallSize) { +TEST_F(SortRewriterTest, NoRewriteSmallSize) { constexpr char kHlo[] = R"( HloModule TestModule @@ -326,7 +326,7 @@ ENTRY %main { } // Basic sort: with batch dimension. -TEST_F(GpuSortRewriterTest, SortWithBatchDim) { +TEST_F(SortRewriterTest, SortWithBatchDim) { constexpr char kHlo[] = R"( HloModule TestModule @@ -352,7 +352,7 @@ ENTRY %main { } // Basic sort: with multiple batch dimensions. -TEST_F(GpuSortRewriterTest, SortWithMultipleBatchDims) { +TEST_F(SortRewriterTest, SortWithMultipleBatchDims) { constexpr char kHlo[] = R"( HloModule TestModule @@ -379,7 +379,7 @@ ENTRY %main { // Sort a pair of tensors (values, indices generated by iota) with a complex // compare. -TEST_F(GpuSortRewriterTest, SortPairsIotaComparerSimple) { +TEST_F(SortRewriterTest, SortPairsIotaComparerSimple) { constexpr char kHlo[] = R"( HloModule TestModule @@ -412,7 +412,7 @@ ENTRY %main { // Sort a pair of tensors (values, indices generated by iota) with a complex // compare computation that matches the output of the StableSortExpander pass. -TEST_F(GpuSortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) { +TEST_F(SortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) { constexpr char kHlo[] = R"( HloModule TestModule @@ -444,8 +444,8 @@ ENTRY %main { m::GetTupleElement(m::CustomCall(), 1)))); } -TEST_F(GpuSortRewriterTest, SortSizeThresholdIsSet) { - EXPECT_EQ(GpuSortRewriter::SortSizeThreshold(), 1000); +TEST_F(SortRewriterTest, SortSizeThresholdIsSet) { + EXPECT_EQ(SortRewriter::SortSizeThreshold(), 1000); } } // namespace diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc similarity index 99% rename from third_party/xla/xla/service/gpu/stream_attribute_annotator.cc rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc index 35c338039f5abb..68805b1ddc3c0c 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/stream_attribute_annotator.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include #include diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h similarity index 91% rename from third_party/xla/xla/service/gpu/stream_attribute_annotator.h rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h index 8a0284adee390e..81816f88dabba2 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ -#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -57,4 +57,4 @@ class StreamAttributeAnnotator : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc index e12d985a3fb2a3..c7d2ca59cff0e9 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/stream_attribute_annotator.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include #include diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc similarity index 97% rename from third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc index 822c6473dba483..be0eb6fc7ac5e0 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/stream_attribute_async_wrapper.h" +#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h similarity index 88% rename from third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h index 95fe7bba66508e..157b57913b6b71 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ -#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -46,4 +46,4 @@ class StreamAttributeAsyncWrapper : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc index 8b3dcb23eac7bc..32ed4c50c57ca1 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/stream_attribute_async_wrapper.h" +#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h" #include diff --git a/third_party/xla/xla/service/gpu/topk_specializer.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc similarity index 98% rename from third_party/xla/xla/service/gpu/topk_specializer.cc rename to third_party/xla/xla/service/gpu/transforms/topk_specializer.cc index bd01a076cc1711..1cc6206ee8908a 100644 --- a/third_party/xla/xla/service/gpu/topk_specializer.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/topk_specializer.h" +#include "xla/service/gpu/transforms/topk_specializer.h" #include diff --git a/third_party/xla/xla/service/gpu/topk_specializer.h b/third_party/xla/xla/service/gpu/transforms/topk_specializer.h similarity index 88% rename from third_party/xla/xla/service/gpu/topk_specializer.h rename to third_party/xla/xla/service/gpu/transforms/topk_specializer.h index 5b57f57b77bba7..e3ec5658f497cd 100644 --- a/third_party/xla/xla/service/gpu/topk_specializer.h +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_ -#define XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -38,4 +38,4 @@ class TopkSpecializer : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_ diff --git a/third_party/xla/xla/service/gpu/topk_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/topk_test.cc rename to third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc index 43e25b8543cc61..96d7e49bade1c4 100644 --- a/third_party/xla/xla/service/gpu/topk_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/gpu/transforms/topk_specializer.h" + #include #include @@ -33,7 +35,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/topk_specializer.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/platform_util.h" #include "xla/service/topk_rewriter.h" diff --git a/third_party/xla/xla/service/gpu/topk_splitter.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/topk_splitter.cc rename to third_party/xla/xla/service/gpu/transforms/topk_splitter.cc index d20116dd22dd7c..41ba13500c4182 100644 --- a/third_party/xla/xla/service/gpu/topk_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/topk_splitter.h" +#include "xla/service/gpu/transforms/topk_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/topk_splitter.h b/third_party/xla/xla/service/gpu/transforms/topk_splitter.h similarity index 91% rename from third_party/xla/xla/service/gpu/topk_splitter.h rename to third_party/xla/xla/service/gpu/transforms/topk_splitter.h index 8fee2dc4975dbd..c6fe4290d7e225 100644 --- a/third_party/xla/xla/service/gpu/topk_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TOPK_SPLITTER_H_ -#define XLA_SERVICE_GPU_TOPK_SPLITTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_ #include @@ -49,4 +49,4 @@ class TopKSplitter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_TOPK_SPLITTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_ diff --git a/third_party/xla/xla/service/gpu/topk_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/topk_splitter_test.cc rename to third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc index 834185f990956c..8236c26d4056ae 100644 --- a/third_party/xla/xla/service/gpu/topk_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/topk_splitter.h" +#include "xla/service/gpu/transforms/topk_splitter.h" #include diff --git a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc new file mode 100644 index 00000000000000..d81d3be88b4273 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/transpose_dimension_grouper.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/permutation_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +class TransposeDimensionGroupVisitor : public DfsHloRewriteVisitor { + public: + absl::Status HandleTranspose(HloInstruction *transpose) override { + VLOG(4) << "Input: " << transpose->ToString(); + absl::InlinedVector permutation; + auto normalized_dims = ShapeUtil::GetNormalizedLogicalTransposeShape( + transpose->shape(), transpose->dimensions(), permutation); + if (!normalized_dims.has_value() || + normalized_dims == transpose->shape().dimensions()) { + return absl::OkStatus(); + } + auto normalized_operand_dims = + ComposePermutations(*normalized_dims, InversePermutation(permutation)); + Shape grouped_operand_shape = ShapeUtil::MakeShapeWithDescendingLayout( + transpose->shape().element_type(), normalized_operand_dims); + auto new_operand = transpose->AddInstruction(HloInstruction::CreateBitcast( + grouped_operand_shape, transpose->mutable_operand(0))); + Shape grouped_shape = ShapeUtil::MakeShapeWithDescendingLayout( + transpose->shape().element_type(), *normalized_dims); + auto new_transpose = + transpose->AddInstruction(HloInstruction::CreateTranspose( + grouped_shape, new_operand, permutation)); + VLOG(5) << "Generated new transpose: " << new_transpose->ToString(); + return ReplaceWithNewInstruction( + transpose, + HloInstruction::CreateBitcast(transpose->shape(), new_transpose)); + } +}; + +absl::StatusOr TransposeDimensionGrouper::Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) { + TF_ASSIGN_OR_RETURN( + bool changed, + TransposeDimensionGroupVisitor().RunOnModule(module, execution_threads)); + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.h b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.h new file mode 100644 index 00000000000000..c07ada3c39d7a7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.h @@ -0,0 +1,57 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRANSPOSE_DIMENSION_GROUPER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TRANSPOSE_DIMENSION_GROUPER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Groups dimensions that are adjacent (logically and physically) in the +// transpose operand and the transpose output. +// +// Precondition: LayoutNormalization has been run (physical proximity and +// logical proximity become the same). +// +// For example, +// +// out = f32[30,10,20] transpose(f32[10,20,30] input, dimensions={2,0,1}) +// +// becomes: +// +// tmp = f32[200,30] bitcast(f32[10,20,30] input) +// transpose = f32[30,200] transpose(f32[200,30] tmp, dimensions={1,0}) +// out = f32[30,0,20] bitcast(f32[30,200] transpose) +// +class TransposeDimensionGrouper : public HloModulePass { + public: + absl::string_view name() const override { + return "transpose-dimension-grouper"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_TRANSPOSE_DIMENSION_GROUPER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper_test.cc b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper_test.cc new file mode 100644 index 00000000000000..bbcf3dbe68dcf8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/transpose_dimension_grouper.h" + +#include + +#include "absl/strings/string_view.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { + +namespace { + +class TransposeDimensionGrouperTest : public HloTestBase { + public: + void CheckDimensionGrouper(absl::string_view hlo, + std::optional expected) { + RunAndFilecheckHloRewrite(hlo, gpu::TransposeDimensionGrouper{}, expected); + } +}; + +TEST_F(TransposeDimensionGrouperTest, TransposeWithGrouping) { + const char* hlo = R"( +HloModule TransposeWithGrouping + +ENTRY main { + input = f32[100,1,10,32,2]{4,3,2,1,0} parameter(0) + ROOT out = f32[10,1,32,100,2]{4,3,2,1,0} transpose(input), dimensions={2,1,3,0,4} +} +)"; + + CheckDimensionGrouper(hlo, + R"( +// CHECK: [[input_0:%[^ ]+]] = f32[100,1,10,32,2]{4,3,2,1,0} parameter(0) +// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,320,2]{2,1,0} bitcast([[input_0]]) +// CHECK: [[transpose:%[^ ]+]] = f32[320,100,2]{2,1,0} transpose([[bitcast_1]]), dimensions={1,0,2} +// CHECK: ROOT {{.*}} = f32[10,1,32,100,2]{4,3,2,1,0} bitcast([[transpose]]) + )"); +} + +// TODO(b/328656780): Do not normalize to 3D once the emitter supports any +// number of dimensions. +TEST_F(TransposeDimensionGrouperTest, Normalize2DTo3D) { + const char* hlo = R"( +HloModule TransposeWithGrouping + +ENTRY main { + input = f32[50,20,30]{2,1,0} parameter(0) + ROOT out = f32[20,30,50]{2,1,0} transpose(input), dimensions={1,2,0} +} +)"; + + CheckDimensionGrouper(hlo, + R"( +// CHECK: [[input_0:%[^ ]+]] = f32[50,20,30]{2,1,0} parameter(0) +// CHECK: [[bitcast_1:%[^ ]+]] = f32[1,50,600]{2,1,0} bitcast([[input_0]]) +// CHECK: [[transpose:%[^ ]+]] = f32[1,600,50]{2,1,0} transpose([[bitcast_1]]), dimensions={0,2,1} +// CHECK: ROOT {{.*}} = f32[20,30,50]{2,1,0} bitcast([[transpose]]) + )"); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc index b54d006947f9c1..fb023fc8cc693f 100644 --- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/tree_reduction_rewriter.h" +#include "xla/service/gpu/transforms/tree_reduction_rewriter.h" #include #include @@ -374,7 +374,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { se::GpuComputeCapability gpu_version_; }; -absl::StatusOr GpuTreeReductionRewriter::Run( +absl::StatusOr TreeReductionRewriter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { VLOG(5) << "Rewriter input: " << module->ToString(); diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h similarity index 86% rename from third_party/xla/xla/service/gpu/tree_reduction_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h index 5f6edf8ac33e4e..7f57d211a8acbd 100644 --- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h @@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_ -#define XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_ - +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -74,15 +73,13 @@ namespace gpu { // f32[A, Q, C] inner_reduce = reduce(reshaped, dimensions={2}) // f32[A, C] outer_reduce = reduce(inner_reduce, dimensions={1}) // -class GpuTreeReductionRewriter : public HloModulePass { +class TreeReductionRewriter : public HloModulePass { public: - explicit GpuTreeReductionRewriter(se::GpuComputeCapability gpu_version) + explicit TreeReductionRewriter(se::GpuComputeCapability gpu_version) : gpu_version_(gpu_version) {} - ~GpuTreeReductionRewriter() override = default; - absl::string_view name() const override { - return "gpu-tree-reduction-rewriter"; - } + ~TreeReductionRewriter() override = default; + absl::string_view name() const override { return "tree-reduction-rewriter"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -96,4 +93,4 @@ class GpuTreeReductionRewriter : public HloModulePass { } // end namespace gpu } // end namespace xla -#endif // XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc index ef6e18966e47c0..bea969efccac4a 100644 --- a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/tree_reduction_rewriter.h" +#include "xla/service/gpu/transforms/tree_reduction_rewriter.h" #include @@ -33,11 +33,11 @@ class TreeReductionRewriterTest : public HloTestBase { RunAndFilecheckHloRewrite( hlo, #if TENSORFLOW_USE_ROCM - gpu::GpuTreeReductionRewriter{se::RocmComputeCapability { + gpu::TreeReductionRewriter{se::RocmComputeCapability { "908" }}, #else - gpu::GpuTreeReductionRewriter{se::CudaComputeCapability{8, 1}}, + gpu::TreeReductionRewriter{se::CudaComputeCapability{8, 1}}, #endif expected); } diff --git a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc similarity index 97% rename from third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc index 2dcd36569b7073..e81bdae50a25bf 100644 --- a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/triangular_solve_rewriter.h" +#include "xla/service/gpu/transforms/triangular_solve_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h similarity index 91% rename from third_party/xla/xla/service/gpu/triangular_solve_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h index 6d4b1c14188a08..c52e0ffb545a3e 100644 --- a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_ -#define XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -57,4 +57,4 @@ class TriangularSolveRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc similarity index 99% rename from third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc rename to third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc index fc426e4876905d..10ae640f3659b1 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/triton_fusion_numerics_verifier.h" +#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h" #include #include diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h similarity index 92% rename from third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h rename to third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h index 18f9527817e74d..e3dc6ebe5dd9f7 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_ -#define XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" @@ -71,4 +71,4 @@ absl::Status ForAllTritonFusions( } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc rename to third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc index eab1e553f2efef..0382577d0d0fb9 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/triton_fusion_numerics_verifier.h" +#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h" #include #include diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc similarity index 98% rename from third_party/xla/xla/service/gpu/variadic_op_splitter.cc rename to third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc index f1371575b7d625..0712040a7d1029 100644 --- a/third_party/xla/xla/service/gpu/variadic_op_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/variadic_op_splitter.h" +#include "xla/service/gpu/transforms/variadic_op_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter.h b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h similarity index 88% rename from third_party/xla/xla/service/gpu/variadic_op_splitter.h rename to third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h index 4449ce2a0bdcda..304afa1d80a605 100644 --- a/third_party/xla/xla/service/gpu/variadic_op_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ -#define XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -40,4 +40,4 @@ class VariadicOpSplitter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_ diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc rename to third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc index 6d7b72eebe0ba3..1d726136a3a8ee 100644 --- a/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/variadic_op_splitter.h" +#include "xla/service/gpu/transforms/variadic_op_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc similarity index 81% rename from third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc rename to third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc index 8f5e26124f24a4..04d5905c652467 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_windowed_einsum_handler.h" +#include "xla/service/gpu/transforms/windowed_einsum_handler.h" #include +#include #include #include "absl/container/flat_hash_set.h" @@ -27,6 +28,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/hlo_creation_utils.h" @@ -48,15 +50,15 @@ namespace m = match; // and type conversions of FP8 operands into the bodies of their while loops, // i.e. rewrites // -// inputs --> dequant --> while loop {dynamic-slice/collective-permute/dot} +// inputs --> dequant --> while loop {collective-permute/dot/etc} // // into // -// inputs --> while loop {dequant --> dynamic-slice/collective-permute/dot}. -absl::Status ShiftDequantizationF8(const HloComputation* comp, - const std::array& gte) { - HloInstruction* while_instr = comp->WhileCallInstruction(); - if (!while_instr) { +// inputs --> while loop {dequant --> collective-permute/dot/etc}. +absl::Status ShiftDequantizationF8(HloComputation* while_body) { + HloInstruction* while_instr = while_body->WhileCallInstruction(); + // The input of the while loop will be modified and must have no other users. + if (!while_instr || while_instr->operand(0)->user_count() != 1) { return absl::OkStatus(); } @@ -105,39 +107,42 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, return absl::OkStatus(); } - // Identify the dot and collective-permute or dynamic-slice instructions in - // the all-gather or reduce-scatter patterns in while's body. - HloComputation* while_body = while_instr->while_body(); + // Identify the dot, get-tuple-element and collective-permute or dynamic-slice + // instructions in the all-gather or reduce-scatter patterns in while's body. HloComputation* while_condition = while_instr->while_condition(); HloInstruction* while_root = while_body->root_instruction(); - std::array dots, dyn_slices{nullptr, nullptr}, + std::array dots, gtes, dyn_slices{nullptr, nullptr}, coll_perms{nullptr, nullptr}; - if (Match( - while_root, - m::Tuple(m::CollectivePermute( - &coll_perms[1], m::CollectivePermute( - &coll_perms[0], m::Op().Is(gte[0]))), - m::Op().Is(gte[1]), - m::DynamicUpdateSlice( - m::DynamicUpdateSlice().WithOperand( - 1, m::Dot(&dots[0], m::Op().Is(gte[0]), - m::Op().Is(gte[1]))), - m::Dot(&dots[1], m::Op(), m::Op().Is(gte[1])), m::Op(), - m::Op(), m::Op()), - m::Op(), m::Op()))) { + if (Match(while_root, + m::Tuple(m::CollectivePermute( + &coll_perms[1], + m::CollectivePermute( + &coll_perms[0], + m::GetTupleElement(>es[0], m::Parameter(), 0))), + m::GetTupleElement(>es[1], m::Parameter(), 1), + m::DynamicUpdateSlice( + m::DynamicUpdateSlice().WithOperand( + 1, m::Dot(&dots[0], m::Op(), m::Op())), + m::Dot(&dots[1], m::Op(), m::Op()), m::Op(), m::Op(), + m::Op()), + m::Op(), m::Op())) && + dots[0]->operand(0) == gtes[0] && dots[0]->operand(1) == gtes[1] && + dots[1]->operand(1) == gtes[1]) { VLOG(5) << "Identified all-gather windowed einsum pattern."; } else if (Match( while_root, - m::Tuple(m::Op().Is(gte[0]), m::Op().Is(gte[1]), + m::Tuple(m::GetTupleElement(>es[0], m::Parameter(), 0), + m::GetTupleElement(>es[1], m::Parameter(), 1), m::AddAnyOrder( m::Dot(&dots[0], m::DynamicSlice(&dyn_slices[0]), - m::Op().Is(gte[1])), + m::Op()), m::Op()), m::CollectivePermute(m::AddAnyOrder( m::Dot(&dots[1], m::DynamicSlice(&dyn_slices[1]), - m::Op().Is(gte[1])), + m::Op()), m::Op())), - m::Op()))) { + m::Op())) && + dots[0]->operand(1) == gtes[1] && dots[1]->operand(1) == gtes[1]) { VLOG(5) << "Identified reduce-scatter windowed einsum pattern."; } else { VLOG(5) << "Unable to identify valid windowed einsum pattern."; @@ -165,14 +170,14 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, } // In the while body, replace the existing get-tuple-element instructions - // retrieving BF16/FP16/FP32 dot operands with dequantized get-tuple-element + // retrieving BF16/FP16/FP32 dot operands with get-tuple-element // instructions retrieving FP8 dot operands from the input tuple. HloInstruction* body_param = while_body->parameter_instruction(0); for (int k = 0; k < 2; ++k) { TF_ASSIGN_OR_RETURN(HloInstruction * operand_f8, MakeGetTupleElementHlo(body_param, k)); - if (while_root->operand(k) == gte[k]) { + if (while_root->operand(k) == gtes[k]) { TF_RETURN_IF_ERROR( while_root->ReplaceOperandWithDifferentShape(k, operand_f8)); ShapeUtil::UpdateTupleShape(operand_f8->shape(), k, @@ -191,7 +196,7 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, // Dequantize the operands of the dots and dynamic-slices. HloInstruction* operand_f32 = - MakeConvertToHlo(operand_f8, gte[k]->shape().element_type()); + MakeConvertToHlo(operand_f8, gtes[k]->shape().element_type()); HloInstruction* broadcast_scale = MakeBroadcastHlo(operand_scale, {}, operand_f32->shape()); TF_ASSIGN_OR_RETURN( @@ -203,10 +208,10 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, // operands. The order of dequantization and dynamic-slices will be // exchanged in gemm_rewriter.cc. for (int l = 0; l < 2; ++l) { - if (dots[l]->operand(k) == gte[k]) { + if (dots[l]->operand(k) == gtes[k]) { TF_RETURN_IF_ERROR(dots[l]->ReplaceOperandWith(k, operand_scaled)); } - if (dyn_slices[l] && dyn_slices[l]->operand(0) == gte[k]) { + if (dyn_slices[l] && dyn_slices[l]->operand(0) == gtes[k]) { TF_RETURN_IF_ERROR( dyn_slices[l]->ReplaceOperandWith(0, operand_scaled)); } @@ -216,7 +221,7 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, // dots[1], which prevents it from being exchanged with dequantization in // gemm_rewriter.cc. Instead, directly insert the dequantization before // dots[1] here. - if (coll_perms[0] && coll_perms[0]->operand(0) == gte[k]) { + if (coll_perms[0] && coll_perms[0]->operand(0) == gtes[k]) { std::array coll_perms_f8{nullptr, nullptr}; // Change the type of both collective-permutes to FP8. coll_perms_f8[0] = @@ -228,7 +233,7 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, // Insert the dequantization between coll_perms[0] and dots[1]. HloInstruction* coll_perm0_f32 = - MakeConvertToHlo(coll_perms_f8[0], gte[k]->shape().element_type()); + MakeConvertToHlo(coll_perms_f8[0], gtes[k]->shape().element_type()); TF_ASSIGN_OR_RETURN(HloInstruction * x_scaled, MakeBinaryHlo(binaries[k]->opcode(), coll_perm0_f32, broadcast_scale)); @@ -243,17 +248,19 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, } // Update the shape of the while call in the parent computation. + HloInstruction* new_while_instr = while_instr->AddInstruction( + while_instr->CloneWithNewShape(while_root->shape())); TF_RETURN_IF_ERROR( - while_instr->ReplaceAllUsesWithDifferentShape(while_instr->AddInstruction( - while_instr->CloneWithNewShape(while_root->shape())))); + while_instr->ReplaceAllUsesWithDifferentShape(new_while_instr)); + while_instr->while_body()->SetWhileCallInstruction(new_while_instr); TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr)); if (coll_perms[0]) { TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[1])); TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[0])); } - TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[0])); - TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[1])); + TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gtes[0])); + TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gtes[1])); VLOG(5) << "FP8 dequantization moved into while loop."; return absl::OkStatus(); @@ -302,22 +309,11 @@ absl::StatusOr HandleRsWindowedEinsumLoop(HloComputation* comp, return changed; } for (auto inst : comp->MakeInstructionPostOrder()) { - HloInstruction* matched_dot; - std::array gte; // The dot we'd like to parallelize is consuming the second loop input // as RHS. - if (Match(inst, - m::Dot(&matched_dot, - m::DynamicSlice().WithOperand( - 0, m::GetTupleElement(>e[0], m::Parameter(), 0)), - m::GetTupleElement(>e[1], m::Parameter(), 1)))) { - // If present, move the dequantization of FP8 operands of the dot into the - // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom - // Call. - TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte)); - + if (Match(inst, m::Dot())) { // Dispatch the dot to additional compute stream. - TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id)); + TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id)); ++stream_id; changed = true; } @@ -332,6 +328,10 @@ absl::StatusOr HandleRsWindowedEinsumLoop(HloComputation* comp, changed = true; } } + // If present, move the dequantization of FP8 operands of the dot into the + // while loop to allow e.g. gemm_rewriter.cc to fuse the dequantization and + // dot into an FP8 GEMM. + TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp)); return changed; } @@ -345,23 +345,15 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, return changed; } for (auto inst : comp->MakeInstructionPostOrder()) { - HloInstruction* matched_dot; - std::array gte; // The dot we'd like to parallelize is consuming the second loop input // as RHS and first loop input as LHS. - if (Match(inst, m::Dot(&matched_dot, - m::GetTupleElement(>e[0], m::Parameter(), 0), - m::GetTupleElement(>e[1], m::Parameter(), 1)))) { - // If present, move the dequantization of FP8 operands of the dot into the - // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom - // Call. - TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte)); - + if (Match(inst, m::Dot(m::GetTupleElement(m::Parameter(), 0), + m::GetTupleElement(m::Parameter(), 1)))) { // Dispatch the dot to additional compute stream. - TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id)); + TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id)); ++stream_id; TF_RETURN_IF_ERROR( - SetForceDelayForInstruction(matched_dot, /*force_delay=*/true)); + SetForceDelayForInstruction(inst, /*force_delay=*/true)); changed = true; } @@ -375,6 +367,11 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, changed = true; } } + // If present, move the dequantization of FP8 operands of the dot into the + // while loop to allow e.g. gemm_rewriter.cc to fuse the dequantization and + // dot into an FP8 GEMM. + TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp)); + return changed; } @@ -382,12 +379,11 @@ static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) { const HloInstruction* loop_tuple = while_loop->operand(0); const Shape& tuple_shape = loop_tuple->shape(); CHECK(tuple_shape.IsTuple()); - return tuple_shape.tuple_shapes_size(); + return tuple_shape.tuple_shapes_size() - 1; } absl::Status ProcessWindowedEinsumLoopForActivationCaching( - GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop, - HloInstruction* ag_with_shared_operand) { + WindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) { HloInstruction* loop = ag_loop.loop; // Transform the while body to cache the allgathered result in the // output buffer to be consumed by the dot @@ -406,41 +402,10 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( // The full buffer that we will use to cache the accumulated activation // is the last operand in the output tuple. int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop); - std::vector new_input_shapes(input_shape.tuple_shapes().begin(), - input_shape.tuple_shapes().end()); - new_input_shapes.push_back(ag_with_shared_operand->shape()); - // Update body input shape - Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes); - *input_tuple->mutable_shape() = new_input_shape; HloInstruction* full_buffer_output_gte = while_body->AddInstruction(HloInstruction::CreateGetTupleElement( - ag_with_shared_operand->shape(), input_tuple, - full_cache_buffer_index)); - - // Update condition input shape - HloComputation* cond_comp = loop->while_condition(); - HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0); - *cond_input_tuple->mutable_shape() = new_input_shape; - - // Update input to the while instruction in parent computation - HloInstruction* original_while_input = loop->mutable_operand(0); - HloComputation* parent_comp = loop->parent(); - std::vector new_operands( - original_while_input->operands().begin(), - original_while_input->operands().end()); - new_operands.push_back( - parent_comp->AddInstruction(HloInstruction::CreateBroadcast( - ag_with_shared_operand->shape(), - parent_comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(new_input_shapes[0].element_type()))), - {}))); - HloInstruction* new_while_input = - parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands)); - TF_RETURN_IF_ERROR( - loop->ReplaceOperandWithDifferentShape(0, new_while_input)); - TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape( - original_while_input, new_while_input)); - *loop->mutable_shape() = new_input_shape; + ShapeUtil::GetTupleElementShape(input_shape, full_cache_buffer_index), + input_tuple, full_cache_buffer_index)); HloInstruction* new_full_buffer_output = nullptr; // Find the DUS in the loop body and re-use the slice indices @@ -550,6 +515,7 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( HloInstruction::CreateTuple(original_operands)); TF_RETURN_IF_ERROR( while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); + return absl::OkStatus(); } @@ -579,8 +545,7 @@ struct MatchedGemmA2aResult { class WindowedEinsumVisitor : public DfsHloRewriteVisitor { public: explicit WindowedEinsumVisitor( - std::vector& - all_ag_loops) + std::vector& all_ag_loops) : all_ag_loops_(all_ag_loops) {} absl::StatusOr MatchA2aGemmWithIntermediateReshapes( HloInstruction* dot, HloInstruction** lhs, HloInstruction** rhs) { @@ -673,65 +638,145 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { absl::Status HandleDot(HloInstruction* dot) override { CHECK_EQ(dot->opcode(), HloOpcode::kDot); HloComputation* comp = dot->parent(); - // Rewrites a allgather-dot pattern that shares the same operand - // with a windowed einsum loop to consume the output of the loop - // and remove the all-gather. - // Now that we have processed all loops, we can check if there are any - // allgather-dot pattern that we can optimize. We'd want to transform: + // Rewrites an allgather-dot pattern that shares the same operand with a + // windowed einsum loop to consume the output of the loop and remove the + // all-gather. Now that we have processed all loops, we can check if there + // are any allgather-dot pattern that we can optimize. We'd want to + // transform: // input // / | - // / | - // AG windowed loop - // / - // / - // dot + // dequantize | + // (optional) | + // / | + // AG windowed loop + // / + // / + // dot // to: - // input + // input // | // | - // windowed loop + // windowed loop // | + // dequantize + // (FP8) // | // dot // The windowed einsum loop will also be rewritten to output the full input // to be consumed by the dot. This is advantageous since the chained dot can // fully utilize all the resources on the GPU while comm is hidden by the - // first collective matmul loop. - for (GpuWindowedEinsumHandler::WindowedEinsumAgLoops ag_loop : + // first collective matmul loop. When the data type is FP8, input is + // dequantized, i.e. type converted and scaled, ahead of the all-gather. The + // dequantization is moved in WindowedEinsumVisitor between the windowed + // loop and the dot. + for (WindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop : all_ag_loops_) { + HloComputation* comp = dot->parent(); HloInstruction* loop = ag_loop.loop; - HloInstruction* ag_operand = nullptr; - - if (Match(dot, m::Dot(m::AllGather(&ag_operand), m::Op())) || - Match(dot, m::Dot(m::Op(), m::AllGather(&ag_operand)))) { - HloInstruction* windowed_lhs = - loop->mutable_operand(0)->mutable_operand(0); - HloInstruction* ag_with_shared_operand = nullptr; - if (ag_operand && ag_operand->mutable_operand(0) == windowed_lhs) { - ag_with_shared_operand = ag_operand; + + HloInstruction* windowed_lhs = + loop->mutable_operand(0)->mutable_operand(0); + + // In the FP8 case, the all-gather operates on the dequantized + // windowed_lhs. The dequantization is shifted to the output of the while + // loop below. + HloInstruction *all_gather, *binary, *scale = nullptr; + auto all_gather_optionally_dequantized = m::AnyOf( + m::AllGather(&all_gather, + m::Divide(&binary, m::Convert(m::Op().Is(windowed_lhs)), + m::Broadcast(m::Op(&scale)))), + m::AllGather( + &all_gather, + m::MultiplyAnyOrder(&binary, m::Convert(m::Op().Is(windowed_lhs)), + m::Broadcast(m::Op(&scale)))), + m::AllGather(&all_gather, m::Op().Is(windowed_lhs))); + + if (!Match(dot, m::Dot(all_gather_optionally_dequantized, m::Op())) && + !Match(dot, m::Dot(m::Op(), all_gather_optionally_dequantized))) { + continue; + } + + if (scale) { + // When the loop contains an FP8 GEMM, a scalar scaling factor must be + // captured. + if (!ShapeUtil::IsScalar(scale->shape())) { + continue; } - if (!ag_with_shared_operand) { + // The element type of windowed_lhs must be a supported FP8 type. + if (windowed_lhs->shape().element_type() != F8E4M3FN && + windowed_lhs->shape().element_type() != F8E5M2) { continue; } + // The scaling multiplication or division must be in BF16, FP16 or FP32. + if (binary->shape().element_type() != BF16 && + binary->shape().element_type() != F16 && + binary->shape().element_type() != F32) { + continue; + } + } + + if (!ag_loop.consumed) { + // Add a broadcasted zero of the same type as windowed_lhs. This caches + // the accumulated activation inside the loop. + Literal zero_literal = + LiteralUtil::Zero(windowed_lhs->shape().element_type()); + HloInstruction* zero = comp->AddInstruction( + HloInstruction::CreateConstant(std::move(zero_literal))); + Shape zero_bcast_shape = ShapeUtil::ChangeElementType( + all_gather->shape(), windowed_lhs->shape().element_type()); + HloInstruction* zero_bcast = + MakeBroadcastHlo(zero, {}, zero_bcast_shape); + loop->mutable_operand(0)->AppendOperand(zero_bcast); + ShapeUtil::AppendShapeToTuple( + zero_bcast->shape(), loop->mutable_operand(0)->mutable_shape()); + + // Update the parameter tuples of while's body and condition + // computations. + for (HloComputation* while_comp : + {loop->while_body(), loop->while_condition()}) { + while_comp->ReplaceParameter( + 0, HloInstruction::CreateParameter( + 0, loop->mutable_operand(0)->shape(), + while_comp->parameter_instruction(0)->name())); + } + + // Update the shape of the while loop in the parent computation. + *loop->mutable_shape() = loop->operand(0)->shape(); + VLOG(5) << "Found all-gather that shares the same operand with a " "windowed einsum loop : " << loop->ToString(); - if (!ag_loop.consumed) { - TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching( - ag_loop, ag_with_shared_operand)); - ag_loop.consumed = true; - } - int64_t cache_output_index = dot->operand_index(ag_with_shared_operand); - HloComputation* comp = dot->parent(); - HloInstruction* new_gte = - comp->AddInstruction(HloInstruction::CreateGetTupleElement( - loop, GetAgActivationCacheIndex(loop) - 1)); TF_RETURN_IF_ERROR( - dot->ReplaceOperandWith(cache_output_index, new_gte)); - TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand)); + ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); + ag_loop.consumed = true; + } + + int64_t cache_output_index = dot->operand_index(all_gather); + HloInstruction* new_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, GetAgActivationCacheIndex(loop))); + + HloInstruction* new_gte_scaled; + + if (scale) { + // In the FP8 case, insert the dequantization of windowed_lhs between + // the while loop and the dot. + HloInstruction* new_convert = + MakeConvertToHlo(new_gte, binary->shape().element_type()); + HloInstruction* bcast_scale = + MakeBroadcastHlo(scale, {}, new_convert->shape()); + TF_ASSIGN_OR_RETURN( + new_gte_scaled, + MakeBinaryHlo(binary->opcode(), new_convert, bcast_scale)); + } + + TF_RETURN_IF_ERROR(dot->ReplaceOperandWith( + cache_output_index, scale ? new_gte_scaled : new_gte)); + if (all_gather->user_count() == 0) { + TF_RETURN_IF_ERROR(comp->RemoveInstruction(all_gather)); } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms @@ -1106,16 +1151,16 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { } private: - std::vector& all_ag_loops_; + std::vector& all_ag_loops_; }; } // namespace -absl::StatusOr GpuWindowedEinsumHandler::Run( +absl::StatusOr WindowedEinsumHandler::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( - 5, "GpuWindowedEinsumHandler::Run(), before:\n" + module->ToString()); + 5, "WindowedEinsumHandler::Run(), before:\n" + module->ToString()); bool changed = false; int64_t stream_id = hlo_query::NextChannelId(*module); @@ -1128,13 +1173,12 @@ absl::StatusOr GpuWindowedEinsumHandler::Run( changed = comp_result; } else if (comp->name().find(kWindowedEinsumAgLoopName) == 0) { VLOG(5) << "Processing computation: " << comp->name(); - TF_ASSIGN_OR_RETURN(bool comp_result, - HandleAgWindowedEinsumLoop(comp, stream_id)); + TF_ASSIGN_OR_RETURN(changed, HandleAgWindowedEinsumLoop(comp, stream_id)); all_ag_loops_.push_back( WindowedEinsumAgLoops(comp->WhileCallInstruction())); - changed = comp_result; } } + for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { WindowedEinsumVisitor visitor(all_ag_loops_); @@ -1142,8 +1186,8 @@ absl::StatusOr GpuWindowedEinsumHandler::Run( changed |= visitor.changed(); } - XLA_VLOG_LINES( - 5, "GpuWindowedEinsumHandler::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(5, + "WindowedEinsumHandler::Run(), after:\n" + module->ToString()); return changed; } diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h similarity index 86% rename from third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h rename to third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h index b511920f7f24b0..bcc7680e1b7ef5 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ -#define XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_ #include @@ -35,11 +35,9 @@ namespace xla::gpu { // optimize it on GPU by annotating independent gemms with // stream ids in the backend config. By running them in different // streams, we can practically achieve overlap between gemms too. -class GpuWindowedEinsumHandler : public HloModulePass { +class WindowedEinsumHandler : public HloModulePass { public: - absl::string_view name() const override { - return "gpu-windowed-einsum-handler"; - } + absl::string_view name() const override { return "windowed-einsum-handler"; } struct WindowedEinsumAgLoops { explicit WindowedEinsumAgLoops(HloInstruction* loop) : loop(loop) {} @@ -63,4 +61,4 @@ class GpuWindowedEinsumHandler : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc similarity index 85% rename from third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc rename to third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc index 6f23319980e90c..151b5b41b5b866 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_windowed_einsum_handler.h" +#include "xla/service/gpu/transforms/windowed_einsum_handler.h" #include #include @@ -34,7 +34,7 @@ namespace { namespace m = ::xla::match; -using GpuWindowedEinsumHanlderTest = HloTestBase; +using WindowedEinsumHandlerTest = HloTestBase; HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) { for (auto inst : comp->instructions()) { @@ -45,7 +45,7 @@ HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) { return nullptr; } -TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, AgLoopsHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,512,24576]{2,1,0}, bf16[24576,24576]{1,0})->bf16[2048,24576]{1,0}}, num_partitions=4 @@ -102,7 +102,7 @@ ENTRY test_main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -121,7 +121,7 @@ ENTRY test_main { cp1->backend_config()->force_earliest_schedule()); } -TEST_F(GpuWindowedEinsumHanlderTest, RsLoopsHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, RsLoopsHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[2048,24576]{1,0})->bf16[512,24576]{1,0}}, num_partitions=4 @@ -180,7 +180,7 @@ ENTRY main.9_spmd { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -198,7 +198,7 @@ ENTRY main.9_spmd { cp1->backend_config()->force_earliest_schedule()); } -TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsMultipleConsumersAreChained) { +TEST_F(WindowedEinsumHandlerTest, AgLoopsMultipleConsumersAreChained) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[24576,24576]{1,0})->bf16[2,2048,24576]{2,1,0}}, num_partitions=4 @@ -259,7 +259,7 @@ ENTRY main.12_spmd { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -286,7 +286,7 @@ ENTRY main.12_spmd { m::Op(), m::Op(), m::Op(), m::Op()), m::Op(), m::Op(), m::Op(), m::Op())))); } -TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, A2aGemmHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,8192]{3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=8 @@ -350,7 +350,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, @@ -358,7 +358,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, EXPECT_TRUE(filecheck_matched); } -TEST_F(GpuWindowedEinsumHanlderTest, GemmA2aHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, GemmA2aHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,32768]{3,2,1,0})->bf16[1,4,2048,8192]{3,2,1,0}}, num_partitions=4 @@ -422,7 +422,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, @@ -430,7 +430,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1, EXPECT_TRUE(filecheck_matched); } -TEST_F(GpuWindowedEinsumHanlderTest, A2aTransposeLoopsHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, A2aTransposeLoopsHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=4 @@ -504,7 +504,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -513,7 +513,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, EXPECT_TRUE(filecheck_matched); } -TEST_F(GpuWindowedEinsumHanlderTest, GemmA2aTransposeLoopsHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, GemmA2aTransposeLoopsHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,4,2048,32768]{3,2,1,0}, bf16[1,32768,8192]{2,1,0})->bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0}}, num_partitions=4 @@ -588,7 +588,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -597,7 +597,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204 EXPECT_TRUE(filecheck_matched); } -TEST_F(GpuWindowedEinsumHanlderTest, AllGatherF8) { +TEST_F(WindowedEinsumHandlerTest, AllGatherF8) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 @@ -660,7 +660,7 @@ ENTRY test_main { } )"; - RunAndFilecheckHloRewrite(kHloString, GpuWindowedEinsumHandler(), + RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), R"( ; CHECK-LABEL: windowed_dot_general_body_ag ; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) @@ -716,7 +716,7 @@ ENTRY test_main { )"); } -TEST_F(GpuWindowedEinsumHanlderTest, ReduceScatterF8) { +TEST_F(WindowedEinsumHandlerTest, ReduceScatterF8) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f8e4m3fn[2,2048,24576]{2,1,0}, f32[], f32[])->f32[2,512,24576]{2,1,0}}, num_partitions=4 @@ -780,7 +780,7 @@ ENTRY main.9_spmd { } )"; - RunAndFilecheckHloRewrite(kHloString, GpuWindowedEinsumHandler(), + RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), R"( ; CHECK-LABEL: windowed_dot_general_body_rs ; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) @@ -814,13 +814,13 @@ ENTRY main.9_spmd { ; CHECK-DAG: lhs_contracting_dims={2}, ; CHECK-DAG: rhs_contracting_dims={0}, ; CHECK-DAG: backend_config={ -; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]", +; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID0:[1-9][0-9]*]]", ; CHECK-DAG: "wait_on_operation_queues":[], ; CHECK-DAG: "force_earliest_schedule":false} ; CHECK-NEXT: [[ADD3:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[CP0]], [[DOT0]]), ; CHECK-DAG: backend_config={" ; CHECK-DAG: operation_queue_id":"0", -; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"], +; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID0]]"], ; CHECK-DAG: "force_earliest_schedule":false} ; CHECK-NEXT: [[GTE6:[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=3 ; CHECK-NEXT: [[ADD4:%[^ ]+]] = u32[] add([[GTE4]], [[PID]]) @@ -830,14 +830,137 @@ ENTRY main.9_spmd { ; CHECK-NEXT: [[DSLICE3:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE1]], [[C0]]), dynamic_slice_sizes={2,512,24576} ; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE3]], [[MUL1]]), ; CHECK-DAG: lhs_contracting_dims={2}, -; CHECK-DAG: rhs_contracting_dims={0} +; CHECK-DAG: rhs_contracting_dims={0}, +; CHECK-DAG: backend_config={ +; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID1:[1-9][0-9]*]]", +; CHECK-DAG: "wait_on_operation_queues":[], +; CHECK-DAG: "force_earliest_schedule":false} ; CHECK-NEXT: [[ADD5:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[GTE6]], [[DOT1]]) +; CHECK-DAG: backend_config={" +; CHECK-DAG: operation_queue_id":"0", +; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID1]]"], +; CHECK-DAG: "force_earliest_schedule":false} ; CHECK-NEXT: [[CP1:[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[ADD5]]), channel_id=10 ; CHECK-NEXT: ROOT [[OUT:[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[GTE0]], [[GTE1]], [[ADD3]], [[CP1]], [[ADD0]], /*index=5*/[[GTE3]], [[GTE5]]) )"); } -TEST_F(GpuWindowedEinsumHanlderTest, +TEST_F(WindowedEinsumHandlerTest, AllGatherMultipleConsumersF8) { + constexpr absl::string_view kHloString = R"( +HloModule all_gather_multiple_consumers_f8, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f8e4m3fn[24576,24576]{1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[], f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 +windowed_dot_general_body_ag { + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + lhs = f32[2,512,24576]{2,1,0} get-tuple-element(input), index=0 + permuted_lhs0 = f32[2,512,24576]{2,1,0} collective-permute(lhs), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + permuted_lhs1 = f32[2,512,24576]{2,1,0} collective-permute(permuted_lhs0), channel_id=3, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + rhs = f32[24576,24576]{1,0} get-tuple-element(input), index=1 + partial_dot_output = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=2 + dot0 = f32[2,512,24576]{2,1,0} dot(lhs, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c0 = s32[] constant(0) + dot_update_slice_offsets = s32[4]{0} constant({0, 512, 1024, 1536}) + loop_counter = u32[] get-tuple-element(input), index=4 + partition_id = u32[] partition-id() + loop_counter_plus_partition_id = u32[] add(loop_counter, partition_id) + c4 = u32[] constant(4) + dot_update_slice_offsets_index0 = u32[] remainder(loop_counter_plus_partition_id, c4) + dot_update_slice_offset0 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index0), dynamic_slice_sizes={1} + dot_update_slice_offset_scalar0 = s32[] reshape(dot_update_slice_offset0) + updated_dot_output0 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(partial_dot_output, dot0, c0, dot_update_slice_offset_scalar0, c0) + dot1 = f32[2,512,24576]{2,1,0} dot(permuted_lhs0, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c1 = u32[] constant(1) + loop_counter_plus_one = u32[] add(loop_counter, c1) + loop_counter_plus_partition_id_plus_one = u32[] add(loop_counter_plus_one, partition_id) + dot_update_slice_offsets_index1 = u32[] remainder(loop_counter_plus_partition_id_plus_one, c4) + dot_update_slice_offset1 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index1), dynamic_slice_sizes={1} + dot_update_slice_offset1_scalar = s32[] reshape(dot_update_slice_offset1) + updated_dot_output1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(updated_dot_output0, dot1, c0, dot_update_slice_offset1_scalar, c0) + pass_through = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=3 + next_loop_counter = u32[] add(loop_counter_plus_one, c1) + ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(permuted_lhs1, rhs, updated_dot_output1, pass_through, next_loop_counter) +} // windowed_dot_general_body_ag + +windowed_dot_general_cond_ag { + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + loop_counter = u32[] get-tuple-element(input), index=4 + loop_limit = u32[] constant(4) + ROOT compare = pred[] compare(loop_counter, loop_limit), direction=LT +} + +ENTRY main { + lhs = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs0 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + c0_f32 = f32[] constant(0) + c0_f32_bcast = f32[2,2048,24576]{2,1,0} broadcast(c0_f32), dimensions={} + c0_u32 = u32[] constant(0) + // Dequantization of LHS and RHS: + scale_lhs = f32[] parameter(4) + scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={} + lhs_f32 = f32[2,512,24576]{2,1,0} convert(lhs) + lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_f32, scale_lhs_bcast) + scale_rhs0 = f32[] parameter(5) + scale_rhs0_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs0), dimensions={} + rhs0_f32 = f32[24576,24576]{1,0} convert(rhs0) + rhs0_scaled = f32[24576,24576]{1,0} multiply(rhs0_f32, scale_rhs0_bcast) + // While loop of all-gather windowed einsum: + while_input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs0_scaled, c0_f32_bcast, c0_f32_bcast, c0_u32) + while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(while_input), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + // Additional all-gather FP8 dot operating on a dequantized RHS and the LHS also consumed by the windowed einsum. + all-gather1 = f32[2,2048,24576]{2,1,0} all-gather(lhs_scaled), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true + rhs1 = f8e4m3fn[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} + scale_rhs1 = f32[] parameter(6) + scale_rhs1_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs1), dimensions={} + rhs1_f32 = f32[24576,24576]{1,0} convert(rhs1) + rhs1_scaled = f32[24576,24576]{1,0} multiply(rhs1_f32, scale_rhs1_bcast) + dot1 = f32[2,2048,24576]{2,1,0} dot(all-gather1, rhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + // Another all-gather FP8 dot operating on a dequantized RHS and the LHS also consumed by the windowed einsum. + all-gather2 = f32[2,2048,24576]{2,1,0} all-gather(lhs_scaled), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true + rhs2 = f8e4m3fn[24576,24576]{1,0} parameter(3), sharding={devices=[1,4]<=[4]} + scale_rhs2 = f32[] parameter(7) + scale_rhs2_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs2), dimensions={} + rhs2_f32 = f32[24576,24576]{1,0} convert(rhs2) + rhs2_scaled = f32[24576,24576]{1,0} multiply(rhs2_f32, scale_rhs2_bcast) + dot2 = f32[2,2048,24576]{2,1,0} dot(all-gather2, rhs2_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT product = f32[2,2048,24576]{2,1,0} multiply(dot1, dot2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), + R"( +; CHECK-LABEL: %main +; CHECK: [[WHILE0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[], f8e4m3fn[2,2048,24576]{2,1,0}) while([[TUPLE0:%[^ ]+]]), +; CHECK-DAG: condition=%windowed_dot_general_cond_ag, +; CHECK-DAG: body=%windowed_dot_general_body_ag +; CHECK: [[LHS1:%[^ ]+]] = f8e4m3fn[2,2048,24576]{2,1,0} get-tuple-element([[WHILE0]]), index=7 +; CHECK-NEXT: [[LHS1_F32:%[^ ]+]] = f32[2,2048,24576]{2,1,0} convert([[LHS1]]) +; CHECK-NEXT: [[SCALE_LHS1_BCAST:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[SCALE_LHS1:%[^ ]+]]), dimensions={} +; CHECK-NEXT: [[LHS1_SCALED:%[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[LHS1_F32]], [[SCALE_LHS1_BCAST]]) +; CHECK-NEXT: [[RHS1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} +; CHECK-NEXT: [[RHS1_F32:%[^ ]+]] = f32[24576,24576]{1,0} convert([[RHS1]]) +; CHECK: [[SCALE_RHS1_BCAST:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[SCALE_RHS1:%[^ ]+]]), dimensions={} +; CHECK-NEXT: [[RHS1_SCALED:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[RHS1_F32]], [[SCALE_RHS1_BCAST]]) +; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dot([[LHS1_SCALED]], [[RHS1_SCALED]]), +; CHECK-DAG: lhs_contracting_dims={2}, +; CHECK-DAG: rhs_contracting_dims={0} +; CHECK: [[LHS2:%[^ ]+]] = f8e4m3fn[2,2048,24576]{2,1,0} get-tuple-element([[WHILE0]]), index=7 +; CHECK-NEXT: [[LHS2_F32:%[^ ]+]] = f32[2,2048,24576]{2,1,0} convert([[LHS2]]) +; CHECK-NEXT: [[SCALE_LHS2_BCAST:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[SCALE_LHS2:%[^ ]+]]), dimensions={} +; CHECK-NEXT: [[LHS2_SCALED:%[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[LHS2_F32]], [[SCALE_LHS2_BCAST]]) +; CHECK-NEXT: [[RHS2:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} parameter(3), sharding={devices=[1,4]<=[4]} +; CHECK-NEXT: [[RHS2_F32:%[^ ]+]] = f32[24576,24576]{1,0} convert([[RHS2]]) +; CHECK-NEXT: [[SCALE_RHS2:%[^ ]+]] = f32[] parameter(7) +; CHECK-NEXT: [[SCALE_RHS2_BCAST:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[SCALE_RHS2]]), dimensions={} +; CHECK-NEXT: [[RHS2_SCALED:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[RHS2_F32]], [[SCALE_RHS2_BCAST]]) +; CHECK-NEXT: [[DOT2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dot([[LHS2_SCALED]], [[RHS2_SCALED]]), +; CHECK-DAG: lhs_contracting_dims={2}, +; CHECK-DAG: rhs_contracting_dims={0} +; CHECK-NEXT: ROOT [[OUT:[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[DOT1]], [[DOT2]]) +)"); +} + +TEST_F(WindowedEinsumHandlerTest, AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 @@ -900,7 +1023,7 @@ ENTRY main.12_spmd { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -914,5 +1037,6 @@ ENTRY main.12_spmd { EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); } + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 6a9a66539854dc..42ab73a15a0412 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/permutation_util.h" #include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -1011,15 +1012,10 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( if (hlo.opcode() == HloOpcode::kPad) { return "Pads are not fused yet."; } - for (const HloInstruction* operand : hlo.operands()) { - if (!legacy_triton::IsTritonSupportedDataType( - operand->shape().element_type(), gpu_version)) { - return "Unsupported input data type."; - } - } - if (!legacy_triton::IsTritonSupportedDataType(hlo.shape().element_type(), - gpu_version)) { - return "Unsupported output data type."; + if (auto decision = + legacy_triton::IsTritonSupportedInstruction(hlo, gpu_version); + !decision.CanFuse()) { + return decision; } DimOrdersAndReqsOrError result_or_error = GetPropagatedDimOrdersAndRequirements(hlo, src_dim_order, diff --git a/third_party/xla/xla/service/heap_simulator/BUILD b/third_party/xla/xla/service/heap_simulator/BUILD index c0b486ffb0a8c1..bada0fdf1d597b 100644 --- a/third_party/xla/xla/service/heap_simulator/BUILD +++ b/third_party/xla/xla/service/heap_simulator/BUILD @@ -43,6 +43,8 @@ cc_library( "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_buffer", + "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service:logical_buffer", @@ -59,6 +61,8 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -68,16 +72,25 @@ xla_cc_test( deps = [ ":allocation_block", ":heap_simulator", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:buffer_value", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_dataflow_analysis", + "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc index fc319e681f769a..a499fcf119c424 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -43,13 +44,22 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/map_util.h" +#include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" +#include "xla/service/logical_buffer.h" #include "xla/service/time_utils.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -219,14 +229,12 @@ absl::StatusOr HeapSimulator::MinimumMemoryForModule( absl::StatusOr HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map* - memory_by_computation) { + const LogicalBuffer::SizeFunction& size_function) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(std::make_unique>(), computation, sequence, alias_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Options())); return result.heap_size; } @@ -267,11 +275,9 @@ absl::StatusOr> HeapSimulator::Run( const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, - const BufferValue::SizeFunction& size_fn, const Options& options, - const absl::flat_hash_map* - memory_by_computation) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*schedule=*/nullptr, memory_by_computation); + /*schedule=*/nullptr); HloSchedule schedule(computation.parent()); schedule.set_sequence(&computation, instruction_sequence); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, @@ -291,7 +297,7 @@ absl::StatusOr> HeapSimulator::Run( const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*schedule=*/schedule, nullptr); + /*schedule=*/schedule); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_live_range, HloLiveRange::Run(*schedule, alias_analysis, &computation)); @@ -492,19 +498,16 @@ absl::Status HeapSimulator::RunComputation( return absl::OkStatus(); } -HeapSimulator::HeapSimulator( - std::unique_ptr> algorithm, - const BufferValue::SizeFunction& size_fn, const Options& options, - const HloSchedule* schedule, - const absl::flat_hash_map* - memory_by_computation) +HeapSimulator::HeapSimulator(std::unique_ptr> algorithm, + const BufferValue::SizeFunction& size_fn, + const Options& options, + const HloSchedule* schedule) : no_fragmentation_stats_( std::make_unique>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - schedule_(schedule), - memory_by_computation_(memory_by_computation) { + schedule_(schedule) { debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } @@ -629,21 +632,10 @@ void NoFragmentationStatsHeap::Alloc(const BufferType* buffer, template void NoFragmentationStatsHeap::AccountForSubcomputationMemory( - const HloInstruction* instruction, int64_t alloc_size_by_instruction, - const absl::flat_hash_map& - memory_by_computation) { + const HloInstruction* instruction, int64_t alloc_size_by_instruction) { // We only count the memory usage of the largest subcomputation, instead of // adding them all, because subcomputations won't execute in parallel. int64_t max_subcomputation_bytes = 0; - for (const auto* c : instruction->called_computations()) { - auto it = memory_by_computation.find(c); - if (it != memory_by_computation.end()) { - int64_t subcomputation_bytes = it->second; - if (subcomputation_bytes > max_subcomputation_bytes) { - max_subcomputation_bytes = subcomputation_bytes; - } - } - } if (max_subcomputation_bytes > 0 && (instruction->opcode() == HloOpcode::kWhile || instruction->opcode() == HloOpcode::kCall || @@ -1019,6 +1011,23 @@ std::string BufferIntervalTree::NodesOverlappingInTimeToAsciiArt( memory_map); } +std::vector BufferIntervalTree::MemoryUsedInInterval( + int64_t start, int64_t end) const { + int64_t total_time = end - start + 1; + CHECK_GE(total_time, 0); + std::vector nodes = + NodesOverlappingInTime(start, end); + std::vector memory_used_in_interval(total_time, 0); + for (const BufferIntervalTreeNode* node : nodes) { + int64_t node_start = std::max(node->start, start); + int64_t node_end = std::min(node->end, end); + for (int64_t time = node_start; time <= node_end; ++time) { + memory_used_in_interval[time - start] += node->chunk.size; + } + } + return memory_used_in_interval; +} + template std::string GlobalDecreasingSizeBestFitHeap::BufferInterval::ToString() const { diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.h b/third_party/xla/xla/service/heap_simulator/heap_simulator.h index 09e12d2aca7042..6d5f4558b6e6b4 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.h +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.h @@ -34,7 +34,9 @@ limitations under the License. #endif #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" @@ -148,9 +150,7 @@ class HeapSimulator { static absl::StatusOr MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map* - memory_by_computation = nullptr); + const LogicalBuffer::SizeFunction& size_function); static absl::StatusOr MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, @@ -184,9 +184,7 @@ class HeapSimulator { const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, - const Options& options = Options(), - const absl::flat_hash_map* - memory_by_computation = nullptr); + const Options& options = Options()); // Same as above, but runs on with a schedule that covers all nested // computations. @@ -204,9 +202,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator(std::unique_ptr> algorithm, const BufferValue::SizeFunction& size_fn, - const Options& options, const HloSchedule* schedule = nullptr, - const absl::flat_hash_map* - memory_by_computation = nullptr); + const Options& options, const HloSchedule* schedule = nullptr); ~HeapSimulator(); absl::Status RunComputation( @@ -244,13 +240,10 @@ class HeapSimulator { const std::unique_ptr> algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // schedule_ is set by buffer assignment, and memory_by_computation_ is - // set by hlo scheduling. Then, in RunComputation, we check both in order to - // handle subcomputations. It would be good to unify the handling of - // subcomputations, but it's not clear how. + // schedule_ is set by buffer assignment. Then, in RunComputation, we check + // both in order to handle subcomputations. It would be good to unify the + // handling of subcomputations, but it's not clear how. const HloSchedule* schedule_; - const absl::flat_hash_map* - memory_by_computation_; // Hold some sets for error-checking the sequence of Alloc and Free calls. absl::flat_hash_set allocated_buffers_; @@ -290,9 +283,7 @@ class HeapAlgorithm { virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, // The total number of bytes allocated by instruction. - int64_t alloc_size_by_instruction, - const absl::flat_hash_map& - memory_by_computation) {} + int64_t alloc_size_by_instruction) {} // Free de-allocates a previously allocated buffer. virtual void Free(const BufferType* buffer, int64_t size) = 0; @@ -328,9 +319,8 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferType* buffer, int64_t size) override; void AccountForSubcomputationMemory( - const HloInstruction* instruction, int64_t alloc_size_by_instruction, - const absl::flat_hash_map& - memory_by_computation) override; + const HloInstruction* instruction, + int64_t alloc_size_by_instruction) override; void Free(const BufferType* buffer, int64_t size) override; @@ -417,6 +407,11 @@ class BufferIntervalTree { std::string NodesOverlappingInTimeToAsciiArt(int64_t start, int64_t end, int64_t group_size = 0) const; + // Returns a vector of size `end - start + 1` where the element at index i is + // the memory used at the time instant `start + i`. Both `start` and `end` are + // inclusive. + std::vector MemoryUsedInInterval(int64_t start, int64_t end) const; + private: std::vector NodesOverlappingInTime( int64_t start, int64_t end) const; diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc index cff0e2f3a72547..878030b01e99b3 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc @@ -25,22 +25,35 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/literal_util.h" #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { namespace { +using ::testing::ContainerEq; using ::testing::HasSubstr; using ::testing::StrEq; @@ -210,9 +223,6 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; - absl::flat_hash_map memory_by_computation; - memory_by_computation[cond_computation] = 5; - memory_by_computation[body_computation] = 16; std::unique_ptr alias_analysis = HloAliasAnalysis::Run(module.get()).value(); @@ -221,7 +231,7 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { // so we don't double count. EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( *entry_computation, schedule.sequence(entry_computation), - *alias_analysis, size_fn, &memory_by_computation) + *alias_analysis, size_fn) .value()); } @@ -2057,6 +2067,20 @@ TEST_F(IntervalTreeTest, BufferIntervalTreeToAsciiArtFreeMemory) { EXPECT_THAT(output, StrEq("No nodes overlapping in time. Memory is free!")); } +TEST_F(IntervalTreeTest, BufferIntervalTreeMemoryUsedInInterval) { + // Buffer 1: memory block [0, 16), time interval [15, 25] + // Buffer 2: memory block [16, 48), time interval [15, 19] + // Buffer 3: memory block [32, 64), time interval [20, 22] + BufferIntervalTree tree; + tree.Add(15, 25, HeapSimulator::Chunk::FromOffsetEnd(0, 16)); + tree.Add(15, 19, HeapSimulator::Chunk::FromOffsetEnd(16, 48)); + tree.Add(20, 22, HeapSimulator::Chunk::FromOffsetEnd(32, 64)); + std::vector memory_used_by_time = tree.MemoryUsedInInterval( + /*start=*/18, /*end=*/23); + std::vector expected_memory_used_by_time = {48, 48, 48, 48, 48, 16}; + EXPECT_THAT(memory_used_by_time, ContainerEq(expected_memory_used_by_time)); +} + class SlicedBufferIntervalTest : public ::testing::Test { public: using HeapTy = GlobalDecreasingSizeBestFitHeap; diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc index ece63e311a55dd..a7190b33f2088d 100644 --- a/third_party/xla/xla/service/hlo_computation_test.cc +++ b/third_party/xla/xla/service/hlo_computation_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" +#include #include -#include #include #include #include @@ -24,19 +24,24 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_cse.cc b/third_party/xla/xla/service/hlo_cse.cc index 2594fa392a5c1c..3162204e68c6dd 100644 --- a/third_party/xla/xla/service/hlo_cse.cc +++ b/third_party/xla/xla/service/hlo_cse.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -316,6 +317,32 @@ absl::StatusOr HloCSE::Run( } } } + if (auto fusion = computation->FusionInstruction()) { + if (fusion->IsMultiOutputFusion()) { + // Attach users to the representative instruction, thus making the + // duplicate fusion roots unused. HloDCE can then cleanup the unused + // fusion roots. + absl::flat_hash_map + root_to_unique_index; + int64_t root_index = 0; + HloInstruction* root = computation->root_instruction(); + for (const HloInstruction* hlo : root->operands()) { + if (root_to_unique_index.find(hlo) == root_to_unique_index.end()) { + root_to_unique_index[hlo] = root_to_unique_index[hlo] = root_index; + } + ++root_index; + } + if (root_to_unique_index.size() < root->operand_count()) { + for (HloInstruction* user : fusion->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + const HloInstruction* fusion_root = + root->operand(user->tuple_index()); + user->set_tuple_index(root_to_unique_index[fusion_root]); + } + } + } + } + } } return changed; } diff --git a/third_party/xla/xla/service/hlo_cse_test.cc b/third_party/xla/xla/service/hlo_cse_test.cc index 106eea0923b0be..f6378353b8d507 100644 --- a/third_party/xla/xla/service/hlo_cse_test.cc +++ b/third_party/xla/xla/service/hlo_cse_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_computation.h" @@ -918,7 +919,10 @@ TEST_F(HloCseTest, MultiOutputFusion) { ENTRY entry { p0 = f32[] parameter(0) p1 = f32[] parameter(1) - ROOT root = (f32[], f32[]) fusion(p0, p1), kind=kLoop, calls=f + fusion = (f32[], f32[]) fusion(p0, p1), kind=kLoop, calls=f + gte0 = f32[] get-tuple-element(fusion), index=0 + gte1 = f32[] get-tuple-element(fusion), index=1 + ROOT res = (f32[], f32[]) tuple(gte0, gte1) } )"; @@ -928,10 +932,18 @@ TEST_F(HloCseTest, MultiOutputFusion) { SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString())); EXPECT_EQ(changed, true); + HloInstruction* root = m->entry_computation()->root_instruction(); HloInstruction* add0; HloInstruction* add1; + HloInstruction* gte0; + HloInstruction* gte1; + ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(>e0), + m::GetTupleElement(>e1)))); + EXPECT_EQ(gte0, gte1); + EXPECT_EQ(gte0->tuple_index(), 0); + const HloInstruction* fusion = gte0->operand(0); ASSERT_THAT( - m->entry_computation()->root_instruction()->fused_expression_root(), + fusion->fused_expression_root(), GmockMatch(m::Tuple(m::Add(&add0, m::Parameter(0), m::Parameter(1)), m::Add(&add1, m::Parameter(0), m::Parameter(1))))); EXPECT_EQ(add0, add1); diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index 981b967d997690..7709bda6032e7f 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -15,6 +15,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" +#include +#include +#include +#include +#include #include #include #include @@ -22,16 +27,23 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" #include "xla/protobuf_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" @@ -40,6 +52,7 @@ limitations under the License. #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -2706,6 +2719,16 @@ TEST_F(HloInstructionTest, VerifyBodyComputationPointsToWhile) { } } EXPECT_EQ(num_while_body_comp, 1); + + for (HloInstruction* instruction : + module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile) { + HloComputation* while_body = instruction->while_body(); + EXPECT_TRUE(while_body->IsWhileBodyComputation()); + HloInstruction* while_back_ref = while_body->WhileCallInstruction(); + EXPECT_EQ(while_back_ref->while_body(), while_body); + } + } } TEST_F(HloInstructionTest, @@ -2752,7 +2775,7 @@ TEST_F(HloInstructionTest, module->AddEntryComputation(main_builder.Build()); // Should find conditional branch computations in the graph and it should - // point to the conditonal instruction. + // point to the conditional instruction. int num_conditional_branch_comp = 0; for (HloComputation* comp : module->MakeComputationPostOrder()) { if (comp->IsConditionalBranchComputation()) { @@ -2827,7 +2850,7 @@ TEST_F(HloInstructionTest, module->AddEntryComputation(main_builder.Build()); // Should find conditional branch computations in the graph and it should - // point to the conditonal instruction. + // point to the conditional instruction. int num_conditional_branch_comp = 0; for (HloComputation* comp : module->MakeComputationPostOrder()) { if (comp->IsConditionalBranchComputation()) { diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.cc b/third_party/xla/xla/service/hlo_memory_scheduler.cc index 283b82e23ec738..83e40723895289 100644 --- a/third_party/xla/xla/service/hlo_memory_scheduler.cc +++ b/third_party/xla/xla/service/hlo_memory_scheduler.cc @@ -90,11 +90,8 @@ class ListScheduler { static absl::StatusOr Run( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation) { - ListScheduler scheduler(computation, points_to_analysis, size_function, - memory_by_computation); + const BufferValue::SizeFunction& size_function) { + ListScheduler scheduler(computation, points_to_analysis, size_function); return scheduler.CreateSchedule(); } @@ -115,13 +112,10 @@ class ListScheduler { ListScheduler(HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation) + const BufferValue::SizeFunction& size_function) : computation_(computation), points_to_analysis_(points_to_analysis), - size_function_(size_function), - memory_by_computation_(memory_by_computation) { + size_function_(size_function) { // Create a map containing the LogicalBuffer uses for each HLO // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by @@ -242,29 +236,7 @@ class ListScheduler { freed_bytes += size_function_(*buffer); } } - // We only count the memory usage of the largest subcomputation, instead of - // adding them all, because subcomputations won't execute in parallel. - int64_t max_subcomputation_bytes = 0; - for (const auto* c : instruction->called_computations()) { - auto it = memory_by_computation_.find(c); - if (it != memory_by_computation_.end()) { - int64_t subcomputation_bytes = it->second; - if (subcomputation_bytes > max_subcomputation_bytes) { - max_subcomputation_bytes = subcomputation_bytes; - } - } - } - int64_t bytes_defined; - if (max_subcomputation_bytes > 0 && - (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || - opcode == HloOpcode::kConditional)) { - // The output buffer of while/call/conditional is always aliased with the - // output buffer of the root instruction in the body. Don't double count. - bytes_defined = max_subcomputation_bytes; - } else { - bytes_defined = entry.bytes_defined + max_subcomputation_bytes; - } - return freed_bytes - bytes_defined; + return freed_bytes - entry.bytes_defined; } // Constructs the scheduling priority of the given instruction. @@ -392,11 +364,6 @@ class ListScheduler { HloComputation* computation_; const TuplePointsToAnalysis& points_to_analysis_; const BufferValue::SizeFunction& size_function_; - // Computations are analyzed in post-order. When scheduling an instruction - // that includes subcomputations, such as a while loop, we use this map to - // look up the memory needed by subcomputations. - const absl::flat_hash_map& - memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. absl::flat_hash_map> @@ -426,19 +393,15 @@ absl::StatusOr ScheduleComputationHelper( const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { VLOG(2) << "Computation: " << computation->name(); if (algorithm) { return algorithm(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, postprocessor, - peak_memory); + size_function, postprocessor, peak_memory); } return DefaultMemoryScheduler(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, - postprocessor, peak_memory); + size_function, postprocessor, peak_memory); } } // namespace @@ -448,8 +411,6 @@ absl::StatusOr DFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { // These variables are a hack to prevent overflows. int64_t cumulative_total_size = 0; @@ -526,9 +487,9 @@ absl::StatusOr DFSMemoryScheduler( CHECK_EQ(sequence.size(), computation->instruction_count()); if (peak_memory) { TF_ASSIGN_OR_RETURN( - *peak_memory, HeapSimulator::MinimumMemoryForComputation( - *computation, sequence, alias_analysis, size_function, - &memory_by_computation)); + *peak_memory, + HeapSimulator::MinimumMemoryForComputation( + *computation, sequence, alias_analysis, size_function)); } return sequence; } @@ -538,8 +499,6 @@ absl::StatusOr BFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { // Index of HloInstruction in the `computation`. absl::flat_hash_map inst_index; @@ -586,9 +545,9 @@ absl::StatusOr BFSMemoryScheduler( CHECK_EQ(sequence.size(), computation->instruction_count()); if (peak_memory) { TF_ASSIGN_OR_RETURN( - *peak_memory, HeapSimulator::MinimumMemoryForComputation( - *computation, sequence, alias_analysis, size_function, - &memory_by_computation)); + *peak_memory, + HeapSimulator::MinimumMemoryForComputation( + *computation, sequence, alias_analysis, size_function)); } return sequence; @@ -605,16 +564,14 @@ ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( const absl::flat_hash_set& execution_threads, int64_t* peak_memory) -> absl::StatusOr { HloSchedule schedule(module); - absl::flat_hash_map memory_by_computation; for (auto* computation : module->MakeComputationPostOrder(execution_threads)) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN( - HloInstructionSequence computation_sequence, - ScheduleComputationHelper( - computation, points_to_analysis, alias_analysis, size_func, - computation_scheduler, memory_by_computation, postprocessor, - /*peak_memory=*/nullptr)); + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, + ScheduleComputationHelper( + computation, points_to_analysis, alias_analysis, + size_func, computation_scheduler, postprocessor, + /*peak_memory=*/nullptr)); schedule.set_sequence(computation, std::move(computation_sequence)); } } @@ -631,20 +588,18 @@ absl::StatusOr ListMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { - TF_ASSIGN_OR_RETURN(HloInstructionSequence sequence, - ListScheduler::Run(computation, points_to_analysis, - size_function, memory_by_computation)); + TF_ASSIGN_OR_RETURN( + HloInstructionSequence sequence, + ListScheduler::Run(computation, points_to_analysis, size_function)); if (postprocessor) { sequence = postprocessor(sequence); } if (peak_memory) { TF_ASSIGN_OR_RETURN( - *peak_memory, HeapSimulator::MinimumMemoryForComputation( - *computation, sequence, alias_analysis, size_function, - &memory_by_computation)); + *peak_memory, + HeapSimulator::MinimumMemoryForComputation( + *computation, sequence, alias_analysis, size_function)); } return sequence; } @@ -654,8 +609,6 @@ absl::StatusOr PostOrderMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { HloInstructionSequence sequence(computation->MakeInstructionPostOrder()); if (postprocessor) { @@ -663,9 +616,9 @@ absl::StatusOr PostOrderMemoryScheduler( } if (peak_memory) { TF_ASSIGN_OR_RETURN( - *peak_memory, HeapSimulator::MinimumMemoryForComputation( - *computation, sequence, alias_analysis, size_function, - &memory_by_computation)); + *peak_memory, + HeapSimulator::MinimumMemoryForComputation( + *computation, sequence, alias_analysis, size_function)); } return sequence; } @@ -675,8 +628,6 @@ absl::StatusOr DefaultMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { // We try a few schedulers and choose whichever returns a lower min-memory, // not accounting for fragmentation. @@ -690,24 +641,21 @@ absl::StatusOr DefaultMemoryScheduler( TF_ASSIGN_OR_RETURN( HloInstructionSequence list_sequence, ListMemoryScheduler(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, postprocessor, - &list_memory)); + size_function, postprocessor, &list_memory)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); int64_t dfs_memory; TF_ASSIGN_OR_RETURN( HloInstructionSequence dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, postprocessor, - &dfs_memory)); + size_function, postprocessor, &dfs_memory)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); int64_t post_order_memory; - TF_ASSIGN_OR_RETURN( - HloInstructionSequence post_order_sequence, - PostOrderMemoryScheduler(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, - postprocessor, &post_order_memory)); + TF_ASSIGN_OR_RETURN(HloInstructionSequence post_order_sequence, + PostOrderMemoryScheduler( + computation, points_to_analysis, alias_analysis, + size_function, postprocessor, &post_order_memory)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -815,21 +763,6 @@ absl::StatusOr ScheduleModule( return std::move(schedule); } -absl::StatusOr ScheduleComputation( - HloComputation* computation, const BufferValue::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor) { - CHECK(!computation->IsFusionComputation()); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation->parent())); - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(computation->parent())); - absl::flat_hash_map empty_map; - return ScheduleComputationHelper( - computation, *points_to_analysis, *alias_analysis, size_function, - /*algorithm=*/nullptr, empty_map, postprocessor, - /*peak_memory=*/nullptr); -} - HloMemoryScheduler::HloMemoryScheduler( const BufferValue::SizeFunction& size_function, const ModuleSchedulerAlgorithm& algorithm) diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.h b/third_party/xla/xla/service/hlo_memory_scheduler.h index 112ced3ee95112..2fb211ac6531a2 100644 --- a/third_party/xla/xla/service/hlo_memory_scheduler.h +++ b/third_party/xla/xla/service/hlo_memory_scheduler.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -51,7 +50,6 @@ using MemorySchedulerAlgorithm = std::function( HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, const LogicalBuffer::SizeFunction&, - const absl::flat_hash_map&, const MemorySchedulerPostprocessor&, /*peak_memory*/ int64_t*)>; @@ -73,8 +71,6 @@ absl::StatusOr ListMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // DFS-order scheduler @@ -83,8 +79,6 @@ absl::StatusOr DFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // BFS-order scheduler @@ -102,8 +96,6 @@ absl::StatusOr BFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // Naive Post Order scheduler @@ -112,8 +104,6 @@ absl::StatusOr PostOrderMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // The default scheduling algorithm. Runs the list scheduler, the DFS scheduler, @@ -125,8 +115,6 @@ absl::StatusOr DefaultMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); absl::StatusOr DefaultModuleScheduler( @@ -146,13 +134,6 @@ absl::StatusOr ScheduleModule( const absl::flat_hash_set& execution_threads = {}, int64_t* peak_memory = nullptr); -// Computes the schedule for a single computation. -// Currently only used by the GPU backend. -absl::StatusOr ScheduleComputation( - HloComputation* computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor); - // A pass which schedules the HLO instructions in a module. The HloModule's // schedule field is set to the resulting HloSchedule using // HloModule::set_schedule. diff --git a/third_party/xla/xla/service/hlo_module_test.cc b/third_party/xla/xla/service/hlo_module_test.cc index 291093213e6719..f2375751a90f55 100644 --- a/third_party/xla/xla/service/hlo_module_test.cc +++ b/third_party/xla/xla/service/hlo_module_test.cc @@ -37,9 +37,9 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/hlo_ordering.cc b/third_party/xla/xla/service/hlo_ordering.cc index 466f64cee1d49f..388de97291fab1 100644 --- a/third_party/xla/xla/service/hlo_ordering.cc +++ b/third_party/xla/xla/service/hlo_ordering.cc @@ -363,14 +363,13 @@ bool HloOrdering::UsesBeforeValueDefinition( return true; } } - // The use at an async call occurs before values that are defined in the - // called computation of the async wrapped instruction. - if (use.instruction->IsAsynchronous() && - use.instruction->async_wrapped_opcode() == HloOpcode::kCall) { + // The use at an async op occurs before values that are defined in the async + // wrapped computation or any of its nested computations. + if (use.instruction->IsAsynchronous()) { const HloInstruction* async = use.instruction; if (call_graph_->InstructionIsNestedIn( value.defining_instruction(), - async->async_wrapped_instruction()->to_apply())) { + async->async_wrapped_computation())) { VLOG(4) << " use is async " << use.instruction->name() << " and def is in called computation"; return true; diff --git a/third_party/xla/xla/service/hlo_ordering_test.cc b/third_party/xla/xla/service/hlo_ordering_test.cc index 743f9f24f20c5e..c0b1dc9c0c6bb7 100644 --- a/third_party/xla/xla/service/hlo_ordering_test.cc +++ b/third_party/xla/xla/service/hlo_ordering_test.cc @@ -675,6 +675,7 @@ ENTRY %main { HloInstruction* async_wrapped_call = FindInstruction(module.get(), "async_wrapped_call"); HloInstruction* p0 = FindInstruction(module.get(), "p0"); + HloInstruction* broadcast1 = FindInstruction(module.get(), "broadcast.1"); ASSERT_NE(async_start, nullptr); ASSERT_NE(async_done, nullptr); @@ -685,13 +686,16 @@ ENTRY %main { HloUse async_done_use = HloUse{async_done, 0, {0, 0}}; HloUse call_use = HloUse{async_wrapped_call, 0}; const HloValue& value = dataflow->GetUniqueValueAt(async_wrapped_call, {}); + const HloValue& broadcast_value = dataflow->GetUniqueValueAt(broadcast1, {}); DependencyHloOrdering ordering(module.get()); - EXPECT_FALSE( + EXPECT_TRUE( ordering.UsesBeforeValueDefinition({&async_start_use}, value, *dataflow)); + EXPECT_TRUE(ordering.UsesBeforeValueDefinition({&async_start_use}, + broadcast_value, *dataflow)); EXPECT_FALSE( ordering.UsesBeforeValueDefinition({&call_use}, value, *dataflow)); - EXPECT_FALSE( + EXPECT_TRUE( ordering.UsesBeforeValueDefinition({&async_done_use}, value, *dataflow)); } @@ -795,11 +799,11 @@ ENTRY %main { const HloValue& value = dataflow->GetUniqueValueAt(async_wrapped_call, {}); DependencyHloOrdering ordering(module.get()); - EXPECT_FALSE( + EXPECT_TRUE( ordering.UsesBeforeValueDefinition({&async_start_use}, value, *dataflow)); EXPECT_FALSE( ordering.UsesBeforeValueDefinition({&call_use}, value, *dataflow)); - EXPECT_FALSE( + EXPECT_TRUE( ordering.UsesBeforeValueDefinition({&async_done_use}, value, *dataflow)); } diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index d3b3038a3703f4..2ff069772c771e 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -80,6 +80,7 @@ limitations under the License. #include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" namespace xla { @@ -1915,6 +1916,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT std::vector async_wrapped_operands; std::vector async_wrapped_operand_shapes; Shape async_wrapped_root_shape; + async_wrapped_operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { async_wrapped_operand_shapes.push_back(operand->shape()); } diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index 0fb60432c4a0a4..6378f08744e76f 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" +#include #include #include #include @@ -22,15 +23,25 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/array.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/service/hlo_lexer.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" @@ -39,6 +50,7 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/hlo_unstacker.cc b/third_party/xla/xla/service/hlo_unstacker.cc index c6b0971f4f3312..21d0eb9d42a27f 100644 --- a/third_party/xla/xla/service/hlo_unstacker.cc +++ b/third_party/xla/xla/service/hlo_unstacker.cc @@ -790,14 +790,8 @@ absl::Status UnstackDSFusionPattern( HloInstruction* bitcast = mutable_dynamic_slicing_fusion->AddInstruction( HloInstruction::CreateBitcast(mutable_dynamic_slicing_fusion->shape(), new_operand)); - HloInstruction* bitcast_fusion = - mutable_dynamic_slicing_fusion->AddInstruction( - HloInstruction::CreateFusion(mutable_dynamic_slicing_fusion->shape(), - HloInstruction::FusionKind::kLoop, - bitcast)); - return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape( - bitcast_fusion); + bitcast); } // This function recognizes fusions with the following pattern: @@ -1430,6 +1424,7 @@ absl::StatusOr HloUnstacker::Run( /*force_unroll=*/true, /*prepare=*/false)); CHECK(unrolled); } + VLOG(3) << "after unstacking \n" << module->ToString(); return true; } diff --git a/third_party/xla/xla/service/hlo_unstacker_test.cc b/third_party/xla/xla/service/hlo_unstacker_test.cc index 4c1d4f8dffa328..3b00f9236a1ae7 100644 --- a/third_party/xla/xla/service/hlo_unstacker_test.cc +++ b/third_party/xla/xla/service/hlo_unstacker_test.cc @@ -34,15 +34,15 @@ namespace { using UnstackerTest = HloTestBase; -int64_t GetSliceCountInEntry(HloModule* module) { - int64_t slice_instrs_count = 0; +int64_t GetInstrCountWithOpcodeInEntry(HloModule* module, HloOpcode opcode) { + int64_t instr_with_opcode_count = 0; for (HloInstruction* instr : module->entry_computation()->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kSlice) { - slice_instrs_count++; + if (instr->opcode() == opcode) { + instr_with_opcode_count++; } } - return slice_instrs_count; + return instr_with_opcode_count; } TEST_F(UnstackerTest, UnstackDSFusionPattern) { @@ -63,7 +63,8 @@ TEST_F(UnstackerTest, UnstackDSFusionPattern) { p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) } @@ -80,7 +81,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPattern) { init = s32[] constant(0) while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body - while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } )"; @@ -90,9 +91,12 @@ TEST_F(UnstackerTest, UnstackDSFusionPattern) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); + // Check that the bitcast is unfused and there are not fusions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 0); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), - std::nullopt)); + std::nullopt, false)); } TEST_F(UnstackerTest, UnstackReduceFusionPattern) { @@ -148,7 +152,7 @@ TEST_F(UnstackerTest, UnstackReduceFusionPattern) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), - std::nullopt)); + std::nullopt, false)); } TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcast) { @@ -195,10 +199,12 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcast) { ParseAndReturnVerifiedModule(hlo_string)); auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); - std::cout << module->ToString() << std::endl; EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); + // Check that all the fusions are removed. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 0); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -249,10 +255,12 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcastKeepFused) { auto unfuse = [](HloInstruction* instruction) { return false; }; TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker(unfuse).Run(module.get())); - std::cout << module->ToString() << std::endl; EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 0); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 0); + // Check that dynamic-slices are still fused. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -261,21 +269,21 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternWithDifferentLayout) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.30.clone (param_0.153: bf16[32,4,64,64,3], param_1.123: s32[]) -> bf16[64,4,64,3] { - %param_0.153 = bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} parameter(0) + %param_0.153 = bf16[32,4,64,64,3]{2,1,4,3,0} parameter(0) %param_1.123 = s32[]{:T(128)} parameter(1) %constant.227 = s32[]{:T(128)} constant(0) - %dynamic-slice.5 = bf16[1,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} dynamic-slice(bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} %param_0.153, s32[]{:T(128)} %param_1.123, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, /*index=5*/s32[]{:T(128)} %constant.227), dynamic_slice_sizes={1,4,64,64,3}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]} - ROOT %bitcast.102 = bf16[64,4,64,3]{0,1,3,2:T(4,128)(2,1)} bitcast(bf16[1,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} %dynamic-slice.5) + %dynamic-slice.5 = bf16[1,4,64,64,3]{2,1,4,3,0} dynamic-slice(bf16[32,4,64,64,3]{2,1,4,3,0} %param_0.153, s32[]{:T(128)} %param_1.123, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, /*index=5*/s32[]{:T(128)} %constant.227), dynamic_slice_sizes={1,4,64,64,3} + ROOT %bitcast.102 = bf16[64,4,64,3]{0,1,3,2} bitcast(bf16[1,4,64,64,3]{2,1,4,3,0} %dynamic-slice.5) } %while.body (wide_param: (s32[], bf16[8,128], bf16[32,4,64,64,3])) -> (s32[], bf16[8,128], bf16[32,4,64,64,3]) { wide_p = (s32[], bf16[8,128], bf16[32,4,64,64,3]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 p0 = bf16[8,128] get-tuple-element(wide_p), index=1 - p1 = bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} get-tuple-element(wide_p), index=2 + p1 = bf16[32,4,64,64,3]{2,1,4,3,0} get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - %fusion.67830 = bf16[64,4,64,3]{0,1,3,2:T(4,128)(2,1)} fusion(p1, i), kind=kLoop, calls=%fused_computation.30.clone + %fusion.67830 = bf16[64,4,64,3]{0,1,3,2} fusion(p1, i), kind=kLoop, calls=%fused_computation.30.clone ROOT out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) tuple(inc, p0, p1) } @@ -291,7 +299,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternWithDifferentLayout) { p1 = bf16[8,128] parameter(1) init = s32[] constant(0) while.input = (s32[], bf16[8,128], bf16[32,4,64,64,3]) tuple(init, p1, p0) - while.out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) while(while.input), condition=%while.cond , body=%while.body + while.out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) while(while.input), condition=%while.cond , body=%while.body while_use = bf16[32,4,64,64,3] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } @@ -301,6 +309,12 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternWithDifferentLayout) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), + 32); + // Check that dynamic-slices are still fused. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 0); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt)); } @@ -358,7 +372,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPattern) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -497,7 +511,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithMultipleIndex) { EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. For each unstacked input, we // create 4 slices, 8 in total. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 8); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 8); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -555,7 +569,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithDiffereOperandsOrder) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -631,7 +645,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithSameUnstackingComps) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -765,7 +779,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternSingleNestedLoop) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 4); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 4); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -906,7 +920,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternTwoNestedLoops) { EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. For each loop there is one // unstacked input that creates 4 slices, in total 8 slices for two loops. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 8); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 8); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index 67259f886f3383..06ae99051ddcce 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -1757,6 +1757,7 @@ absl::Status HloValueSemanticsPropagation::HandleConditional( [&](const ShapeIndex& index, const HloValueSemantics* semantics) -> absl::Status { std::vector semantics_vector; + semantics_vector.reserve(semantics_tree_vec.size()); for (size_t i = 0; i < semantics_tree_vec.size(); ++i) { semantics_vector.push_back( *(semantics_tree_vec[i].find(index)->second)); diff --git a/third_party/xla/xla/service/host_offload_utils.cc b/third_party/xla/xla/service/host_offload_utils.cc new file mode 100644 index 00000000000000..203c08e9d0c39a --- /dev/null +++ b/third_party/xla/xla/service/host_offload_utils.cc @@ -0,0 +1,243 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offload_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/shape_util.h" +#include "xla/util.h" + +namespace xla { +namespace host_offload_utils { + +namespace { + +using ::xla::host_memory_offload_annotations::kMoveToDeviceCustomCallTarget; +using ::xla::host_memory_offload_annotations::kMoveToHostCustomCallTarget; + +bool CustomCallReusesBuffer(const HloInstruction* custom_call, + int64_t operand_index) { + if (custom_call->custom_call_target() == kMoveToDeviceCustomCallTarget || + custom_call->custom_call_target() == kMoveToHostCustomCallTarget) { + // Does not define a new buffer. + return true; + } + // Check the custom call's output_to_operand_aliasing. + const std::vector>>& + aliases = custom_call->output_operand_aliasing(); + for (const std::pair>& alias : + aliases) { + int64_t alias_operand_index = alias.second.first; + if (alias_operand_index == operand_index) { + // This operand aliases with the output. + return true; + } + } + // By default, assume custom calls define new buffers. + return false; +} + +} // namespace + +absl::StatusOr> GetSuccessors( + const InstructionAndShapeIndex& instruction_and_shape_index) { + std::vector result; + HloInstruction* instruction = instruction_and_shape_index.instruction; + if (instruction->IsRoot()) { + // Successor of the root is the call instruction(s). + std::unique_ptr call_graph = + CallGraph::Build(instruction->GetModule()); + auto callers = call_graph->GetComputationCallers(instruction->parent()); + for (HloInstruction* caller : callers) { + result.push_back({caller, instruction_and_shape_index.shape_index}); + } + } + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kTuple) { + auto operand_indices = user->OperandIndices(instruction); + for (const auto i : operand_indices) { + auto tmp_shape_index = instruction_and_shape_index.shape_index; + tmp_shape_index.push_back(i); + result.push_back({user, std::move(tmp_shape_index)}); + } + } else if (user->opcode() == HloOpcode::kGetTupleElement) { + ShapeIndex tmp_shape_index = instruction_and_shape_index.shape_index; + const auto index = tmp_shape_index.front(); + if (index == user->tuple_index()) { + // This GTE is for the buffer we're tracking. + tmp_shape_index.pop_front(); + result.push_back({user, std::move(tmp_shape_index)}); + } + } else if (user->opcode() == HloOpcode::kCall) { + auto operand_indices = user->OperandIndices(instruction); + CHECK(user->called_computations().size() == 1) + << "Expect call to only have one called computation."; + for (const auto i : operand_indices) { + HloComputation* called_computation = + user->called_computations().front(); + HloInstruction* parameter_instruction = + called_computation->parameter_instruction(i); + result.push_back( + {parameter_instruction, instruction_and_shape_index.shape_index}); + } + } else if (user->opcode() == HloOpcode::kWhile) { + auto operand_indices = user->OperandIndices(instruction); + HloComputation* while_body_computation = user->while_body(); + HloComputation* while_condition_computation = user->while_condition(); + for (const auto i : operand_indices) { + HloInstruction* parameter_instruction = + while_body_computation->parameter_instruction(i); + result.push_back( + {parameter_instruction, instruction_and_shape_index.shape_index}); + + HloInstruction* condition_instruction = + while_condition_computation->parameter_instruction(i); + result.push_back( + {condition_instruction, instruction_and_shape_index.shape_index}); + } + } else if (user->opcode() == HloOpcode::kAsyncStart) { + auto operand_indices = user->OperandIndices(instruction); + CHECK(user->called_computations().size() == 1) + << "Expect async-start to only have one called computation."; + for (const auto i : operand_indices) { + HloComputation* called_computation = + user->called_computations().front(); + HloInstruction* parameter_instruction = + called_computation->parameter_instruction(i); + result.push_back( + {parameter_instruction, instruction_and_shape_index.shape_index}); + } + } else if (user->opcode() == HloOpcode::kCustomCall) { + const auto operand_indices = user->OperandIndices(instruction); + // TODO(b/342650757): Rather than a boolean indicating whether the + // instruction reuses the buffer, return the shape index of the output + // that the operand aliases with. + bool found_one = false; + for (const auto i : operand_indices) { + if (CustomCallReusesBuffer(user, i)) { + if (found_one) { + return absl::InternalError( + "Found multiple operands of a custom call that reuse the same " + "output buffer."); + } + result.push_back({user, instruction_and_shape_index.shape_index}); + found_one = true; + } + } + } else { + result.push_back({user, instruction_and_shape_index.shape_index}); + } + } + return result; +} + +std::vector GetPredecessors( + const InstructionAndShapeIndex& instruction_and_shape_index) { + std::vector result; + HloInstruction* instruction = instruction_and_shape_index.instruction; + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + const int64_t index = instruction->tuple_index(); + auto tmp_shape_index = instruction_and_shape_index.shape_index; + tmp_shape_index.push_front(index); + result.push_back({instruction->mutable_operand(0), tmp_shape_index}); + } else if (instruction->opcode() == HloOpcode::kTuple) { + CHECK(!instruction_and_shape_index.shape_index.empty()) + << "Did not store an index before encountering a tuple."; + auto tmp_shape_index = instruction_and_shape_index.shape_index; + const int64_t index = tmp_shape_index.front(); + tmp_shape_index.pop_front(); + result.push_back({instruction->mutable_operand(index), tmp_shape_index}); + } else if (instruction->opcode() == HloOpcode::kCall) { + // Predecessor of a call is its computation's root instruction. + CHECK(instruction->called_computations().size() == 1) + << "Expect call to only have one called computation."; + HloComputation* called_computation = + instruction->called_computations().front(); + result.push_back({called_computation->root_instruction(), + instruction_and_shape_index.shape_index}); + } else if (instruction->opcode() == HloOpcode::kParameter) { + std::unique_ptr call_graph = + CallGraph::Build(instruction->GetModule()); + auto callers = call_graph->GetComputationCallers(instruction->parent()); + for (HloInstruction* caller : callers) { + result.push_back( + {caller->mutable_operand(instruction->parameter_number()), + instruction_and_shape_index.shape_index}); + } + } else if (instruction->opcode() == HloOpcode::kDynamicSlice) { + result.push_back({instruction->mutable_operand(0), + instruction_and_shape_index.shape_index}); + } else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { + result.push_back({instruction->mutable_operand(0), + instruction_and_shape_index.shape_index}); + } else if (instruction->opcode() == HloOpcode::kWhile) { + HloComputation* while_body_computation = instruction->while_body(); + result.push_back({while_body_computation->root_instruction(), + instruction_and_shape_index.shape_index}); + } else { + CHECK(instruction->operand_count() == 1) << absl::StreamFormat( + "Expecting instruction %s to have 1 operand, but it has %d.", + instruction->name(), instruction->operand_count()); + result.push_back({instruction->mutable_operand(0), + instruction_and_shape_index.shape_index}); + } + return result; +} + +bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction) { + static constexpr std::array allowed_opcodes = { + HloOpcode::kGetTupleElement, + HloOpcode::kBitcast, + HloOpcode::kTuple, + HloOpcode::kCall, + HloOpcode::kWhile, + HloOpcode::kParameter, + HloOpcode::kOptimizationBarrier, + HloOpcode::kAsyncStart, + HloOpcode::kAsyncDone, + HloOpcode::kCustomCall}; + return absl::c_linear_search(allowed_opcodes, instruction->opcode()); +} + +bool operator==(const InstructionAndShapeIndex& lhs, + const InstructionAndShapeIndex& rhs) { + return lhs.instruction == rhs.instruction && + lhs.shape_index == rhs.shape_index; +} + +std::string InstructionAndShapeIndex::ToString() const { + return absl::StrFormat("{Instr: %s, ShapeIndex: %s}", instruction->name(), + shape_index.ToString()); +} + +} // namespace host_offload_utils +} // namespace xla diff --git a/third_party/xla/xla/service/host_offload_utils.h b/third_party/xla/xla/service/host_offload_utils.h new file mode 100644 index 00000000000000..22e1c359dca83e --- /dev/null +++ b/third_party/xla/xla/service/host_offload_utils.h @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_HOST_OFFLOAD_UTILS_H_ +#define XLA_SERVICE_HOST_OFFLOAD_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/pattern_matcher.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace host_offload_utils { + +struct InstructionAndShapeIndex { + explicit InstructionAndShapeIndex(HloInstruction* instruction) + : instruction(instruction) {} + InstructionAndShapeIndex(HloInstruction* instruction, ShapeIndex shape_index) + : instruction(instruction), shape_index(shape_index) {} + HloInstruction* instruction; + ShapeIndex shape_index; + std::string ToString() const; + + template + static H Hash(H h, const InstructionAndShapeIndex& i) { + h = H::combine(std::move(h), i.instruction); + h = H::combine(std::move(h), i.shape_index); + return std::move(h); + } + + template + friend H AbslHashValue(H h, const InstructionAndShapeIndex& i) { + return InstructionAndShapeIndex::Hash(std::move(h), i); + } +}; + +bool operator==(const InstructionAndShapeIndex& lhs, + const InstructionAndShapeIndex& rhs); + +// If an instruction's user is a call, we descend into the call first. +// Eventually, a later invocation of this function while walking the graph will +// return the call itself as a successor of the ROOT instruction of the +// computation. +absl::StatusOr> GetSuccessors( + const InstructionAndShapeIndex& instruction_and_shape_index); + +// If an instruction's operand is a call, return the call now. A follow up call +// of this function on that call returns the ROOT. Eventually, once the given +// instruction is a parameter, the returned predecessor will be the appropriate +// operand of the call (not the call itself, since we already returned it). +std::vector GetPredecessors( + const InstructionAndShapeIndex& instruction_and_shape_index); + +// Returns true if the instruction is allowed to be in the +// middle of a pure memory offload path. +bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction); + +} // namespace host_offload_utils +} // namespace xla + +#endif // XLA_SERVICE_HOST_OFFLOAD_UTILS_H_ diff --git a/third_party/xla/xla/service/host_offload_utils_test.cc b/third_party/xla/xla/service/host_offload_utils_test.cc new file mode 100644 index 00000000000000..6f38b45ab09544 --- /dev/null +++ b/third_party/xla/xla/service/host_offload_utils_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offload_utils.h" + +#include +#include + +#include +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace host_offload_utils { +namespace { + +class HostOffloadUtilsTest : public HloTestBase {}; + +TEST_F(HostOffloadUtilsTest, SimpleGetSuccessorsGetPredecessorsTest) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[1,2048,2048] parameter(0) + index_param = s32[] parameter(1) + constant_f32_0 = f32[] constant(0) + constant_s32_0 = s32[] constant(0) + broadcast = f32[2,2048,2048] broadcast(constant_f32_0), dimensions={} + offload_custom_call = f32[1,2048,2048] custom-call(data_param), custom_call_target="MoveToHost" + dynamic_update_slice = f32[2,2048,2048] dynamic-update-slice(broadcast, offload_custom_call, index_param, constant_s32_0, constant_s32_0) + dynamic_slice = f32[1,2048,2048] dynamic-slice(dynamic_update_slice, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} + ROOT load_custom_call = f32[1,2048,2048] custom-call(dynamic_slice), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* data_param = FindInstruction(module.get(), "data_param"); + ASSERT_NE(data_param, nullptr); + HloInstruction* offload_custom_call = + FindInstruction(module.get(), "offload_custom_call"); + ASSERT_NE(offload_custom_call, nullptr); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector succ, + GetSuccessors(InstructionAndShapeIndex(data_param, {}))); + std::vector expected_succ = { + InstructionAndShapeIndex(offload_custom_call, {})}; + EXPECT_EQ(succ, expected_succ); + + std::vector pred = + GetPredecessors(InstructionAndShapeIndex(offload_custom_call, {})); + std::vector expected_pred = { + InstructionAndShapeIndex(data_param, {})}; + EXPECT_EQ(pred, expected_pred); +} + +TEST_F(HostOffloadUtilsTest, ComputationGetSuccessorsGetPredecessorsTest) { + const std::string& hlo_string = R"( +HloModule my_module +other_computation { + param_0 = f32[2048] parameter(0) + param_1 = f32[2048] parameter(1) + ROOT tuple = (f32[2048], f32[2048]) tuple(param_0, param_1) +} +ENTRY main { + data_param = f32[2048] parameter(0) + other_param = f32[2048] parameter(1) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="MoveToHost" + call = (f32[2048], f32[2048]) call(offload_custom_call, other_param), to_apply=other_computation + gte_0 = f32[2048] get-tuple-element(call), index=0 + gte_1 = f32[2048] get-tuple-element(call), index=1 + ROOT load_custom_call = f32[2048] custom-call(gte_0), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* call = FindInstruction(module.get(), "call"); + ASSERT_NE(call, nullptr); + HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); + ASSERT_NE(gte_0, nullptr); + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); + ASSERT_NE(tuple, nullptr); + + TF_ASSERT_OK_AND_ASSIGN(std::vector succ, + GetSuccessors(InstructionAndShapeIndex(call, {0}))); + std::vector expected_succ = { + InstructionAndShapeIndex(gte_0, {})}; + EXPECT_EQ(succ, expected_succ); + + std::vector pred = + GetPredecessors(InstructionAndShapeIndex(call, {0})); + std::vector expected_pred = { + InstructionAndShapeIndex(tuple, {0})}; + EXPECT_EQ(pred, expected_pred); +} + +} // namespace +} // namespace host_offload_utils +} // namespace xla diff --git a/third_party/xla/xla/service/host_offloader.cc b/third_party/xla/xla/service/host_offloader.cc index 95c97e94c704da..7e1971302c981b 100644 --- a/third_party/xla/xla/service/host_offloader.cc +++ b/third_party/xla/xla/service/host_offloader.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/hlo_cse.h" #include "xla/service/hlo_value.h" #include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/host_offload_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -58,6 +59,7 @@ namespace { using ::xla::host_memory_offload_annotations::kMoveToDeviceCustomCallTarget; using ::xla::host_memory_offload_annotations::kMoveToHostCustomCallTarget; +using ::xla::host_offload_utils::InstructionAndShapeIndex; void SetMemorySpace(Shape* shape, int64_t memory_space_color) { CHECK(shape->has_layout()); @@ -85,210 +87,8 @@ bool SetBuffersToMemorySpaceColor( return changed; } -bool CustomCallReusesBuffer(const HloInstruction* custom_call, - int64_t operand_index) { - if (custom_call->custom_call_target() == kMoveToDeviceCustomCallTarget || - custom_call->custom_call_target() == kMoveToHostCustomCallTarget) { - // Does not define a new buffer. - return true; - } - // Check the custom call's output_to_operand_aliasing. - const std::vector>>& - aliases = custom_call->output_operand_aliasing(); - for (const std::pair>& alias : - aliases) { - int64_t alias_operand_index = alias.second.first; - if (alias_operand_index == operand_index) { - // This operand aliases with the output. - return true; - } - } - // By default, assume custom calls define new buffers. - return false; -} - -// If an instruction's user is a call, we descend into the call first. -// Eventually, a later invocation of this function while walking the graph will -// return the call itself as a successor of the ROOT instruction of the -// computation. -absl::StatusOr> GetSuccessors( - const InstructionAndShapeIndex& instruction_and_shape_index) { - std::vector result; - HloInstruction* instruction = instruction_and_shape_index.instruction; - if (instruction->IsRoot()) { - // Successor of the root is the call instruction(s). - std::unique_ptr call_graph = - CallGraph::Build(instruction->GetModule()); - auto callers = call_graph->GetComputationCallers(instruction->parent()); - for (HloInstruction* caller : callers) { - result.push_back({caller, instruction_and_shape_index.shape_index}); - } - } - for (HloInstruction* user : instruction->users()) { - if (user->opcode() == HloOpcode::kTuple) { - auto operand_indices = user->OperandIndices(instruction); - for (const auto i : operand_indices) { - auto tmp_shape_index = instruction_and_shape_index.shape_index; - tmp_shape_index.push_back(i); - result.push_back({user, std::move(tmp_shape_index)}); - } - } else if (user->opcode() == HloOpcode::kGetTupleElement) { - ShapeIndex tmp_shape_index = instruction_and_shape_index.shape_index; - const auto index = tmp_shape_index.front(); - if (index == user->tuple_index()) { - // This GTE is for the buffer we're tracking. - tmp_shape_index.pop_front(); - result.push_back({user, std::move(tmp_shape_index)}); - } - } else if (user->opcode() == HloOpcode::kCall) { - auto operand_indices = user->OperandIndices(instruction); - CHECK(user->called_computations().size() == 1) - << "Expect call to only have one called computation."; - for (const auto i : operand_indices) { - HloComputation* called_computation = - user->called_computations().front(); - HloInstruction* parameter_instruction = - called_computation->parameter_instruction(i); - result.push_back( - {parameter_instruction, instruction_and_shape_index.shape_index}); - } - } else if (user->opcode() == HloOpcode::kWhile) { - auto operand_indices = user->OperandIndices(instruction); - HloComputation* while_body_computation = user->while_body(); - HloComputation* while_condition_computation = user->while_condition(); - for (const auto i : operand_indices) { - HloInstruction* parameter_instruction = - while_body_computation->parameter_instruction(i); - result.push_back( - {parameter_instruction, instruction_and_shape_index.shape_index}); - - HloInstruction* condition_instruction = - while_condition_computation->parameter_instruction(i); - result.push_back( - {condition_instruction, instruction_and_shape_index.shape_index}); - } - } else if (user->opcode() == HloOpcode::kAsyncStart) { - auto operand_indices = user->OperandIndices(instruction); - CHECK(user->called_computations().size() == 1) - << "Expect async-start to only have one called computation."; - for (const auto i : operand_indices) { - HloComputation* called_computation = - user->called_computations().front(); - HloInstruction* parameter_instruction = - called_computation->parameter_instruction(i); - result.push_back( - {parameter_instruction, instruction_and_shape_index.shape_index}); - } - } else if (user->opcode() == HloOpcode::kCustomCall) { - const auto operand_indices = user->OperandIndices(instruction); - // TODO(b/342650757): Rather than a boolean indicating whether the - // instruction reuses the buffer, return the shape index of the output - // that the operand aliases with. - bool found_one = false; - for (const auto i : operand_indices) { - if (CustomCallReusesBuffer(user, i)) { - if (found_one) { - return absl::InternalError( - "Found multiple operands of a custom call that reuse the same " - "output buffer."); - } - result.push_back({user, instruction_and_shape_index.shape_index}); - found_one = true; - } - } - } else { - result.push_back({user, instruction_and_shape_index.shape_index}); - } - } - return result; -} - -// If an instruction's operand is a call, return the call now. A follow up call -// of this function on that call returns the ROOT. Eventually, once the given -// instruction is a parameter, the returned predecessor will be the appropriate -// operand of the call (not the call itself, since we already returned it). -std::vector GetPredecessors( - const InstructionAndShapeIndex& instruction_and_shape_index) { - std::vector result; - HloInstruction* instruction = instruction_and_shape_index.instruction; - if (instruction->opcode() == HloOpcode::kGetTupleElement) { - const int64_t index = instruction->tuple_index(); - auto tmp_shape_index = instruction_and_shape_index.shape_index; - tmp_shape_index.push_front(index); - result.push_back({instruction->mutable_operand(0), tmp_shape_index}); - } else if (instruction->opcode() == HloOpcode::kTuple) { - CHECK(!instruction_and_shape_index.shape_index.empty()) - << "Did not store an index before encountering a tuple."; - auto tmp_shape_index = instruction_and_shape_index.shape_index; - const int64_t index = tmp_shape_index.front(); - tmp_shape_index.pop_front(); - result.push_back({instruction->mutable_operand(index), tmp_shape_index}); - } else if (instruction->opcode() == HloOpcode::kCall) { - // Predecessor of a call is its computation's root instruction. - CHECK(instruction->called_computations().size() == 1) - << "Expect call to only have one called computation."; - HloComputation* called_computation = - instruction->called_computations().front(); - result.push_back({called_computation->root_instruction(), - instruction_and_shape_index.shape_index}); - } else if (instruction->opcode() == HloOpcode::kParameter) { - std::unique_ptr call_graph = - CallGraph::Build(instruction->GetModule()); - auto callers = call_graph->GetComputationCallers(instruction->parent()); - for (HloInstruction* caller : callers) { - result.push_back( - {caller->mutable_operand(instruction->parameter_number()), - instruction_and_shape_index.shape_index}); - } - } else if (instruction->opcode() == HloOpcode::kDynamicSlice) { - result.push_back({instruction->mutable_operand(0), - instruction_and_shape_index.shape_index}); - } else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { - result.push_back({instruction->mutable_operand(0), - instruction_and_shape_index.shape_index}); - } else if (instruction->opcode() == HloOpcode::kWhile) { - HloComputation* while_body_computation = instruction->while_body(); - result.push_back({while_body_computation->root_instruction(), - instruction_and_shape_index.shape_index}); - } else { - CHECK(instruction->operand_count() == 1) << absl::StreamFormat( - "Expecting instruction %s to have 1 operand, but it has %d.", - instruction->name(), instruction->operand_count()); - result.push_back({instruction->mutable_operand(0), - instruction_and_shape_index.shape_index}); - } - return result; -} - } // namespace -bool operator==(const InstructionAndShapeIndex& lhs, - const InstructionAndShapeIndex& rhs) { - return lhs.instruction == rhs.instruction && - lhs.shape_index == rhs.shape_index; -} - -std::string InstructionAndShapeIndex::ToString() const { - return absl::StrFormat("{Instr: %s, ShapeIndex: %s}", instruction->name(), - shape_index.ToString()); -} - -bool HostOffloader::IsValidDuringPureMemoryOffload( - const HloInstruction* instruction) const { - static constexpr std::array allowed_opcodes = { - HloOpcode::kGetTupleElement, - HloOpcode::kBitcast, - HloOpcode::kTuple, - HloOpcode::kCall, - HloOpcode::kWhile, - HloOpcode::kParameter, - HloOpcode::kOptimizationBarrier, - HloOpcode::kAsyncStart, - HloOpcode::kAsyncDone, - HloOpcode::kCustomCall}; - return absl::c_linear_search(allowed_opcodes, instruction->opcode()); -} - bool HostOffloader::InstructionIsAllowedBetweenMoveToHostAndDus( const HloInstruction* instruction) const { if (instruction->opcode() == HloOpcode::kReshape) { @@ -355,7 +155,8 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( // this so that we don't try to create an AllocateBuffer later. dynamic_update_slices_already_allocated_.insert(instruction); } - } else if (IsValidDuringPureMemoryOffload(instruction)) { + } else if (host_offload_utils::IsValidDuringPureMemoryOffload( + instruction)) { if (instruction->opcode() == HloOpcode::kAsyncStart) { // When visiting the parameter, we already set the memory space of the // input of the async-start; do not set it now. @@ -433,8 +234,9 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( } } // Push successors onto the queue to be visited. - TF_ASSIGN_OR_RETURN(const std::vector successors, - GetSuccessors(instruction_and_shape_index)); + TF_ASSIGN_OR_RETURN( + const std::vector successors, + host_offload_utils::GetSuccessors(instruction_and_shape_index)); for (const InstructionAndShapeIndex& successor : successors) { queue.push(successor); } @@ -454,7 +256,8 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( } if (insert_copy_before) { - const auto predecessors = GetPredecessors(starting_instruction_and_index); + const auto predecessors = + host_offload_utils::GetPredecessors(starting_instruction_and_index); CHECK_EQ(predecessors.size(), 1); TF_ASSIGN_OR_RETURN(bool inserted_copy, InsertCopyBetween(predecessors.front(), @@ -687,7 +490,8 @@ HostOffloader::GetStartingInstructions( std::queue queue; TF_ASSIGN_OR_RETURN( const std::vector successors_of_custom_call, - GetSuccessors(InstructionAndShapeIndex(custom_call_instruction))); + host_offload_utils::GetSuccessors( + InstructionAndShapeIndex(custom_call_instruction))); for (const InstructionAndShapeIndex& successor : successors_of_custom_call) { queue.push(successor); } @@ -707,8 +511,9 @@ HostOffloader::GetStartingInstructions( } else { // Is a logical bitcast/reshape, we won't offload this yet. } - TF_ASSIGN_OR_RETURN(const std::vector successors, - GetSuccessors(instruction_and_shape)); + TF_ASSIGN_OR_RETURN( + const std::vector successors, + host_offload_utils::GetSuccessors(instruction_and_shape)); for (const InstructionAndShapeIndex& successor : successors) { queue.push(successor); } @@ -730,7 +535,7 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall( std::queue queue; TF_ASSIGN_OR_RETURN( const std::vector successors_of_slice, - GetSuccessors(InstructionAndShapeIndex(slice))); + host_offload_utils::GetSuccessors(InstructionAndShapeIndex(slice))); for (const InstructionAndShapeIndex& successor : successors_of_slice) { queue.push(successor); } @@ -751,8 +556,9 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall( "the MoveToDevice custom call.", slice->name(), current_instruction->name())); } - TF_ASSIGN_OR_RETURN(const std::vector successors, - GetSuccessors(instruction_and_shape)); + TF_ASSIGN_OR_RETURN( + const std::vector successors, + host_offload_utils::GetSuccessors(instruction_and_shape)); for (const InstructionAndShapeIndex& successor : successors) { queue.push(successor); } @@ -824,7 +630,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( InstructionAndShapeIndex nested_instruction_and_shape = nested_queue.front(); nested_queue.pop(); - if (!IsValidDuringPureMemoryOffload( + if (!host_offload_utils::IsValidDuringPureMemoryOffload( nested_instruction_and_shape.instruction)) { return absl::InvalidArgumentError(absl::StrFormat( "Tensor which is moved to host is used by an invalid " @@ -838,7 +644,8 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( kHostMemorySpaceColor); TF_ASSIGN_OR_RETURN( const std::vector successors, - GetSuccessors(nested_instruction_and_shape)); + host_offload_utils::GetSuccessors( + nested_instruction_and_shape)); for (const InstructionAndShapeIndex& successor : successors) { nested_queue.push(successor); } @@ -851,7 +658,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( dynamic_update_slices_already_allocated_.insert(instruction); } const std::vector predecessors = - GetPredecessors(instruction_and_shape); + host_offload_utils::GetPredecessors(instruction_and_shape); for (const InstructionAndShapeIndex& predecessor : predecessors) { HloInstruction* predecessor_instruction = predecessor.instruction; if (predecessor_instruction->opcode() == HloOpcode::kBroadcast) { diff --git a/third_party/xla/xla/service/host_offloader.h b/third_party/xla/xla/service/host_offloader.h index 880cda3d77b621..8dfee6d455eb6b 100644 --- a/third_party/xla/xla/service/host_offloader.h +++ b/third_party/xla/xla/service/host_offloader.h @@ -26,36 +26,12 @@ #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/service/host_offload_utils.h" namespace xla { class HloCostAnalysis; -struct InstructionAndShapeIndex { - explicit InstructionAndShapeIndex(HloInstruction* instruction) - : instruction(instruction) {} - InstructionAndShapeIndex(HloInstruction* instruction, ShapeIndex shape_index) - : instruction(instruction), shape_index(shape_index) {} - HloInstruction* instruction; - ShapeIndex shape_index; - std::string ToString() const; - - template - static H Hash(H h, const InstructionAndShapeIndex& i) { - h = H::combine(std::move(h), i.instruction); - h = H::combine(std::move(h), i.shape_index); - return std::move(h); - } - - template - friend H AbslHashValue(H h, const InstructionAndShapeIndex& i) { - return InstructionAndShapeIndex::Hash(std::move(h), i); - } -}; - -bool operator==(const InstructionAndShapeIndex& lhs, - const InstructionAndShapeIndex& rhs); - // This pass does "host memory offloading". If a tensor is annotated to be moved // to or from the host, this pass will remove the annotations and update each // tensor's layout with host memory spaces and insert copies if necessary. This @@ -90,17 +66,14 @@ class HostOffloader : public HloModulePass { absl::flat_hash_set validated_slices_; absl::flat_hash_map copies_created_after_; absl::flat_hash_set move_to_device_custom_calls_to_remove_; - absl::flat_hash_set already_inserted_copy_before_; + absl::flat_hash_set + already_inserted_copy_before_; // Sometimes previous transformations turn a DynamicSlice into a Slice. Since // we're doing a DMA between the host and device, we need to turn the Slice // back into a DynamicSlice. absl::StatusOr DynamifySlice(HloInstruction* slice); - // Returns true if the instruction is allowed to be in the - // middle of a pure memory offload path. - bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction) const; - // Returns true if the instruction is allowed to be in the // middle of a path between a MoveToHost custom-call annotation and a // DynamicUpdateSlice. Ideally the custom-call should be immediately followed @@ -146,19 +119,22 @@ class HostOffloader : public HloModulePass { // Common function for doing the actual walking of the graph. Host memory // spaces are set and copies are inserted in here. absl::StatusOr WalkDownHostMemoryOffloadPaths( - const InstructionAndShapeIndex& starting_instruction_and_index, + const host_offload_utils::InstructionAndShapeIndex& + starting_instruction_and_index, bool insert_copy_before); // Given a custom call, this returns the first instruction and shape index to // start the host memory offload path from for each use of the custom call. - absl::StatusOr> GetStartingInstructions( - HloInstruction* custom_call_instruction); + absl::StatusOr> + GetStartingInstructions(HloInstruction* custom_call_instruction); // When a MoveToHost custom call is not paired with a DynamicUpdateSlice, a // copy from device to host must be inserted. absl::StatusOr InsertCopyBetween( - const InstructionAndShapeIndex& before_instruction_and_index, - const InstructionAndShapeIndex& after_instruction_and_index); + const host_offload_utils::InstructionAndShapeIndex& + before_instruction_and_index, + const host_offload_utils::InstructionAndShapeIndex& + after_instruction_and_index); // This is a fix for scheduling. Add copies to inputs of dynamic-update-slice // if the inserted value is directly a parameter of a computation. This is to diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index 5f7757bcd2056a..dc59e5cca70151 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -486,6 +486,15 @@ bool AsyncTracker::ReleasesSelectiveResource(const HloGraphNode* node) const { }); } +bool AsyncTracker::OccupiesSelectiveResource(const HloGraphNode* node) const { + return absl::c_any_of( + node->GetResources(), [&](const ResourcePair& resource) { + return resource.second == ResourceUsageType::kResourceOccupy && + GetResourceHazardType(resource.first) == + ResourceHazardType::kSelective; + }); +} + BufferInfoTracker::BufferInfoTracker( const HloModule* module, const HloAliasAnalysis* alias_analysis, const HloCostAnalysis::ShapeSizeFunction& shape_size_bytes) { @@ -731,6 +740,25 @@ DefaultSchedulerCore::ScheduleCandidate InitializeCandidate( namespace { +// Find the num hops to the closest selective resource overlap in ready set that +// provided node can be scheduled in between. +int64_t GetNumHopsToClosestSelectiveOverlap( + const DefaultSchedulerCore::ReadyQueueSet& ready_set, + const HloGraphNode* node) { + int64_t num_hops_to_closest_selective_resource_occupier = + std::numeric_limits::max(); + for (const HloGraphNode* n : ready_set) { + // Skip the node itself. + if (n == node) { + continue; + } + num_hops_to_closest_selective_resource_occupier = + std::min(num_hops_to_closest_selective_resource_occupier, + n->GetNumHopsToClosestSelectiveResourceOccupier()); + } + return num_hops_to_closest_selective_resource_occupier; +} + // Comparator for the ready set. This class represents the priority policies // for the nodes in the ready set. The policy can be whatever is appropriate to // reduce the execution time of the graph or achieve interesting properties @@ -1002,6 +1030,31 @@ class ReadySetLt { return *value; } } + // If there are no selective overlaps open currently and there will be + // overlaps opened in the near future, hold off scheduling instructions + // that are valuable for selective overlaps. + if (sched_state_.config.enable_selective_resources && + sched_state_.selective_resource_releasers.empty()) { + int64_t distance_to_selective_overlap_for_a = + GetNumHopsToClosestSelectiveOverlap(sched_state_.ready_set, a.node); + int64_t distance_to_selective_overlap_for_b = + GetNumHopsToClosestSelectiveOverlap(sched_state_.ready_set, b.node); + // If a is valuable for selective overlap and there is a selective + // overlap in the near future a can be scheduled inside, hold off + // scheduling a and schedule b instead. Same logic applies in reverse. + int64_t max_distance = + sched_state_.config.max_hops_to_closest_selective_overlap; + if (auto value = DefaultSchedulerCore::ChooseBestCandidate( + (a.node->GetValuableForSelectiveOverlap() && + distance_to_selective_overlap_for_a <= max_distance), + b, + (b.node->GetValuableForSelectiveOverlap() && + distance_to_selective_overlap_for_b <= max_distance), + a, "kNotValuableForSelectiveOverlap")) { + return *value; + } + } + if (sched_state_.config.aggressive_scheduling_policies) { // Favor nodes that unlock other nodes to be scheduled if possible. // This makes us more flexible in what we can use in scheduling. @@ -1693,6 +1746,8 @@ HloScheduleGraph::HloScheduleGraph( new_node_it->second->GetResources()); new_node_it->second->releases_selective_resource_ = async_tracker->ReleasesSelectiveResource(new_node_it->second.get()); + new_node_it->second->occupies_selective_resource_ = + async_tracker->OccupiesSelectiveResource(new_node_it->second.get()); // Gather while instructions for subsequent send-done dependency checks. if (instr->opcode() == HloOpcode::kWhile) { while_instrs.push_back(instr); @@ -1900,6 +1955,25 @@ void HloScheduleGraph::InitializeGraphAnalysis( while (!stack.empty()) { auto* node = stack.back(); stack.pop_back(); + // If a node occupies a selective resource, it is the closest selective + // resource occupier to itself and is 0 hops away. Otherwise, the num hops + // to closest selective resource occupier is the minimum of that of all + // predecessors plus 1. + if (async_tracker->OccupiesSelectiveResource(node)) { + node->num_hops_to_closest_selective_resource_occupier_ = 0; + } else { + int64_t closest_predecessor_distance = + std::numeric_limits::max(); + for (auto& pred : node->GetPredecessors()) { + closest_predecessor_distance = std::min( + closest_predecessor_distance, + pred.Target().num_hops_to_closest_selective_resource_occupier_); + } + if (closest_predecessor_distance != std::numeric_limits::max()) { + node->num_hops_to_closest_selective_resource_occupier_ = + closest_predecessor_distance + 1; + } + } if (async_tracker->IsSupportedAsyncDone(node->GetInstr())) { for (auto& pred : node->GetPredecessors()) { node->SetAsyncDepth( diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index ebe1cf0c6bcc8c..b0d8a8d08e9886 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -137,6 +137,7 @@ struct SchedulerConfig { bool resource_serializing = false; bool depth_based_memory_pressure_reduction = false; bool enable_selective_resources = false; + int64_t max_hops_to_closest_selective_overlap = 0; int64_t rerun = 0; }; @@ -284,6 +285,9 @@ class AsyncTracker { // Returns whether the provided node releases a selective resource. bool ReleasesSelectiveResource(const HloGraphNode* node) const; + // Returns whether the provided node occupies a selective resource. + bool OccupiesSelectiveResource(const HloGraphNode* node) const; + inline CanonicalAsyncOp GetCanonicalAsyncOp(const HloInstruction& hlo) const { return get_canonical_async_op_(hlo); } @@ -386,6 +390,17 @@ class HloGraphNode { bool ReleasesSelectiveResource() const { return releases_selective_resource_; } + bool OccupiesSelectiveResource() const { + return occupies_selective_resource_; + } + int64_t GetNumHopsToClosestSelectiveResourceOccupier() const { + return num_hops_to_closest_selective_resource_occupier_; + } + void SetNumHopsToClosestSelectiveResourceOccupier( + int64_t num_hops_to_closest_selective_resource_occupier) { + num_hops_to_closest_selective_resource_occupier_ = + num_hops_to_closest_selective_resource_occupier; + } ResourcesVector GetResources() const { return resources_; } bool DoesOccupyAnyResource() const { @@ -525,6 +540,11 @@ class HloGraphNode { bool valuable_for_selective_overlap_ = true; // Whether this node releases a selective resource. bool releases_selective_resource_ = false; + // Whether this node occupies a selective resource. + bool occupies_selective_resource_ = false; + // Nums hops to closest selective resource occupier. + int64_t num_hops_to_closest_selective_resource_occupier_ = + std::numeric_limits::max(); }; // Schedule graph that can be used to drive scheduling @@ -920,7 +940,6 @@ class DefaultSchedulerCore : public SchedulerCore { virtual absl::StatusOr FindAndExtractBestNodeAvailable( SchedulingState& sched_state, DefaultSchedulerCore::ShouldSkipNodeFunction should_skip_node); - bool DoesNodeReleaseSelectiveResource(const HloGraphNode* node) const; void DumpLatencyHidingSchedule( const HloComputation* computation, const HloScheduleGraph& schedule_graph, const std::vector& instructions, diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 688af31615c710..f749cdba55d57a 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -482,7 +482,8 @@ absl::Status LayoutAssignment::SetInstructionLayout( absl::Status LayoutAssignment::SetInstructionLayout( const Shape& shape_with_layout, const HloInstruction* instruction, - bool mandatory, bool dfs, bool allow_alias, int64_t priority) { + bool mandatory, bool dfs, bool allow_alias, int64_t priority, + ShapeIndexView subshape_index) { VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", " << ShapeUtil::HumanStringWithLayout(shape_with_layout) << ": priority = " << priority << " : mandatory = " << mandatory @@ -499,8 +500,12 @@ absl::Status LayoutAssignment::SetInstructionLayout( // instruction. TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( shape_with_layout, - [this, dfs, instruction, mandatory, allow_alias, priority]( - const Shape& subshape, const ShapeIndex& index) -> absl::Status { + [this, dfs, instruction, mandatory, allow_alias, priority, + subshape_index](const Shape& subshape, + const ShapeIndex& index) -> absl::Status { + if (!subshape_index.empty() && index != subshape_index) { + return absl::OkStatus(); + } auto buffers = points_to_analysis_->GetPointsToSet(instruction).element(index); CHECK_EQ(1, buffers.size()); diff --git a/third_party/xla/xla/service/layout_assignment.h b/third_party/xla/xla/service/layout_assignment.h index ba12a2a325bc99..ba59743018c386 100644 --- a/third_party/xla/xla/service/layout_assignment.h +++ b/third_party/xla/xla/service/layout_assignment.h @@ -378,14 +378,16 @@ class LayoutAssignment : public HloModulePass { absl::Status SetInstructionLayout(const Shape& shape_with_layout, const HloInstruction* instruction, bool mandatory = true, bool dfs = true, - bool allow_alias = false) { + bool allow_alias = false, + ShapeIndexView subshape_index = {}) { return SetInstructionLayout(shape_with_layout, instruction, mandatory, dfs, - allow_alias, current_priority_); + allow_alias, current_priority_, subshape_index); } absl::Status SetInstructionLayout(const Shape& shape_with_layout, const HloInstruction* instruction, bool mandatory, bool dfs, bool allow_alias, - int64_t priority); + int64_t priority, + ShapeIndexView subshape_index = {}); // Set the same given layout across all components of the instruction output. // It works the same as the API above if the output is a single array. absl::Status SetInstructionLayout(const Layout& layout, diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index bf375e84182ad8..4271cc897f41d7 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -4069,10 +4069,8 @@ static absl::Status ValidateGatherDimensionNumbers( absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(), - offset_dims_seen)) { - offset_dims_seen++; - } - while (absl::c_binary_search(gather_dim_numbers.operand_batching_dims(), + offset_dims_seen) || + absl::c_binary_search(gather_dim_numbers.operand_batching_dims(), offset_dims_seen)) { offset_dims_seen++; } @@ -4308,7 +4306,7 @@ absl::Status ValidateScatterDimensionNumbers( TF_RETURN_IF_ERROR(ExpectArray( updates_shape, absl::StrCat("updates ", operand_i, " of scatter op"))); - int64_t inserted_dims_seen = 0; + int64_t inserted_dims_seen = 0, input_batching_dims_seen = 0; std::vector max_update_slice_sizes; const auto dimensions_size = operand_shape.dimensions_size(); max_update_slice_sizes.reserve(dimensions_size); @@ -4317,6 +4315,11 @@ absl::Status ValidateScatterDimensionNumbers( scatter_dim_numbers.inserted_window_dims_size() && scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { ++inserted_dims_seen; + } else if (input_batching_dims_seen < + scatter_dim_numbers.input_batching_dims_size() && + scatter_dim_numbers.input_batching_dims( + input_batching_dims_seen) == i) { + ++input_batching_dims_seen; } else { max_update_slice_sizes.push_back(operand_shape.dimensions(i)); } diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 14c3e804563815..29ae32add358e3 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -2870,6 +2870,24 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { << ShapeUtil::HumanString(gather_shape); } +TEST_F(GatherShapeInferenceTest, TensorFlowGatherBatchingDims) { + TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, + ShapeInference::InferGatherShape( + ShapeUtil::MakeShape(F32, {100, 64, 5, 48}), + ShapeUtil::MakeShape(S64, {5, 100, 32}), + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{3}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, + /*index_vector_dim=*/3, + /*operand_batching_dims=*/{0, 2}, + /*start_indices_batching_dims=*/{1, 0}), + /*slice_sizes=*/{1, 1, 1, 8})); + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {5, 100, 32, 8}))) + << ShapeUtil::HumanString(gather_shape); +} + TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, ShapeInference::InferGatherShape( @@ -3481,6 +3499,27 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { << statusor.status(); } +TEST_P(ScatterShapeInferenceTest, + TfScatterBatchingDimsWithUpdatesBiggerThanInput) { + const auto shapes = CreateShapes({100, 64, 48}, s64_tensor({100, 32}), + {100, 65, 32}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( + shapes.ptrs, to_apply(types()), + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{2}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/2, + /*input_batching_dims=*/{0}, + /*scatter_indices_batching_dims=*/{0})); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) { const auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 31}, types()); const absl::StatusOr statusor = ShapeInference::InferScatterShape( diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD index ecd0481a22f3e1..bd15f2048ec50d 100644 --- a/third_party/xla/xla/service/spmd/shardy/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/BUILD @@ -143,6 +143,7 @@ xla_cc_binary( "//xla/service/spmd/shardy/mhlo_round_trip:shard_map_export", "//xla/service/spmd/shardy/round_trip_common:convert_sharding_custom_calls", "//xla/service/spmd/shardy/round_trip_common:import_constants", + "//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding", "//xla/service/spmd/shardy/round_trip_common:shard_map_import", "//xla/service/spmd/shardy/sdy_round_trip:export_ops", "//xla/service/spmd/shardy/sdy_round_trip:export_shardings", diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD index c3335b52779680..e929f614006e81 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD @@ -47,6 +47,22 @@ cc_library( ], ) +cc_library( + name = "open_while_free_vars_sharding", + srcs = ["open_while_free_vars_sharding.cc"], + hdrs = ["open_while_free_vars_sharding.h"], + deps = [ + "//xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@shardy//shardy/dialect/sdy/ir:dialect", + ], +) + cc_library( name = "shard_map_import", srcs = ["shard_map_import.cc"], @@ -79,6 +95,7 @@ cc_library( deps = [ ":convert_sharding_custom_calls", ":import_constants", + ":open_while_free_vars_sharding", ":shard_map_import", "//xla/mlir_hlo:mhlo_passes", "@llvm-project//mlir:FuncDialect", diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc new file mode 100644 index 00000000000000..603b270eefa46f --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" + +#include + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/RegionUtils.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::StringRef; +using ::mlir::func::FuncOp; +using ::mlir::sdy::TensorShardingAttr; + +class OpenWhileFreeVarsShardingPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenWhileFreeVarsShardingPass) + + void runOnOperation() final { + FuncOp funcOp = getOperation(); + mlir::IRRewriter rewriter(funcOp); + + funcOp.walk([&](mlir::mhlo::WhileOp op) { + llvm::SetVector freeVars; + mlir::getUsedValuesDefinedAbove(op->getRegions(), freeVars); + rewriter.setInsertionPoint(op); + for (mlir::Value freeVar : freeVars) { + TensorShardingAttr sharding = mlir::sdy::getSharding(freeVar); + if (!sharding || sharding.getRank() == 0) { + continue; + } + auto shardingConstraint = + rewriter.create( + freeVar.getLoc(), freeVar, + TensorShardingAttr::getFullyOpenLike(sharding)); + // Only replace uses in the regions of the while op. + rewriter.replaceUsesWithIf( + freeVar, shardingConstraint, [op](mlir::OpOperand& use) { + return op->isProperAncestor(use.getOwner()); + }); + } + }); + } + + StringRef getArgument() const override { + return "xla-sdy-open-while-free-vars-sharding"; + } + + StringRef getDescription() const override { + return "Adds a fully open sharding constraint to free variables of while " + "op that already have a sharding."; + } +}; + +} // namespace + +std::unique_ptr createOpenWhileFreeVarsShardingPass() { + return std::make_unique(); +} + +void registerOpenWhileFreeVarsShardingPass() { + mlir::registerPass(createOpenWhileFreeVarsShardingPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h new file mode 100644 index 00000000000000..c06776f3c368fc --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_ +#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates a pass that adds a fully open sharding constraint to free variables +// of while op that already have a user-defined sharding. +// +// This allows for their uses in the while op to be further sharded, which is +// important when converting to HLO as they will be lifted as passthrough while +// operands/results. +std::unique_ptr createOpenWhileFreeVarsShardingPass(); + +// Registers the xla-sdy-open-while-free-vars-sharding pass. +void registerOpenWhileFreeVarsShardingPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc index 14e2795b8eb880..23960ab48aadca 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h" #include "xla/service/spmd/shardy/round_trip_common/import_constants.h" +#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" #include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h" namespace xla { @@ -45,13 +46,15 @@ void addCommonPreImportPasses(mlir::OpPassManager& pm) { pm.addNestedPass(mlir::mhlo::createFlattenTuplePass()); // We need to canonicalize redundant mhlo::GetTupleElementOp and - // mhlo::GetTupleOp. + // mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before + // `createOpenWhileFreeVarsShardingPass`. pm.addPass(mlir::createCanonicalizerPass()); } void addCommonPostImportPasses(mlir::OpPassManager& pm) { pm.addPass(createShardMapImportPass()); pm.addPass(createConvertShardingCustomCallsPass()); + pm.addNestedPass(createOpenWhileFreeVarsShardingPass()); } } // namespace sdy diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc index 52c5c23aba614e..b5670e78ace9b3 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h" #include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h" #include "xla/service/spmd/shardy/round_trip_common/import_constants.h" +#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" #include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" @@ -54,6 +55,7 @@ int main(int argc, char** argv) { xla::sdy::registerMhloImportShardingsPass(); xla::sdy::registerShardMapImportPass(); xla::sdy::registerConvertShardingCustomCallsPass(); + xla::sdy::registerOpenWhileFreeVarsShardingPass(); xla::sdy::registerImportConstantsPass(); xla::sdy::registerMhloExportPipeline(); diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc index 59197cc4a38c1f..40463d0dc74fce 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -535,36 +535,42 @@ TEST_F(ShardyXLATest, RngBitGenerator) { TEST_F(ShardyXLATest, WhileWithFreeVariables) { const char* const hloString = R"( - HloModule main - - %region_0.6 (arg_tuple.7: (f32[32,96], s32[], s32[], s32[])) -> (f32[32,96], s32[], s32[], s32[]) { - %arg_tuple.7 = (f32[32,96]{1,0}, s32[], s32[], s32[]) parameter(0) - %get-tuple-element.8 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=0 - %add.13 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.8, f32[32,96]{1,0} %get-tuple-element.8) - %get-tuple-element.9 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=1 - %get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=3 - %add.12 = s32[] add(s32[] %get-tuple-element.9, s32[] %get-tuple-element.11) - %get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=2 - ROOT %tuple.14 = (f32[32,96]{1,0}, s32[], s32[], s32[]) tuple(f32[32,96]{1,0} %add.13, s32[] %add.12, s32[] %get-tuple-element.10, s32[] %get-tuple-element.11) + HloModule main, entry_computation_layout={(f32[32,96]{1,0}, f32[32,96]{1,0})->f32[32,96]{1,0}} + + %region_0.7 (arg_tuple.8: (f32[32,96], s32[], s32[], s32[], f32[32,96])) -> (f32[32,96], s32[], s32[], s32[], f32[32,96]) { + %arg_tuple.8 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0) + %get-tuple-element.9 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=0 + %get-tuple-element.13 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=4 + %add.15 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.9, f32[32,96]{1,0} %get-tuple-element.13), metadata={source_file="-" source_line=25} + %get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=1 + %get-tuple-element.12 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=3 + %add.14 = s32[] add(s32[] %get-tuple-element.10, s32[] %get-tuple-element.12), metadata={source_file="-" source_line=24} + %get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=2 + ROOT %tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %add.15, s32[] %add.14, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[32,96]{1,0} %get-tuple-element.13) } - %region_1.15 (arg_tuple.16: (f32[32,96], s32[], s32[], s32[])) -> pred[] { - %arg_tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[]) parameter(0) - %get-tuple-element.17 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=0 - %get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=3 - %get-tuple-element.18 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=1 - %get-tuple-element.19 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=2 - ROOT %compare.21 = pred[] compare(s32[] %get-tuple-element.18, s32[] %get-tuple-element.19), direction=LT + %region_1.17 (arg_tuple.18: (f32[32,96], s32[], s32[], s32[], f32[32,96])) -> pred[] { + %arg_tuple.18 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0) + %get-tuple-element.19 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=0 + %get-tuple-element.22 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=3 + %get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=4 + %get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=1 + %get-tuple-element.21 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=2 + ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.20, s32[] %get-tuple-element.21), direction=LT, metadata={source_file="-" source_line=21} } - ENTRY %main.27 (Arg_0.1: f32[32,96]) -> f32[32,96] { + ENTRY %main.30 (Arg_0.1: f32[32,96], Arg_1.2: f32[32,96]) -> f32[32,96] { %Arg_0.1 = f32[32,96]{1,0} parameter(0), sharding={devices=[2,2]<=[4]} - %constant.2 = s32[] constant(0) - %constant.4 = s32[] constant(32) - %constant.3 = s32[] constant(1) - %tuple.5 = (f32[32,96]{1,0}, s32[], s32[], s32[]) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.2, s32[] %constant.4, s32[] %constant.3) - %while.22 = (f32[32,96]{1,0}, s32[], s32[], s32[]) while((f32[32,96]{1,0}, s32[], s32[], s32[]) %tuple.5), condition=%region_1.15, body=%region_0.6 - ROOT %get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %while.22), index=0 + %constant.3 = s32[] constant(0) + %constant.5 = s32[] constant(32) + %constant.4 = s32[] constant(1) + %Arg_1.2 = f32[32,96]{1,0} parameter(1), sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate} + %tuple.6 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.3, s32[] %constant.5, s32[] %constant.4, f32[32,96]{1,0} %Arg_1.2), metadata={source_file="-" source_line=19} + %while.25 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) while((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %tuple.6), condition=%region_1.17, body=%region_0.7, metadata={source_file="-" source_line=19} + %get-tuple-element.27 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=1, metadata={source_file="-" source_line=19} + %get-tuple-element.26 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=0, metadata={source_file="-" source_line=19} + %tuple.28 = (f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %get-tuple-element.26) + ROOT %get-tuple-element.29 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}) %tuple.28), index=0 })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); @@ -575,10 +581,14 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) { HloInstruction* whileInst = FindInstruction(module.get(), xla::HloOpcode::kWhile); EXPECT_NE(whileInst, nullptr); - EXPECT_THAT( - whileInst, - op::Sharding( - "{{devices=[2,2]<=[4]}, {replicated}, {replicated}, {replicated}}")); + // Verify that the sharding of parameter(1) hasn't changed. + EXPECT_THAT(module->entry_computation()->parameter_instruction(1), + op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}")); + // Verify the sharding of the while, and specifically that the sharding of the + // result that corresponds to parameter(1) is further sharded. + EXPECT_THAT(whileInst, + op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, {replicated}, " + "{devices=[2,2]<=[4]}, {replicated}}")); } TEST_F(ShardyXLATest, ShardMap) { diff --git a/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir b/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir index a04734d8a1f667..f191acf1aaf687 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir @@ -126,7 +126,7 @@ func.func @unknown_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4 // ----- // CHECK-LABEL: sdy.mesh @mesh = <> -// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = // CHECK-LABEL: func @one_maximal_mesh( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>} @@ -138,8 +138,8 @@ func.func @one_maximal_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal de // ----- -// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = -// CHECK-LABEL: sdy.mesh @maximal_mesh_4 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_4 = // CHECK-LABEL: func @two_maximal_shardings_should_be_sorted( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_4, [{}, {}]>}, @@ -151,7 +151,7 @@ func.func @two_maximal_shardings_should_be_sorted(%arg0: tensor<8x8xf32> {mhlo.s } // ----- -// CHECK-COUNT-1: sdy.mesh @maximal_mesh_0 = +// CHECK-COUNT-1: sdy.mesh @maximal_mesh_0 = // CHECK-LABEL: func @duplicate_maximal_sharding_should_be_deduped( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}, @@ -165,7 +165,7 @@ func.func @duplicate_maximal_sharding_should_be_deduped(%arg0: tensor<8x8xf32> { // ----- // CHECK-LABEL: sdy.mesh @mesh = <"axis_0"=8, "axis_1"=4> -// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = // CHECK-LABEL: func @two_meshes( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_1"}, {}]>}, @@ -180,7 +180,7 @@ func.func @two_meshes(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]< // ----- // CHECK-LABEL: sdy.mesh @mesh = <"axis_0"=8, "axis_1"=4> -// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = // CHECK-LABEL: func @maximal_sharding_on_op( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_1"}, {}]>}, diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir index c2707bae962cd3..f3e17fd2defac1 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir @@ -4,8 +4,8 @@ sdy.mesh @mesh_0 = <"axis_0"=2, "axis_1"=4, "axis_2"=4> sdy.mesh @mesh_1 = <"axis_0"=16> sdy.mesh @mesh_2 = <"x"=8, "y"=4> sdy.mesh @mesh_3 = <"a"=2, "b"=2, "c"=2, "d"=2> -sdy.mesh @maximal_mesh_0 = -sdy.mesh @maximal_mesh_1 = +sdy.mesh @maximal_mesh_0 = +sdy.mesh @maximal_mesh_1 = // CHECK-NOT: sdy.mesh diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir index 182504a423bf93..b022afcb921d43 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir @@ -52,36 +52,44 @@ func.func @shmap_body(%arg0: tensor<1x8xf32>, %arg1: tensor<1x8xf32>) -> (tensor // ----- +// CHECK-LABEL: sdy.mesh @mesh = <"axis_0"=2, "axis_1"=2> + // CHECK-LABEL: func @while_with_free_variables -func.func @while_with_free_variables(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { +func.func @while_with_free_variables( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"}) + -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> - // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> + // CHECK-NEXT: %[[C32:.*]] = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, []>]>} dense<32> + // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] // CHECK-NEXT: mhlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<1> : tensor - %2 = mhlo.constant dense<32> : tensor + %2 = mhlo.constant {mhlo.sharding = "{replicated}"} dense<32> : tensor %3:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { %4 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor mhlo.return %4 : tensor } do { %4 = mhlo.add %iterArg_0, %1 : tensor - %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32> + %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> mhlo.return %5, %4 : tensor<32x96xf32>, tensor } return %3#0 : tensor<32x96xf32> } +// ----- + // CHECK-LABEL: func @while_with_sinked_constants func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> diff --git a/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir new file mode 100644 index 00000000000000..b87048e4979a62 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir @@ -0,0 +1,93 @@ +// RUN: sdy_opt %s -xla-sdy-open-while-free-vars-sharding 2>&1 | FileCheck %s + +sdy.mesh @mesh1 = <"a"=2> +sdy.mesh @mesh2 = <"b"=2> + +// CHECK-LABEL: func @while_with_free_variables +func.func @while_with_free_variables( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}, + %arg2: tensor<32x96xf32>) + -> (tensor<32x96xf32>, tensor<32x96xf32>) { + // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0> + // CHECK-NEXT: %[[C1:.*]] = mhlo.constant dense<1> + // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> + // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} + // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> + // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %[[ADD_0]] <@mesh2, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: cond { + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: } do { + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_2:.*]] = mhlo.add %iterArg, %[[SC_0]] + // CHECK-NEXT: %[[ADD_3:.*]] = mhlo.add %[[ADD_2]], %arg2 + // CHECK-NEXT: %[[ADD_4:.*]] = mhlo.add %[[ADD_3]], %[[SC_1]] + // CHECK-NEXT: mhlo.return %[[ADD_4]], %[[ADD_1]] + // CHECK-NEXT: } + // CHECK-NEXT: return %[[ADD_0]], %[[WHILE]]#0 + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + %2 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor + %3 = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} : tensor<32x96xf32> + %4:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + cond { + %5 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor + mhlo.return %5 : tensor + } do { + %5 = mhlo.add %iterArg_0, %1 : tensor + %6 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> + %7 = mhlo.add %6, %arg2 : tensor<32x96xf32> + %8 = mhlo.add %7, %3 : tensor<32x96xf32> + mhlo.return %8, %5 : tensor<32x96xf32>, tensor + } + return %3, %4#0 : tensor<32x96xf32>, tensor<32x96xf32> +} + +// CHECK-LABEL: func @free_var_used_in_multiple_while_ops +func.func @free_var_used_in_multiple_while_ops( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}) + -> tensor<32x96xf32> { + // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0> + // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> + // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE_0:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: cond { + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: } do { + // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg, %[[SC_0]] + // CHECK-NEXT: mhlo.return %[[ADD_0]], %iterArg_0 + // CHECK-NEXT: } + // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE_1:.*]]:2 = mhlo.while(%iterArg = %[[WHILE_0]]#0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: cond { + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: } do { + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC_1]] + // CHECK-NEXT: mhlo.return %[[ADD_1]], %iterArg_0 + // CHECK-NEXT: } + // CHECK-NEXT: return %[[WHILE_1]]#0 + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor + %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + cond { + %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor + mhlo.return %4 : tensor + } do { + %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> + mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor + } + %3:2 = mhlo.while(%iterArg = %2#0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + cond { + %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor + mhlo.return %4 : tensor + } do { + %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> + mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor + } + return %3#0 : tensor<32x96xf32> +} diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir index 292b3544c05bc3..66b227dffd0f93 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir @@ -163,12 +163,19 @@ func.func @main( // ----- +// CHECK: sdy.mesh @mesh = <"x"=2> +sdy.mesh @mesh = <"x"=2> + // Test WhileOp with lifted free variables and sinked constants. // CHECK-LABEL: func @main -func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { +func.func @main( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}}) + -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> + // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] @@ -176,7 +183,7 @@ func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { // CHECK-NEXT: } do { // CHECK-DAG: %[[C1:.*]] = sdy.constant dense<1> // CHECK-DAG: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-DAG: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg + // CHECK-DAG: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 @@ -189,7 +196,7 @@ func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { } do { %3 = sdy.constant dense<1> : tensor %4 = mhlo.add %iterArg_0, %3 : tensor - %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32> + %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> mhlo.return %5, %4 : tensor<32x96xf32>, tensor } return %2#0 : tensor<32x96xf32> diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 857d4c5d790f54..e782de02815699 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -30,17 +30,21 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x } // CHECK-LABEL: func @while_with_free_variables - func.func @while_with_free_variables(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { + func.func @while_with_free_variables( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}}) + -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> + // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] // CHECK-NEXT: mhlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 @@ -53,7 +57,7 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x mhlo.return %4 : tensor } do { %4 = mhlo.add %iterArg_0, %1 : tensor - %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32> + %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> mhlo.return %5, %4 : tensor<32x96xf32>, tensor } return %3#0 : tensor<32x96xf32> diff --git a/third_party/xla/xla/service/stable_sort_expander.cc b/third_party/xla/xla/service/stable_sort_expander.cc index 910ab5da82a01e..ca87dce4df65a7 100644 --- a/third_party/xla/xla/service/stable_sort_expander.cc +++ b/third_party/xla/xla/service/stable_sort_expander.cc @@ -55,7 +55,6 @@ absl::StatusOr StableSortExpander::ExpandInstruction( HloComputation* computation = sort->parent(); HloInstruction* expanded_sort = nullptr; - absl::flat_hash_set used_indices; int64_t iota_index = IotaOperandIndexForStableSort(*sort); // If there is currently no iota operand which we could use for making the diff --git a/third_party/xla/xla/service/stream_pool_test.cc b/third_party/xla/xla/service/stream_pool_test.cc index 6eaeb912e4e2c0..2bea4119a4d9b7 100644 --- a/third_party/xla/xla/service/stream_pool_test.cc +++ b/third_party/xla/xla/service/stream_pool_test.cc @@ -29,8 +29,7 @@ class StreamPoolTest : public ::testing::Test { se::StreamExecutor* NewStreamExecutor() { se::Platform* platform = se::PlatformManager::PlatformWithName("Host").value(); - se::StreamExecutorConfig config(/*ordinal=*/0); - return platform->GetExecutor(config).value(); + return platform->ExecutorForDevice(/*ordinal=*/0).value(); } }; diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc index d1fd7acd8ca110..07b49dbafe45d1 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -136,10 +136,6 @@ absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( } bool changed = false; - - absl::flat_hash_map> - conditional_gte_index_to_insts = - WhileUtil::GetGTEsMapForWhileConditional(*while_cond); std::vector invariant_body_gtes = WhileUtil::GetInvariantGTEsForWhileBody(*while_body); std::vector tuple_indices; diff --git a/third_party/xla/xla/service/while_loop_unroller.cc b/third_party/xla/xla/service/while_loop_unroller.cc index 0e1c3288c468df..7b22244fa9023f 100644 --- a/third_party/xla/xla/service/while_loop_unroller.cc +++ b/third_party/xla/xla/service/while_loop_unroller.cc @@ -98,8 +98,11 @@ std::unique_ptr MakeTrivialLoopCondition( absl::Status HandleDynamicGteOrTuple(HloInstruction* instr) { if (instr->IsCustomCall("DynamicGte")) { HloEvaluator evaluator(/*max_loop_iterations=*/0); - TF_ASSIGN_OR_RETURN(Literal index_lit, - evaluator.Evaluate(instr->mutable_operand(1), true)); + TF_ASSIGN_OR_RETURN( + Literal index_lit, + evaluator.Evaluate(instr->mutable_operand(1), + /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true)); auto index = LiteralUtil::LiteralAsScalarInt64(std::move(index_lit)); // The index must have a compile-time integer value at this point. TF_RET_CHECK(index.has_value()); @@ -109,8 +112,11 @@ absl::Status HandleDynamicGteOrTuple(HloInstruction* instr) { } else if (instr->IsCustomCall("DynamicTuple")) { HloEvaluator evaluator(/*max_loop_iterations=*/0); std::vector tuple_operands; - TF_ASSIGN_OR_RETURN(Literal index_lit, - evaluator.Evaluate(instr->mutable_operand(2), true)); + TF_ASSIGN_OR_RETURN( + Literal index_lit, + evaluator.Evaluate(instr->mutable_operand(2), + /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true)); auto index = LiteralUtil::LiteralAsScalarInt64(std::move(index_lit)); // The index must have a compile-time integer value at this point. TF_RET_CHECK(index.has_value()); diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 0deca085f6c94c..9581d6d673f6f0 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/shape_util.h" #include -#include #include #include #include @@ -1982,8 +1981,9 @@ struct ParallelState { // Returns the indices of the first elements of all consecutive subarrays of the // given array. For example: // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -static std::vector ConsecutiveSegments(absl::Span xs) { - std::vector is = {0}; +static absl::InlinedVector ConsecutiveSegments( + absl::Span xs) { + absl::InlinedVector is = {0}; for (size_t i = 1; i < xs.size(); ++i) { if (1 != xs[i] - xs[i - 1]) { is.push_back(i); @@ -2010,83 +2010,74 @@ static Shape MergeDimensions(absl::Span segs, dimensions); } -static std::vector MajorToMinorLayout(const Shape& s) { +static absl::InlinedVector MajorToMinorLayout(const Shape& s) { absl::Span minor_to_major = LayoutUtil::MinorToMajor(s); - return std::vector{minor_to_major.rbegin(), minor_to_major.rend()}; -} - -static std::optional GetNormalizedTransposeShapeHelper( - const Shape& input_shape, absl::Span output_to_input, - const Vector3& permutation) { - // 'permutation' should not be the identity permutation. - if (permutation[0] == 0 && permutation[1] == 1 && permutation[2] == 2) { - return std::nullopt; - } - std::vector segments = ConsecutiveSegments(output_to_input); - if (segments.size() > 3) { + return absl::InlinedVector{minor_to_major.rbegin(), + minor_to_major.rend()}; +} + +static std::optional> +GetNormalizedTransposeShapeHelper( + const Shape& output_shape, absl::Span output_to_input, + absl::InlinedVector& permutation) { + absl::InlinedVector segments = + ConsecutiveSegments(output_to_input); + // This means that after normalization there is actually no transpose. + if (segments.size() == 1) { return std::nullopt; } - - Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - Shape normalized_shape = MergeDimensions(segments, normalized_input_shape); - std::vector normalized_dims{normalized_shape.dimensions().begin(), - normalized_shape.dimensions().end()}; + Shape normalized_shape = MergeDimensions(segments, output_shape); if (segments.size() == 2) { - // If we have two segments, we know that at least one transpose is - // happening, otherwise we would have only 1 segment. - int64_t untransposed = 0; - while (untransposed < permutation.size() && - permutation[untransposed] != untransposed) { - ++untransposed; - } - // The desired permutation may not contain any untransposed dimension. With - // just 2 segments, we cannot uniquely match that. - if (untransposed == permutation.size()) { - return std::nullopt; - } - // Insert a 1-dimension at the position of the untransposed dimension. - normalized_dims.insert(normalized_dims.begin() + untransposed, 1); - } else if (segments.size() == 3) { - // Derive the order from the segments. - Vector3 segment_order{output_to_input[segments[0]], - output_to_input[segments[1]], - output_to_input[segments[2]]}; - // We expect the same relative order. - for (int64_t i = 1; i < 3; ++i) { - if ((segment_order[i] > segment_order[i - 1]) != - (permutation[i] > permutation[i - 1])) { - return std::nullopt; - } + // If we have two segments, we know that exactly two dimensions are swapped. + // Insert a 1-dimension at the front and detect a 021 transpose. + // TODO(b/328656780): Don't insert the extra 1-dimension once the emitter + // supports any number of dimensions >= 2. + permutation = {0, 2, 1}; + return absl::InlinedVector{1, normalized_shape.dimensions(0), + normalized_shape.dimensions(1)}; + } + // We have at least 3 segments. Derive the permutation from the segments. + std::vector segment_to_normalized_dim(output_shape.rank(), -1); + for (size_t segment : segments) { + segment_to_normalized_dim[output_to_input[segment]] = 0; + } + int64_t normalized_dim = 0; + for (int64_t i = 0; i < segment_to_normalized_dim.size(); ++i) { + if (segment_to_normalized_dim[i] >= 0) { + segment_to_normalized_dim[i] = normalized_dim++; } } - if (normalized_dims.size() == 3) { - return Vector3{normalized_dims[permutation[0]], - normalized_dims[permutation[1]], - normalized_dims[permutation[2]]}; + permutation.reserve(segments.size()); + for (int64_t i = 0; i < segments.size(); ++i) { + permutation.push_back( + segment_to_normalized_dim[output_to_input[segments[i]]]); } - return std::nullopt; + absl::InlinedVector normalized_dims( + normalized_shape.dimensions().begin(), + normalized_shape.dimensions().end()); + return normalized_dims; } -/* static */ std::optional +/* static */ std::optional> ShapeUtil::GetNormalizedLogicalTransposeShape( - const Shape& input_shape, const Shape& output_shape, - absl::Span dimensions, const Vector3& permutation) { - if (!LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout()) || - !LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) { + const Shape& output_shape, absl::Span dimensions, + absl::InlinedVector& permutation) { + permutation.clear(); + if (!LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) { // Only works on default layouts. return std::nullopt; } // Drop degenerate dimensions. - std::vector delta(input_shape.rank() + 1, 0); - for (int i = 0; i < input_shape.rank(); ++i) { + absl::InlinedVector delta(output_shape.rank() + 1, 0); + auto input_dimensions = ComposePermutations(output_shape.dimensions(), + InversePermutation(dimensions)); + for (int i = 0; i < output_shape.rank(); ++i) { delta[i + 1] = delta[i]; - if (input_shape.dimensions(i) == static_cast(1)) { + if (input_dimensions[i] == static_cast(1)) { ++delta[i + 1]; } } - std::vector new_dimensions; + absl::InlinedVector new_dimensions; for (int i = 0; i < dimensions.size(); i++) { if (output_shape.dimensions(i) != 1) { new_dimensions.push_back(dimensions[i] - delta[dimensions[i]]); @@ -2094,24 +2085,29 @@ ShapeUtil::GetNormalizedLogicalTransposeShape( } return GetNormalizedTransposeShapeHelper( - DropDegenerateDimensions(input_shape), InversePermutation(new_dimensions), - permutation); + DropDegenerateDimensions(output_shape), new_dimensions, permutation); } -/* static */ std::optional ShapeUtil::GetNormalizedTransposeShape( +/* static */ std::optional> +ShapeUtil::GetNormalizedTransposeShape( const Shape& input_shape, const Shape& output_shape, - const Vector3& permutation) { + absl::InlinedVector& permutation) { + permutation.clear(); if (!ShapeUtil::CompatibleIgnoringElementType(input_shape, output_shape)) { return std::nullopt; } - std::vector major_to_minor_input = MajorToMinorLayout(input_shape); - std::vector major_to_minor_output = MajorToMinorLayout(output_shape); + absl::InlinedVector major_to_minor_input = + MajorToMinorLayout(input_shape); + absl::InlinedVector major_to_minor_output = + MajorToMinorLayout(output_shape); std::vector output_to_input = ComposePermutations( - InversePermutation(major_to_minor_output), major_to_minor_input); + InversePermutation(major_to_minor_input), major_to_minor_output); - return GetNormalizedTransposeShapeHelper(input_shape, output_to_input, - permutation); + return GetNormalizedTransposeShapeHelper( + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + output_shape), + output_to_input, permutation); } Shape ShapeUtil::DeviceShapeToHostShape(Shape s) { diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index a773adeaf08d6b..aeb043ebeb4d13 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -44,7 +44,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/printer.h" #include "xla/shape.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep @@ -1012,10 +1011,9 @@ class ShapeUtil { const Shape& shape, const ForEachParallelVisitorFunction& visitor_function); - // In this case, we care about transposes that swap two dimensions of a - // a shape that can be viewed as three logical components 0-1-2 in the order - // of major to minor. - // As an example, let's consider a 0-2-1 transpose: + // In this case, we care about transposes that permute dimensions of a shape + // that can be viewed as several logical components in the order of major to + // minor. As an example, let's consider a 0-2-1 transpose: // // If a shape can be viewed as three logical components 0-1-2 in the order of // major to minor, a 0-2-1-transpose changes the order of such logical @@ -1029,15 +1027,18 @@ class ShapeUtil { // should be set to {0, 2, 1}. // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the // normalized shape of `b` or the 0-2-1 shape. In general, the - // permutation[0]-permutation[1]-permutation[2] shape is returned. - static std::optional GetNormalizedTransposeShape( - const Shape& input_shape, const Shape& output_shape, - const Vector3& permutation); + // permutation[0]-permutation[1]-...-permutation[permutation.size()-1] shape + // is returned. + static std::optional> + GetNormalizedTransposeShape(const Shape& input_shape, + const Shape& output_shape, + absl::InlinedVector& permutation); // Entry point for physical + logical transposition. - static std::optional GetNormalizedLogicalTransposeShape( - const Shape& input_shape, const Shape& output_shape, - absl::Span dimensions, const Vector3& permutation); + static std::optional> + GetNormalizedLogicalTransposeShape( + const Shape& output_shape, absl::Span dimensions, + absl::InlinedVector& permutation); // Strips device-specific information, namely tiling and memory-space // information, from a shape. diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index e7c1beb972958d..c35464af6d55c5 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -1391,167 +1392,173 @@ TEST(ShapeUtilTest, DecomposeBitcastToTrt) { EXPECT_FALSE(decomposition_trt.IsTranspose2Identity()); } -TEST(Transpose021Test, NoTranspose) { +TEST(NormalizedTransposeShapeTest, NoTranspose) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {64, 128}, {0, 1}); + absl::InlinedVector permutation; EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - shape, transposed, Vector3{0, 2, 1})); + shape, transposed, permutation)); } -TEST(Transpose021Test, NoTranspose2) { +TEST(NormalizedTransposeShapeTest, NoTranspose2) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64, 32}, {2, 1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 64, 128}, {0, 1, 2}); + absl::InlinedVector permutation; EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - shape, transposed, Vector3{0, 1, 2})); + shape, transposed, permutation)); } -TEST(Transpose021Test, WrongTranspose) { - Shape input_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0}); - Shape output_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {0, 1, 2}); - EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - input_shape, output_shape, Vector3{0, 2, 1})); -} - -TEST(Transpose021Test, WrongTranspose2) { - Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {1, 0}); - Shape output_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1}); - EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - input_shape, output_shape, Vector3{0, 1, 2})); -} - -TEST(Transpose021Test, WrongTranspose3) { - Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {1, 0}); - Shape output_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1}); - EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - input_shape, output_shape, Vector3{1, 2, 0})); -} - -TEST(Transpose021Test, Simple) { +TEST(NormalizedTransposeShapeTest, Simple) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {0, 1}); - EXPECT_EQ(std::make_optional(Vector3{1, 64, 128}), - ShapeUtil::GetNormalizedTransposeShape(shape, transposed, - Vector3{0, 2, 1})); + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{1, 64, 128}), + ShapeUtil::GetNormalizedTransposeShape(shape, transposed, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, Simple2) { +TEST(NormalizedTransposeShapeTest, Simple2) { Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0}); Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {1, 2, 0}); - EXPECT_EQ(std::make_optional(Vector3{8, 16, 32768}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{8, 16, 32768}), ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, - Vector3{0, 2, 1})); + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, Simple3) { +TEST(NormalizedTransposeShapeTest, Simple3) { Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0}); Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {0, 1, 2}); - EXPECT_EQ(std::make_optional(Vector3{16, 32768, 8}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{16, 32768, 8}), ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, - Vector3{2, 1, 0})); + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{2, 1, 0})); } -TEST(Transpose021Test, Simple4) { - Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {1, 0}); +TEST(NormalizedTransposeShapeTest, NormalizedShapeRank4) { + Shape input_shape = + ShapeUtil::MakeShapeWithDenseLayout(F32, {16, 4, 8, 32768}, {2, 1, 3, 0}); Shape output_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1}); - EXPECT_EQ(std::make_optional(Vector3{16, 1, 8}), - ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, - Vector3{2, 1, 0})); + ShapeUtil::MakeShapeWithDenseLayout(F32, {16, 4, 8, 32768}, {1, 0, 2, 3}); + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{32768, 8, 16, 4}), + ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{1, 3, 0, 2})); } -TEST(Transpose021Test, LargeView) { +TEST(NormalizedTransposeShapeTest, LargeView) { Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout( F32, {8, 32, 32, 32, 16}, {4, 3, 2, 1, 0}); Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout( F32, {8, 32, 32, 32, 16}, {3, 2, 1, 4, 0}); - EXPECT_EQ(std::make_optional(Vector3{8, 16, 32768}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{8, 16, 32768}), ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, - Vector3{0, 2, 1})); + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, LargeSizeOverflowTest) { +TEST(NormalizedTransposeShapeTest, LargeSizeOverflowTest) { Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(BF16, {4096, 4096, 128}, {2, 1, 0}); Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout(BF16, {4096, 4096, 128}, {2, 1, 0}); + absl::InlinedVector permutation; EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - input_shape, output_shape, Vector3{0, 2, 1})); + input_shape, output_shape, permutation)); } -TEST(Transpose021Test, Batched) { +TEST(NormalizedTransposeShapeTest, Batched) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {2, 1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {1, 0, 2}); - EXPECT_EQ(std::make_optional(Vector3{1, 64, 96}), - ShapeUtil::GetNormalizedTransposeShape(shape, transposed, - Vector3{0, 2, 1})); + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{1, 64, 96}), + ShapeUtil::GetNormalizedTransposeShape(shape, transposed, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, BatchedLogical) { - Shape shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {2, 1, 0}); +TEST(NormalizedTransposeShapeTest, BatchedLogical) { Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {64, 32, 3}, {2, 1, 0}); std::vector dimensions = {2, 0, 1}; - EXPECT_EQ(std::make_optional(Vector3{1, 64, 96}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{1, 64, 96}), ShapeUtil::GetNormalizedLogicalTransposeShape( - shape, transposed, dimensions, Vector3{0, 2, 1})); + transposed, dimensions, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, LogicalWithDegenerateDims) { - Shape shape = ShapeUtil::MakeShapeWithDenseLayout( - F32, {1, 32, 1, 3, 1, 64, 1}, {6, 5, 4, 3, 2, 1, 0}); +TEST(NormalizedTransposeShapeTest, LogicalWithDegenerateDims) { Shape transposed = ShapeUtil::MakeShapeWithDenseLayout( F32, {1, 32, 1, 64, 1, 3, 1}, {6, 5, 4, 3, 2, 1, 0}); std::vector dimensions = {6, 1, 4, 5, 2, 3, 0}; - EXPECT_EQ(std::make_optional(Vector3{32, 64, 3}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{32, 64, 3}), ShapeUtil::GetNormalizedLogicalTransposeShape( - shape, transposed, dimensions, Vector3{0, 2, 1})); + transposed, dimensions, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, LogicalWithDegenerateLastDim) { - Shape shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {1, 64, 32}, {2, 1, 0}); +TEST(NormalizedTransposeShapeTest, LogicalWithDegenerateLastDim) { Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 64, 1}, {2, 1, 0}); std::vector dimensions = {2, 1, 0}; - EXPECT_EQ(std::make_optional(Vector3{1, 32, 64}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{1, 32, 64}), ShapeUtil::GetNormalizedLogicalTransposeShape( - shape, transposed, dimensions, Vector3{0, 2, 1})); + transposed, dimensions, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, Large) { +TEST(NormalizedTransposeShapeTest, Large) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 31, 31, 65}, {3, 2, 1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 31, 31, 65}, {2, 1, 3, 0}); - EXPECT_EQ(std::make_optional(Vector3{8, 65, 961}), - ShapeUtil::GetNormalizedTransposeShape(shape, transposed, - Vector3{0, 2, 1})); + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{8, 65, 961}), + ShapeUtil::GetNormalizedTransposeShape(shape, transposed, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose210Test, LogicalTranspose) { - Shape shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {10, 11, 12, 13}, {3, 2, 1, 0}); +TEST(NormalizedLogicialTransposeShapeTest, LogicalTranspose) { Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {13, 12, 10, 11}, {3, 2, 1, 0}); std::vector dimensions = {3, 2, 0, 1}; - EXPECT_EQ(std::make_optional(Vector3{13, 12, 110}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{13, 12, 110}), ShapeUtil::GetNormalizedLogicalTransposeShape( - shape, transposed, dimensions, Vector3{2, 1, 0})); + transposed, dimensions, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{2, 1, 0})); +} + +TEST(NormalizedLogicalTransposeShapeTest, NormalizedShapeRank4) { + Shape transposed = + ShapeUtil::MakeShapeWithDenseLayout(F32, {16, 32768, 8, 4}, {3, 2, 1, 0}); + std::vector dimensions = {2, 0, 3, 1}; + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{16, 32768, 8, 4}), + ShapeUtil::GetNormalizedLogicalTransposeShape(transposed, dimensions, + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{2, 0, 3, 1})); } TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 2cfd29bb5f49ec..a3c96f03aa88a7 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -201,16 +201,67 @@ cc_library( ], ) +cc_library( + name = "stream_finder", + srcs = ["stream_finder.cc"], + hdrs = ["stream_finder.h"], + deps = [ + ":platform", + ":stream", + ":stream_executor_h", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "mock_platform", + testonly = True, + hdrs = ["mock_platform.h"], + deps = [ + ":device_description", + ":platform", + ":stream_executor_h", + "//xla:test", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "mock_stream", + testonly = True, + hdrs = ["mock_stream.h"], + deps = [ + ":device_description", + ":device_memory", + ":event", + ":event_based_timer", + ":kernel", + ":launch_dim", + ":platform", + ":stream", + "//xla:test", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_library( name = "mock_stream_executor", testonly = True, hdrs = ["mock_stream_executor.h"], deps = [ ":allocator_stats", + ":blas", ":command_buffer", ":device_description", ":device_memory", + ":dnn", ":event", + ":fft", ":kernel", ":kernel_spec", ":launch_dim", @@ -220,7 +271,6 @@ cc_library( ":stream", ":stream_executor_h", "//xla:test", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -380,6 +430,7 @@ cc_library( ":device_memory", ":numeric_options", "//xla/stream_executor/platform", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -391,7 +442,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # buildcleaner: keep - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", @@ -449,12 +499,10 @@ cc_library( ":fft", ":kernel", ":kernel_spec", - ":launch_dim", ":memory_allocation", ":module_spec", ":platform", ":stream", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -573,7 +621,6 @@ cc_library( srcs = ["executor_cache.cc"], hdrs = ["executor_cache.h"], deps = [ - ":platform", ":stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -670,13 +717,10 @@ cc_library( ":blas", ":device_description", ":fft", - ":kernel", - ":launch_dim", ":platform", ":stream", ":stream_executor_h", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -818,10 +862,27 @@ xla_cc_test( deps = [ ":executor_cache", ":mock_stream_executor", - ":platform", ":stream", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/log", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "stream_finder_test", + srcs = ["stream_finder_test.cc"], + deps = [ + ":mock_platform", + ":mock_stream", + ":mock_stream_executor", + ":stream_finder", + "//xla:test", + "@com_google_absl//absl/status", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 6b9fda16902a19..b0b0bf5e608d24 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -29,7 +29,14 @@ load( "tf_additional_gpu_compilation_copts", ) load("//xla/tests:build_defs.bzl", "xla_test") -load("//xla/tsl:tsl.bzl", "if_google", "if_nccl", "internal_visibility", "tsl_copts") +load( + "//xla/tsl:tsl.bzl", + "if_google", + "if_hermetic_cuda_tools", + "if_nccl", + "internal_visibility", + "tsl_copts", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -134,9 +141,21 @@ cuda_only_cc_library( # Buildozer can not remove dependencies inside select guards, so we have to use # an intermediate target. -cc_library(name = "ptxas_wrapper") +cc_library( + name = "ptxas_wrapper", + data = if_hermetic_cuda_tools( + ["@cuda_nvcc//:ptxas"], + [], + ), +) -cc_library(name = "nvlink_wrapper") +cc_library( + name = "nvlink_wrapper", + data = if_hermetic_cuda_tools( + ["@cuda_nvcc//:nvlink"], + [], + ), +) cuda_only_cc_library( name = "cuda_driver", @@ -226,12 +245,9 @@ xla_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], backends = ["gpu"], - tags = [ - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - "gpu", - "no_rocm", - ], + tags = ["no_rocm"], deps = [ + ":cuda_diagnostics", ":cuda_driver", ":cuda_status", "//xla/stream_executor/gpu:gpu_driver_header", @@ -429,7 +445,6 @@ gpu_kernel_library( "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_semaphore", - "//xla/stream_executor/gpu:gpu_stream", "@com_google_absl//absl/status:statusor", ], ) @@ -601,6 +616,23 @@ cc_library( ], ) +xla_test( + name = "cuda_platform_test", + srcs = ["cuda_platform_test.cc"], + backends = ["gpu"], + tags = ["no_rocm"], + deps = [ + ":cuda_platform", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "ptx_compiler_test", srcs = ["ptx_compiler_test.cc"], @@ -712,7 +744,7 @@ xla_cc_test( name = "nvjitlink_test", srcs = ["nvjitlink_test.cc"], args = if_google([ - # nvjitlink allocates memory and only keeps a pointer past the usual offest of 1024 bytes; + # nvjitlink allocates memory and only keeps a pointer past the usual offset of 1024 bytes; # so we need to increase the max pointer offset. -1 means no limit. # This is only relevant for Google's HeapLeakChecker. The newer Leak sanitizer doesn't # have this issue. @@ -750,6 +782,13 @@ cuda_only_cc_library( # "@local_config_cuda//cuda:runtime_ptxas", # ], # copybara:uncomment_end + # copybara:comment_begin + data = if_hermetic_cuda_tools([ + "@cuda_nvcc//:fatbinary", + "@cuda_nvcc//:nvlink", + "@cuda_nvcc//:ptxas", + ]), + # copybara:comment_end visibility = internal_visibility([ "//third_party/py/jax:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc index 01aa15313c2cd0..7f2183f85a0a95 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index a4337dfe60e497..9d200e74dcada8 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" @@ -448,12 +449,15 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, b_scale.opaque())); } - if (c_scale != nullptr) { + auto isF8Input = [](const auto& desc) { + return desc.type() == CUDA_R_8F_E4M3 || desc.type() == CUDA_R_8F_E5M2; + }; + if (c_scale != nullptr && isF8Input(c_desc_)) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, c_scale.opaque())); } - if (d_scale != nullptr) { + if (d_scale != nullptr && isF8Input(d_desc_)) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, d_scale.opaque())); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h index 2fae670f87edca..3d61c816024af9 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc index 2060fb2e296ead..155c3383e7d843 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc @@ -225,7 +225,7 @@ absl::StatusOr Diagnostician::FindDsoVersion() { absl::StatusOr Diagnostician::FindKernelModuleVersion( const std::string &driver_version_file_contents) { - static const char *kDriverFilePrelude = "Kernel Module "; + static const char *kDriverFilePrelude = "Kernel Module"; size_t offset = driver_version_file_contents.find(kDriverFilePrelude); if (offset == std::string::npos) { return absl::NotFoundError( @@ -233,9 +233,17 @@ absl::StatusOr Diagnostician::FindKernelModuleVersion( "driver version file contents: \"", driver_version_file_contents, "\"")); } + static const char *kDriverVersionPrelude = " "; + offset = driver_version_file_contents.find(kDriverVersionPrelude, offset); + if (offset == std::string::npos) { + return absl::NotFoundError( + absl::StrCat("driver version not preceded by two spaces in " + "driver version file contents: \"", + driver_version_file_contents, "\"")); + } std::string version_and_rest = driver_version_file_contents.substr( - offset + strlen(kDriverFilePrelude), std::string::npos); + offset + strlen(kDriverVersionPrelude), std::string::npos); size_t space_index = version_and_rest.find(' '); auto kernel_version = version_and_rest.substr(0, space_index); // TODO(b/22689637): Eliminate the explicit namespace if possible. diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 90640a839dcf23..440f647b84f1ce 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -40,7 +40,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -3762,32 +3761,6 @@ absl::StatusOr CreateCudnnTensor( } #if CUDNN_VERSION >= 8800 -enum CudnnfMHAUid { - Q_ID = 400, - K_ID, - V_ID, - P_ID, - O_ID, - dQ_ID, - dK_ID, - dV_ID, - dP_ID, - dO_ID, - dS_ID, - dBIAS_ID, - BIAS_ID, - MASK_ID, - ZERO_VAL_ID, - ONE_VAL_ID, - NEG_INFINITY_ID, - ALPHA_SCALE_ID, - DROPOUT_SCALE_ID, - Q_SEQLEN_ID, - K_SEQLEN_ID, - D_OFFSET_ID, - D_SEED_ID, - VIRTUAL_ID = 34857 -}; absl::StatusOr CreatePwDesc( dnn::DataType dtype, cudnnPointwiseMode_t mode) { @@ -5032,12 +5005,14 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_io_data_type(ioDataType) .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::shared_ptr q_tensor = graph.tensor(Tensor_attributes() .set_name("Q") .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true)) .set_stride(q_descriptor.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::Q_ID)); + .set_uid(next_uid())); auto dim = k_descriptor.GetCudnnCompatibleDimensions(true); std::shared_ptr k_tensor = @@ -5045,13 +5020,13 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("K") .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true)) .set_stride(k_descriptor.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::K_ID)); + .set_uid(next_uid())); std::shared_ptr v_tensor = graph.tensor( Tensor_attributes() .set_name("V") .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false)) .set_stride(v_descriptor.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::V_ID)); + .set_uid(next_uid())); // Setting sdpa, and is_inference bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || @@ -5069,7 +5044,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("bias") .set_dim(bias_descriptor->dimensions()) .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::BIAS_ID)); + .set_uid(next_uid())); sdpa_options.set_bias(bias_tensor); } // Setting actual seqlen @@ -5083,37 +5058,38 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("seq_q") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::Q_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); auto seq_kv_tensor = graph.tensor(Tensor_attributes() .set_name("seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::K_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); sdpa_options.set_padding_mask(true); sdpa_options.set_seq_len_q(seq_q_tensor); sdpa_options.set_seq_len_kv(seq_kv_tensor); } // Setting seed and offset + std::shared_ptr seed_tensor; + std::shared_ptr offset_tensor; if (use_dropout) { - auto seed_tensor = + // Skip setting UIDs: pass by value tensors go at the end. + seed_tensor = graph.tensor(Tensor_attributes() .set_name("seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_SEED_ID)); - auto offset_tensor = + .set_is_pass_by_value(true)); + offset_tensor = graph.tensor(Tensor_attributes() .set_name("offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + .set_is_pass_by_value(true)); sdpa_options.set_dropout((float)dropout_rate.value(), seed_tensor, offset_tensor); } @@ -5127,7 +5103,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_output(true) .set_dim(o_descriptor.dimensions()) .set_stride(o_descriptor.GetLogicalStrides()) - .set_uid(CudnnfMHAUid::O_ID); + .set_uid(next_uid()); if (stats_descriptor.has_value()) { cudnn_frontend::DataType_t statsType = ToCudnnFrontendDataType(stats_descriptor->type()); @@ -5140,7 +5116,13 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_data_type(statsType) .set_dim(stat_dims) .set_stride(stat_strides) - .set_uid(CudnnfMHAUid::P_ID); + .set_uid(next_uid()); + } + if (seed_tensor != nullptr) { + seed_tensor->set_uid(next_uid()); + } + if (offset_tensor != nullptr) { + offset_tensor->set_uid(next_uid()); } CudnnGraph cudnnGraph(std::move(graph)); TF_RETURN_IF_ERROR(cudnnGraph.Prepare( @@ -5195,71 +5177,66 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) .set_io_data_type(ioDataType); + auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); + auto p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + p_reduction_strides[3] = 1; + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || + mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; + auto sdpa_backward_options = + cudnn_frontend::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_causal_mask(is_causal) + .set_attn_scale(scale) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::shared_ptr q = graph.tensor(Tensor_attributes() .set_name("Q") .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) .set_stride(q_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::Q_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr k = graph.tensor(Tensor_attributes() .set_name("K") .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) .set_stride(k_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::K_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr v = graph.tensor(Tensor_attributes() .set_name("V") .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) .set_stride(v_desc.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::V_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); - std::shared_ptr o = + std::shared_ptr stats = graph.tensor(Tensor_attributes() - .set_name("O") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) - .set_stride(do_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::O_ID) - .set_data_type(ioDataType)); + .set_name("stats") + .set_dim(p_reduction_dims) + .set_stride(p_reduction_strides) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); std::shared_ptr dO = graph.tensor(Tensor_attributes() .set_name("dO") .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::dO_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); - auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); - auto p_strides = p_desc.GetCudnnCompatibleStrides(false); - std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); - p_reduction_dims.push_back(1); - - // Divide every stride by the last dim value. - std::vector p_reduction_strides; - p_reduction_strides.reserve(p_strides.size()); - int64_t p_reduced_dim_len = p_dims.back(); - for (auto stride : p_strides) { - p_reduction_strides.push_back(stride / p_reduced_dim_len); - } - p_reduction_strides[3] = 1; - std::shared_ptr stats = - graph.tensor(Tensor_attributes() - .set_name("stats") - .set_dim(p_reduction_dims) - .set_stride(p_reduction_strides) - .set_uid(CudnnfMHAUid::P_ID) - .set_data_type(cudnn_frontend::DataType_t::FLOAT)); - bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - auto sdpa_backward_options = - cudnn_frontend::graph::SDPA_backward_attributes() - .set_name("flash_attention_backward") - .set_causal_mask(is_causal) - .set_attn_scale(scale) - .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); - - // Setting bias + std::shared_ptr d_bias_tensor; if (use_bias) { DCHECK(bias_descriptor != std::nullopt); auto bias_dim = bias_descriptor->dimensions(); @@ -5272,21 +5249,29 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_name("bias") .set_dim(bias_descriptor->dimensions()) .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::BIAS_ID)); + .set_uid(next_uid())); sdpa_backward_options.set_bias(bias_tensor); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for // dbias calculation but they are supported for forward bias calculation + // Set UID later: this is the last output tuple element. if (b == 1 && n == q_n) { - auto d_bias_tensor = + d_bias_tensor = graph.tensor(Tensor_attributes() .set_name("dBias") .set_dim(bias_descriptor->dimensions()) - .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::dBIAS_ID)); + .set_stride(bias_descriptor->GetLogicalStrides())); sdpa_backward_options.set_dbias(d_bias_tensor); } } + std::shared_ptr o = + graph.tensor(Tensor_attributes() + .set_name("O") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; @@ -5298,38 +5283,39 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_name("seq_q") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::Q_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); auto seq_kv_tensor = graph.tensor(Tensor_attributes() .set_name("seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::K_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); sdpa_backward_options.set_padding_mask(true); sdpa_backward_options.set_seq_len_q(seq_q_tensor); sdpa_backward_options.set_seq_len_kv(seq_kv_tensor); } // Setting seed and offset + std::shared_ptr seed_tensor; + std::shared_ptr offset_tensor; if (use_dropout) { DCHECK(dropout_rate != std::nullopt); - auto seed_tensor = + // Skip setting UIDs: pass by value tensors go at the end. + seed_tensor = graph.tensor(Tensor_attributes() .set_name("seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_SEED_ID)); - auto offset_tensor = + .set_is_pass_by_value(true)); + offset_tensor = graph.tensor(Tensor_attributes() .set_name("offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + .set_is_pass_by_value(true)); sdpa_backward_options.set_dropout((float)dropout_rate.value(), seed_tensor, offset_tensor); } @@ -5344,21 +5330,30 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( dQ->set_output(true) .set_dim(dq_desc.dimensions()) .set_stride(dq_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dQ") - .set_uid(CudnnfMHAUid::dQ_ID) .set_data_type(ioDataType); dK->set_output(true) .set_dim(dk_desc.dimensions()) .set_stride(dk_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dK") - .set_uid(CudnnfMHAUid::dK_ID) .set_data_type(ioDataType); dV->set_output(true) .set_dim(dv_desc.dimensions()) .set_stride(dv_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dV") - .set_uid(CudnnfMHAUid::dV_ID) .set_data_type(ioDataType); + if (d_bias_tensor != nullptr) { + d_bias_tensor->set_uid(next_uid()); + } + if (seed_tensor != nullptr) { + seed_tensor->set_uid(next_uid()); + } + if (offset_tensor != nullptr) { + offset_tensor->set_uid(next_uid()); + } CudnnGraph cudnnGraph(std::move(graph)); TF_RETURN_IF_ERROR( @@ -5696,8 +5691,8 @@ absl::Status CudnnSupport::DoConvolve( } // Utility for dealing with CUDA's type-erased scaling parameters, where some -// sets of parameters expect a void* pointing at a float while others expect it -// to point at a double. +// sets of parameters expect a void* pointing at a float while others expect +// it to point at a double. // // This is rather ugly, but its purpose is to quarantine the corresponding // ugliness that already exists in the CUDA API. @@ -5721,9 +5716,9 @@ class ScalingParam { // // See // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters - // for more info; the behavior for int8 result tensors is not described there, - // but is maintained from the existing behavior (namely, using a float scaling - // parameter). + // for more info; the behavior for int8 result tensors is not described + // there, but is maintained from the existing behavior (namely, using a + // float scaling parameter). void* ToVoidPointer(dnn::DataType element_type) { if (element_type == dnn::DataType::kDouble) { return &as_double_; @@ -5795,10 +5790,11 @@ absl::StatusOr> GetDescriptorAttribute( absl::c_transform(result, std::back_inserter(raw_ptrs), [](const BackendDescriptor& ptr) { return ptr.get(); }); - // This API evidently does a deep copy of the descriptors into the pointers in - // the output array, rather than writing pointers to the descriptors into the - // output array. So, this writes the memory behind each BackendDescriptor in - // result, rather than writing the contents of raw_ptrs. + // This API evidently does a deep copy of the descriptors into the pointers + // in the output array, rather than writing pointers to the descriptors into + // the output array. So, this writes the memory behind each + // BackendDescriptor in result, rather than writing the contents of + // raw_ptrs. RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute( desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, n, &n, raw_ptrs.data())); @@ -5834,9 +5830,9 @@ absl::StatusOr ExecutionPlanToAlgorithmDesc( cudnnBackendGetAttribute(engines[0].get(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX, CUDNN_TYPE_INT64, 1, &n, &engine_id)); - // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query the - // number of elements in the attribute by using an output limit value of 0 - // just returns 0; the only way to find out how many there are is to + // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query + // the number of elements in the attribute by using an output limit value of + // 0 just returns 0; the only way to find out how many there are is to // pre-allocate space for every existing knob type (as an upper bound on the // number of knob choices a config can have), and then look back at how many // were filled. @@ -6047,103 +6043,7 @@ class CudnnExecutionPlanRunner std::vector scalar_input_uids_; std::vector scalar_input_values_; }; -#endif // CUDNN_VERSION >= 8100 - -template -class CudnnGraphRunner; -// An OpRunner implemented by a cuDNN frontend graph. -// -// This is the class holding the implementation of ToString, GetWorkspaceSize, -// and operator() for use by the cudnn frontend op runners. -template -class CudnnGraphRunner : public dnn::OpRunner { - private: - using Graph = cudnn_frontend::graph::Graph; - using Tensor_attributes = cudnn_frontend::graph::Tensor_attributes; - - public: - std::string ToString() const override { return graph_.Graph().print(); } - - size_t GetWorkspaceSize() const override { - return graph_.Graph().get_workspace_size(); - } - - absl::StatusOr ToAlgorithmDesc() const override { - return absl::InternalError( - "Unexpected call to CudnnGraphRunner::ToAlgorithmDesc"); - } - - absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - Args... inputs) const override { - if (parent_ != stream->parent()) { - return tsl::errors::Internal( - "CudnnExecutionPlanRunner cached across multiple StreamExecutors."); - } - CudnnHandle handle = cudnn_->GetHandle(parent_, stream); - std::unordered_map variant_pack; - std::vector vec = {inputs.opaque()...}; - - // add device buffers to the variant pack - for (int i = 0; i < uids_.size(); ++i) { - if (uids_[i].has_value()) { - variant_pack[*uids_[i]] = vec[i]; - } - } - if (dropout_rng_offset_increment_ > 0) { -#if CUDNN_VERSION >= 8800 - variant_pack[CudnnfMHAUid::D_SEED_ID] = (void*)&dropout_rng_seed_; - current_dropout_rng_offset_ += dropout_rng_offset_increment_; - variant_pack[CudnnfMHAUid::D_OFFSET_ID] = - (void*)¤t_dropout_rng_offset_; -#else - return absl::UnimplementedError( - "Cudnn dropout offset and seed are only supported with Cudnn >= " - "8.8.0"); -#endif // CUDNN_VERSION >= 8800 - } - int workspace = graph_.Graph().get_workspace_size(); - if (workspace > scratch_memory.size()) { - return tsl::errors::Internal( - absl::StrFormat("CuDNN FMHA requires %d workspace, got %d workspace.", - workspace, scratch_memory.size())); - } - RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.Graph().execute( - handle.handle(), variant_pack, scratch_memory.opaque())); - - return absl::OkStatus(); - } - - static absl::StatusOr Create( - GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - std::vector> uids) { - return CudnnGraphRunner(parent, cudnn, std::move(graph), dropout_rng_seed, - dropout_rng_offset, uids); - } - - private: - CudnnGraphRunner(GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - std::vector> uids) - : parent_(parent), - cudnn_(cudnn), - graph_(std::move(graph)), - dropout_rng_seed_(dropout_rng_seed), - current_dropout_rng_offset_(0), - dropout_rng_offset_increment_(dropout_rng_offset), - uids_(uids) {} - GpuExecutor* parent_; - CudnnAccess* cudnn_; - Stream* stream_; - CudnnGraph graph_; - int64_t dropout_rng_seed_; - mutable int64_t current_dropout_rng_offset_; - int64_t dropout_rng_offset_increment_; - std::vector> uids_; -}; -#if CUDNN_VERSION >= 8100 namespace { template @@ -6929,7 +6829,8 @@ absl::Status CudnnSupport::GetFusedMatmulRunners( use_fallback, out_exec_plans, /*need_side_input=*/true, numeric_options); #else return tsl::errors::Unimplemented( - "Cudnn execution plans for matmul are only supported with Cudnn >= 8.4."); + "Cudnn execution plans for matmul are only supported with Cudnn >= " + "8.4."); #endif // CUDNN_VERSION >= 8400 } @@ -7131,139 +7032,6 @@ int64_t GetDropoutRngOffset(std::vector& intermediate_shape) { return max_seq_len * max_seq_len / cudnn_mha_num_threads; } -absl::StatusOr> -CudnnSupport::FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) { -#if CUDNN_VERSION >= 90000 - auto cudnn = cudnn_->GetHandle(parent_, stream); - bool use_dropout = dropout_rate && *dropout_rate > 0.0; - std::vector intermediate_shape; - - TF_ASSIGN_OR_RETURN(auto graph, - GetCudnnFlashAttentionOperationGraph( - *this, /*q_descriptor=*/bmm1_lhs_descriptor, - /*k_descriptor=*/bmm1_rhs_descriptor, - /*v_descriptor=*/bmm2_rhs_descriptor, - /*o_descriptor=*/output_descriptor, bias_descriptor, - /*stats_descriptor=*/activation_descriptor, - /*scale=*/static_cast(scale), use_dropout, - dropout_rate, mask_type)); - - std::vector intermediate_bmm2_lhs_dims = - intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true); - intermediate_shape = intermediate_bmm2_lhs_dims; - int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; - std::vector> uids = { - CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::V_ID, - CudnnfMHAUid::O_ID}; - uids.emplace_back(bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::BIAS_ID) - : std::nullopt); - uids.emplace_back(activation_descriptor.has_value() - ? std::optional(CudnnfMHAUid::P_ID) - : std::nullopt); - bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::Q_SEQLEN_ID) - : std::nullopt); - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::K_SEQLEN_ID) - : std::nullopt); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnGraphRunner::Create( - parent_, cudnn_.get(), std::move(graph), - dropout_rng_seed, dropout_rng_offset, uids)); - - return {std::make_unique>( - std::move(runner))}; -#else - return absl::UnimplementedError( - "Cudnn flash attention are only supported with Cudnn >= 9.0.0"); -#endif // CUDNN_VERSION >= 90000 -} - -absl::StatusOr> -CudnnSupport::FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& d_output_descriptor, - const dnn::TensorDescriptor& d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor& d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic) { -#if CUDNN_VERSION >= 90000 - auto cudnn = cudnn_->GetHandle(parent_, stream); - - bool use_dropout = dropout_rate && *dropout_rate > 0.0; - std::vector intermediate_shape; - - TF_ASSIGN_OR_RETURN( - auto graph, - GetCudnnFlashAttentionBackwardOperationGraph( - *this, bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor, - bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor, - d_output_descriptor, d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor, - d_bmm2_rhs_descriptor, bias_descriptor, dropout_rate, seed, scale, - use_dropout, bias_descriptor != std::nullopt, mask_type, - force_deterministic)); - - std::vector p_dims = - bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false); - intermediate_shape = p_dims; - int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; - - std::vector> uids; - uids = {CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::P_ID, - CudnnfMHAUid::V_ID, CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID, - CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, std::nullopt}; - uids.emplace_back(d_bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::dBIAS_ID) - : std::nullopt); - uids.push_back(CudnnfMHAUid::O_ID); - uids.emplace_back(bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::BIAS_ID) - : std::nullopt); - bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::Q_SEQLEN_ID) - : std::nullopt); - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::K_SEQLEN_ID) - : std::nullopt); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnGraphRunner::Create( - parent_, cudnn_.get(), graph, dropout_rng_seed, - dropout_rng_offset, uids)); - return {std::make_unique>( - std::move(runner))}; -#else - return absl::UnimplementedError( - "Cudnn flash attention bwd are only " - "supported with Cudnn >= 9.0.0"); -#endif // CUDNN_VERSION >= 90000 -} - bool CudnnSupport::GetRnnAlgorithms( std::vector* out_algorithms) { PreloadCudnnSubLibs(PreloadCudnnType::Rnn); @@ -8348,15 +8116,30 @@ absl::Status CudnnGraph::Execute(Stream& stream, std::unordered_map tensor_to_ptr_map; absl::Span operands_without_workspace = operands; DeviceMemoryBase workspace; - if (graph_.get_workspace_size() != 0) { + if (graph_.get_workspace_size() > 0) { workspace = operands.back(); CHECK_EQ(graph_.get_workspace_size(), workspace.size()); + } + if (graph_.get_workspace_size() > 0 || operands.back().size() == 0) { operands_without_workspace = operands.first(operands.size() - 1); } - int operand_number = 0; + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; for (DeviceMemoryBase operand : operands_without_workspace) { - tensor_to_ptr_map[CuDnnTensorUID(operand_number++)] = operand.opaque(); + tensor_to_ptr_map[next_uid()] = operand.opaque(); } + + if (dropout_rng_offset_increment_ > 0) { +#if CUDNN_VERSION >= 8800 + tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_; + current_dropout_rng_offset_ += dropout_rng_offset_increment_; + tensor_to_ptr_map[next_uid()] = (void*)¤t_dropout_rng_offset_; +#else + return absl::UnimplementedError( + "Cudnn dropout offset and seed are only supported with Cudnn >= " + "8.8.0"); +#endif // CUDNN_VERSION >= 8800 + } + const CudnnSupport& dnn_support = static_cast(*stream.parent()->AsDnn()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 083a2b431e62c1..24d84e369cb138 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -70,6 +70,9 @@ class CudnnGraph : public dnn::DnnGraph { private: cudnn_frontend::graph::Graph graph_; + int64_t dropout_rng_seed_; + mutable int64_t current_dropout_rng_offset_; + int64_t dropout_rng_offset_increment_ = 0; }; #endif // CUDNN_VERSION >= 8100 @@ -335,37 +338,6 @@ class CudnnSupport : public dnn::DnnSupport { std::optional dscale_descriptor, std::optional dbias_descriptor) override; - absl::StatusOr> - FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) override; - - absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& d_output_descriptor, - const dnn::TensorDescriptor& d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor& d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic); - bool GetRnnAlgorithms( std::vector* out_algorithms) override; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 866c1ff7131462..e8e26e2c9de5ee 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -27,7 +28,6 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/base/const_init.h" -#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/debugging/leak_check.h" #include "absl/log/check.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h index 5c04ab6ccbee02..aefd89650fda0f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h @@ -19,16 +19,13 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ #include -#include #include -#include #include #include #include "absl/container/node_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_status.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc index 7cb402a91ca43a..ba855635f3ecdb 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "tsl/platform/status.h" @@ -49,7 +50,7 @@ TEST(CudaDriverTest, ScopedActivateContextTest) { CUcontext context0, context1; CHECK_CUDA(cuCtxCreate(&context0, 0, device)); CHECK_CUDA(cuCtxCreate(&context1, 0, device)); - GpuContext se_context1(context1, /*id=*/101); + GpuContext se_context1(context1, /*device_ordinal=*/101); { ScopedActivateContext scope(&se_context1); CUcontext c; @@ -68,4 +69,25 @@ TEST(CudaDriverTest, ScopedActivateContextTest) { } } // namespace gpu + +namespace cuda { + +TEST(CudaDriverTest, DriverVersionParsingTest) { + // Tests that the driver version can be right after 'Kernel Module', + // or later as well. + auto driver_version = Diagnostician::FindKernelModuleVersion( + "... NVIDIA UNIX Open Kernel Module for x86_64 570.00 Release Build " + "... Mon Aug 12 04:17:20 UTC 2024"); + TF_CHECK_OK(driver_version.status()); + EXPECT_EQ("570.0.0", cuda::DriverVersionToString(driver_version.value())); + + driver_version = Diagnostician::FindKernelModuleVersion( + "... NVIDIA UNIX Open Kernel Module 571.00 Release Build " + "... Mon Aug 12 04:17:20 UTC 2024"); + TF_CHECK_OK(driver_version.status()); + EXPECT_EQ("571.0.0", cuda::DriverVersionToString(driver_version.value())); +} + +} // namespace cuda + } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 0b90f27b8811d9..8ae24775c7558f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -17,20 +17,17 @@ limitations under the License. #include #include #include -#include #include #include #include #include #include #include -#include -#include "absl/base/casts.h" #include "absl/numeric/int128.h" -#include "absl/strings/str_join.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" @@ -46,7 +43,6 @@ limitations under the License. #include #endif -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -54,7 +50,6 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -75,14 +70,12 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" @@ -477,83 +470,6 @@ absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, return absl::OkStatus(); } -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, const KernelArgs& args) { - return Launch(stream, thread_dims, block_dims, std::nullopt, kernel, args); -} - -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - return Launch(stream, thread_dims, block_dims, - std::make_optional(cluster_dims), kernel, args); -} - -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - CUstream custream = AsGpuStreamValue(stream); - const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); - CUfunction cufunc = cuda_kernel->gpu_function(); - - if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) { - TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( - cufunc, cuda_kernel->GetGpuCacheConfig())); - } - - // Launch CUDA kernels with packed arguments. - auto launch = [&](const KernelArgsPackedArrayBase& packed) { - int32_t expected_number_of_arguments = - kernel.Arity() + (packed.number_of_shared_bytes() > 0); - - CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) - << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() - << " arguments, but expected " << expected_number_of_arguments - << "; arity=" << kernel.Arity() - << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); - - void** params = const_cast(packed.argument_addresses().data()); - - if (cluster_dims.has_value()) { - return GpuDriver::LaunchKernel( - context_, kernel.name(), cufunc, cluster_dims->x, cluster_dims->y, - cluster_dims->z, block_dims.x, block_dims.y, block_dims.z, - thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), custream, params, - /*extra=*/nullptr); - } else { - return GpuDriver::LaunchKernel( - context_, kernel.name(), cufunc, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), custream, params, - /*extra=*/nullptr); - } - }; - - // If arguments are already packed we can just launch the kernel. - if (auto* packed = DynCast(&args)) { - return launch(*packed); - } - - // For device memory array we rely on a custom kernel arguments packing. - if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.args_packing(); - if (!pack) { - return absl::InternalError( - "Kernel is missing a custom arguments packing function for device " - "memory arguments array"); - } - - TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); - return launch(*packed); - } - - return absl::InternalError("Unsupported kernel arguments type"); -} - DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == 1) { auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc index ea86363ce27e9f..83dab87a0c6c85 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc @@ -59,26 +59,17 @@ CudaPlatform::DescriptionForDevice(int ordinal) const { } absl::StatusOr CudaPlatform::ExecutorForDevice(int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); + return executor_cache_.GetOrCreate( + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } -absl::StatusOr CudaPlatform::GetExecutor( - const StreamExecutorConfig& config) { - if (config.gpu_stream) { - // If the GPU stream was provided, it's not possible to get-or-create a - // stream with a required pointer: so we are looking for previously - // allocated streams. - return executor_cache_.Get(config); - } - return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); +absl::StatusOr CudaPlatform::FindExisting(int ordinal) { + return executor_cache_.Get(ordinal); } absl::StatusOr> -CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = std::make_unique(this, config.ordinal); +CudaPlatform::GetUncachedExecutor(int ordinal) { + auto executor = std::make_unique(this, ordinal); TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h index 25fec73b90f372..e4ba806343f091 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h @@ -53,15 +53,13 @@ class CudaPlatform : public Platform { int ordinal) const override; absl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr FindExisting(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; - - // Returns a device constructed with the options specified in "config" without + // Returns a device constructed with the ordinal without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config); + int ordinal); private: // This platform's name. diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc new file mode 100644 index 00000000000000..b9621f76aee349 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_platform.h" + +#include +#include "absl/container/flat_hash_map.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { + +TEST(CudaPlatformTest, FindExistingWorks) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + PlatformManager::PlatformWithName("CUDA")); + CHECK_GT(platform->VisibleDeviceCount(), 0); + for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { + EXPECT_FALSE(platform->FindExisting(i).ok()); + } + absl::flat_hash_map executors; + for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { + TF_ASSERT_OK_AND_ASSIGN(auto executor, platform->ExecutorForDevice(i)); + executors[i] = executor; + } + EXPECT_EQ(executors.size(), platform->VisibleDeviceCount()); + for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { + TF_ASSERT_OK_AND_ASSIGN(auto executor, platform->FindExisting(i)); + EXPECT_EQ(executor, executors[i]); + } +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h index aa59af500ba7a3..0a30c1af59c0c4 100644 --- a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h +++ b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h @@ -29,6 +29,11 @@ namespace gpu { } \ } while (false) +// UIDs for cuDNN are unique identifiers of tensors within a graph. They are +// assigned during graph construction; then graph execution takes a {uid: +// buffer pointer} map defining the correspondance of buffers to tensors. +// UID assignment scheme can be arbitrary; at the moment for simplicity XLA uses +// a scheme UID = (HLO operand number + 1). int CuDnnTensorUID(int offset); } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/cuda/delay_kernel.h b/third_party/xla/xla/stream_executor/cuda/delay_kernel.h index 09aad2f6e85a67..016639d0ba2136 100644 --- a/third_party/xla/xla/stream_executor/cuda/delay_kernel.h +++ b/third_party/xla/xla/stream_executor/cuda/delay_kernel.h @@ -18,7 +18,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/stream.h" namespace stream_executor::gpu { diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc index c2958332c154c3..aae94067af0ceb 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include #include #include diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index 5a674a05e175c2..951b2f6e147cd8 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/stream_executor/data_type.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/protobuf/dnn.pb.h" @@ -249,42 +249,6 @@ DnnSupport::NormRunnerFromDesc( return absl::UnimplementedError("NormRunnerFromDesc not implemented."); } -absl::StatusOr> -DnnSupport::FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) { - return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented."); -} - -absl::StatusOr> -DnnSupport::FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& d_output_descriptor, - const TensorDescriptor& d_bmm1_lhs_descriptor, - const TensorDescriptor& d_bmm1_rhs_descriptor, - const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic) { - return absl::UnimplementedError( - "FusedMHABackwardRunnerFromDesc not implemented."); -} - bool DnnSupport::GetMIOpenConvolveAlgorithms( dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/, dnn::DataType /*output_type*/, Stream* /*stream*/, diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index af709946eeb241..a2e1cd629dc2b4 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -993,30 +993,6 @@ using FusedMatmulRunner = OpRunner; using NormSignature = void(std::vector); using NormRunner = OpRunner; -using FusedMHASignature = void(DeviceMemoryBase /*BMM1_inputA_data*/, - DeviceMemoryBase /* BMM1_inputB_data */, - DeviceMemoryBase /* BMM2_inputA_data */, - DeviceMemoryBase /* output_data */, - DeviceMemoryBase /* bias_data */, - DeviceMemoryBase /* activation_data */, - DeviceMemoryBase /* seqlen_q_data */, - DeviceMemoryBase /* seqlen_k_data */); -using FusedMHARunner = OpRunner; - -using FusedMHABackwardSignature = void( - DeviceMemoryBase /* BMM1_GRAD_GEMM1_inputA_data */, - DeviceMemoryBase /* BMM1_GRAD_GEMM2_inputB_data */, - DeviceMemoryBase /* BMM2_GRAD_GEMM1_inputA_data */, - DeviceMemoryBase /* BMM2_GRAD_GEMM2_inputB_data */, - DeviceMemoryBase /* d_output_data */, - DeviceMemoryBase /* d_BMM1_inputA_data */, - DeviceMemoryBase /* d_BMM1_inputB_data */, - DeviceMemoryBase /* d_BMM2_inputB_data */, DeviceMemoryBase /* d_S_data */, - DeviceMemoryBase /* d_bias_data */, DeviceMemoryBase /* fwd_output_data */, - DeviceMemoryBase /* bias_data */, DeviceMemoryBase /* seqlen_q_data */, - DeviceMemoryBase /* seqlen_k_data */); -using FusedMHABackwardRunner = OpRunner; - // Describes the configuration for the algorithms that will used. // // Arguments: @@ -1731,37 +1707,6 @@ class DnnSupport { return absl::UnimplementedError("Graph support requires cuDNN >= 8.1."); }; - virtual absl::StatusOr> - FusedMHARunnerFromDesc( - Stream* stream, const AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_rhs_descriptor, - const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type); - - virtual absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - Stream* stream, const AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& d_output_descriptor, - const TensorDescriptor& d_bmm1_lhs_descriptor, - const TensorDescriptor& d_bmm1_rhs_descriptor, - const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic); - virtual bool GetMIOpenConvolveAlgorithms( ConvolutionKind kind, DataType element_type, DataType output_type, Stream* stream, const BatchDescriptor& input_descriptor, diff --git a/third_party/xla/xla/stream_executor/executor_cache.cc b/third_party/xla/xla/stream_executor/executor_cache.cc index 341af6f2d4b5da..1fcfd6b847f907 100644 --- a/third_party/xla/xla/stream_executor/executor_cache.cc +++ b/third_party/xla/xla/stream_executor/executor_cache.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -33,11 +32,11 @@ ExecutorCache::ExecutorCache() = default; ExecutorCache::~ExecutorCache() = default; absl::StatusOr ExecutorCache::GetOrCreate( - const StreamExecutorConfig& config, const ExecutorFactory& factory) { + int ordinal, const ExecutorFactory& factory) { // In the fast path case, the cache already has an entry and we can just // return after Get() which only takes a shared lock and not a unique lock. // If we need to create, we take a unique lock on cache_. - if (auto fast_result = Get(config); fast_result.ok()) { + if (auto fast_result = Get(ordinal); fast_result.ok()) { return fast_result; } @@ -45,32 +44,19 @@ absl::StatusOr ExecutorCache::GetOrCreate( TF_ASSIGN_OR_RETURN(std::unique_ptr result, factory()); auto returned_executor = result.get(); absl::MutexLock lock(&mutex_); - cache_.emplace(config.ordinal, std::move(result)); + cache_.emplace(ordinal, std::move(result)); return returned_executor; } -absl::StatusOr ExecutorCache::Get( - const StreamExecutorConfig& config) { +absl::StatusOr ExecutorCache::Get(int ordinal) { absl::ReaderMutexLock lock{&mutex_}; - // If gpu stream is not nullptr we have to find StreamExecutor that owns it, - // and return NOT_FOUND error if we can't find it. - if (config.gpu_stream) { - for (auto& [ordinal, executor] : cache_) { - if (executor->FindAllocatedStream(config.gpu_stream)) { - return executor.get(); - } - } - return absl::NotFoundError( - absl::StrFormat("No executors own stream %p", config.gpu_stream)); - } - - if (auto it = cache_.find(config.ordinal); it != cache_.end()) { + if (auto it = cache_.find(ordinal); it != cache_.end()) { return it->second.get(); } - return absl::NotFoundError(absl::StrFormat( - "No executors registered for ordinal %d", config.ordinal)); + return absl::NotFoundError( + absl::StrFormat("No executors registered for ordinal %d", ordinal)); } } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/executor_cache.h b/third_party/xla/xla/stream_executor/executor_cache.h index ae62c6d49224d3..d4cf4b5e31441d 100644 --- a/third_party/xla/xla/stream_executor/executor_cache.h +++ b/third_party/xla/xla/stream_executor/executor_cache.h @@ -23,13 +23,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" namespace stream_executor { -// Forward declare. -class StreamExecutor; - // Utility class to allow Platform objects to manage cached StreamExecutors. // Thread-safe. class ExecutorCache { @@ -40,15 +37,15 @@ class ExecutorCache { ExecutorCache(); ~ExecutorCache(); - // Looks up 'config' in the cache. Returns a pointer to the existing executor, - // if already present, or creates it using 'factory', if it does not. - // Factories may be executed concurrently for different device ordinals. - absl::StatusOr GetOrCreate( - const StreamExecutorConfig& config, const ExecutorFactory& factory); + // Looks up 'ordinal' in the cache. Returns a pointer to the existing + // executor, if already present, or creates it using 'factory', if it does + // not. Factories may be executed concurrently for different device ordinals. + absl::StatusOr GetOrCreate(int ordinal, + const ExecutorFactory& factory); - // Returns a pointer to the described executor (if one with a matching config + // Returns a pointer to the described executor (if one with a matching ordinal // has been created), or a NOT_FOUND status. - absl::StatusOr Get(const StreamExecutorConfig& config); + absl::StatusOr Get(int ordinal); private: // Protects cache_. diff --git a/third_party/xla/xla/stream_executor/executor_cache_test.cc b/third_party/xla/xla/stream_executor/executor_cache_test.cc index 71e9f72a64cc3d..84bed1ecaf576b 100644 --- a/third_party/xla/xla/stream_executor/executor_cache_test.cc +++ b/third_party/xla/xla/stream_executor/executor_cache_test.cc @@ -16,112 +16,48 @@ limitations under the License. #include "xla/stream_executor/executor_cache.h" #include -#include -#include -#include +#include "absl/log/log.h" #include "xla/stream_executor/mock_stream_executor.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" namespace stream_executor { namespace { TEST(ExecutorCacheTest, GetOnEmptyCacheFails) { ExecutorCache cache; - StreamExecutorConfig config; - config.ordinal = 0; - EXPECT_FALSE(cache.Get(config).ok()); + EXPECT_FALSE(cache.Get(0).ok()); } -TEST(ExecutorCacheTest, GetViaStreamOnEmptyCacheFails) { +TEST(ExecutorCacheTest, GetReturnsExpectedExecutor) { ExecutorCache cache; - StreamExecutorConfig config; - config.ordinal = 0; - config.gpu_stream = reinterpret_cast(0x1234); - EXPECT_FALSE(cache.Get(config).ok()); -} - -TEST(ExecutorCacheTest, GetOrCreateConstructsAndRepeatedlyReturns) { - ExecutorCache cache; - StreamExecutorConfig config; - config.ordinal = 0; - StreamExecutor *created = nullptr; - auto factory = [&created]() { - auto executor = std::make_unique(); - created = executor.get(); - return executor; - }; - TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory)); - EXPECT_EQ(executor, created); - TF_ASSERT_OK_AND_ASSIGN(auto found, cache.GetOrCreate(config, factory)); - EXPECT_EQ(found, created); - TF_ASSERT_OK_AND_ASSIGN(found, cache.Get(config)); - EXPECT_EQ(found, created); -} - -TEST(ExecutorCacheTest, GetViaStreamFailsIfNotFound) { - ExecutorCache cache; - StreamExecutorConfig config; - config.ordinal = 0; - StreamExecutor *created = nullptr; - void *expected_stream = reinterpret_cast(0x1234); - auto factory = [&created, &expected_stream]() { + StreamExecutor *executor0 = nullptr; + StreamExecutor *executor1 = nullptr; + auto factory = [&executor0, &executor1]() { auto executor = std::make_unique(); - EXPECT_CALL(*executor, FindAllocatedStream(expected_stream)) - .WillRepeatedly(testing::Return(nullptr)); - created = executor.get(); - return executor; - }; - - // Create the executor. - TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory)); - EXPECT_EQ(executor, created); - // Now look for the expected stream, and don't expected to find it. - config.gpu_stream = expected_stream; - EXPECT_FALSE(cache.Get(config).ok()); -} - -TEST(ExecutorCacheTest, GetViaStreamWorksOnSecondStream) { - ExecutorCache cache; - StreamExecutorConfig config; - config.ordinal = 0; - StreamExecutor *created = nullptr; - Stream *expected_stream = reinterpret_cast(0x1234); - - // Create a factory that will make the second StreamExecutor find the - // expected_stream. - auto factory = [&created, &expected_stream]() { - static int count = 0; - auto executor = std::make_unique(); - if (count != 1) { - EXPECT_CALL(*executor, FindAllocatedStream(expected_stream)) - .WillRepeatedly(testing::Return(nullptr)); + if (executor0 == nullptr) { + executor0 = executor.get(); + } else if (executor1 == nullptr) { + executor1 = executor.get(); } else { - created = executor.get(); - EXPECT_CALL(*executor, FindAllocatedStream(expected_stream)) - .WillRepeatedly(testing::Invoke( - [expected_stream](void *stream) { return expected_stream; })); + LOG(FATAL) << "Bad call to factory."; } - ++count; return executor; }; - - // Create four executors. - std::vector created_executors; - for (int i = 0; i < 4; ++i) { - config.ordinal = i; - TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory)); - EXPECT_NE(executor, nullptr); - created_executors.push_back(executor); - } - EXPECT_EQ(created_executors.size(), 4); - // Now look for the expected stream, and expect to find it on the second - // stream. - config.gpu_stream = expected_stream; - TF_ASSERT_OK_AND_ASSIGN(auto found, cache.Get(config)); - EXPECT_EQ(found, created); + TF_ASSERT_OK_AND_ASSIGN(auto found, cache.GetOrCreate(0, factory)); + EXPECT_EQ(found, executor0); + TF_ASSERT_OK_AND_ASSIGN(found, cache.GetOrCreate(1, factory)); + EXPECT_EQ(found, executor1); + TF_ASSERT_OK_AND_ASSIGN(found, cache.GetOrCreate(0, factory)); + EXPECT_EQ(found, executor0); + TF_ASSERT_OK_AND_ASSIGN(found, cache.GetOrCreate(1, factory)); + EXPECT_EQ(found, executor1); + TF_ASSERT_OK_AND_ASSIGN(found, cache.Get(0)); + EXPECT_EQ(found, executor0); + TF_ASSERT_OK_AND_ASSIGN(found, cache.Get(1)); + EXPECT_EQ(found, executor1); } } // namespace diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 8e3c1e234d95e3..8b4fe379ddc162 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -220,7 +220,6 @@ gpu_only_cc_library( "//xla/stream_executor:host_memory_allocation", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", - "//xla/stream_executor:launch_dim", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:module_spec", "//xla/stream_executor:platform", @@ -228,7 +227,6 @@ gpu_only_cc_library( "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -243,8 +241,7 @@ gpu_only_cc_library( name = "gpu_helpers_header", hdrs = ["gpu_helpers.h"], deps = [ - ":gpu_types_header", - "@local_tsl//tsl/platform:logging", + "//xla/stream_executor:device_memory", ], ) @@ -314,6 +311,8 @@ gpu_only_cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", @@ -331,10 +330,13 @@ gpu_only_cc_library( ":gpu_driver_header", ":gpu_event_header", ":gpu_executor_header", + ":gpu_kernel_header", ":gpu_types_header", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", @@ -345,6 +347,7 @@ gpu_only_cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:nvtx_utils", ], ) @@ -377,7 +380,6 @@ gpu_only_cc_library( ":gpu_stream", ":gpu_types_header", "//xla/stream_executor", - "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -386,7 +388,6 @@ gpu_only_cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@com_google_absl//absl/utility", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -556,8 +557,6 @@ xla_test( name = "redzone_allocator_test", srcs = ["redzone_allocator_test.cc"], backends = ["gpu"], - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], deps = [ ":gpu_asm_opts", ":gpu_init", @@ -603,11 +602,7 @@ xla_test( name = "gpu_cudamallocasync_allocator_test", srcs = ["gpu_cudamallocasync_allocator_test.cc"], backends = ["gpu_any"], - tags = [ - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - "gpu", - "no_rocm", - ], + tags = ["no_rocm"], deps = [ ":gpu_cudamallocasync_allocator", ":gpu_stream", @@ -660,40 +655,89 @@ cc_library( gpu_kernel_library( name = "gpu_test_kernels", testonly = 1, - srcs = if_gpu_is_configured(["gpu_test_kernels.cu.cc"]), - hdrs = if_gpu_is_configured(["gpu_test_kernels.h"]), + srcs = ["gpu_test_kernels.cu.cc"], + hdrs = ["gpu_test_kernels.h"], + tags = ["gpu"], deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", - "//xla/stream_executor/rocm:add_i32_kernel", ]), ) +genrule( + name = "gpu_test_kernels_fatbin_extractor", + testonly = True, + srcs = [":gpu_test_kernels"], + outs = ["gpu_test_kernels.fatbin"], + cmd = """ + STATIC_LIBRARY="" + for src in $(SRCS); do + if [[ $$src == *.a ]]; then + STATIC_LIBRARY=$$src + break + fi + done + + if [[ -z $$STATIC_LIBRARY ]]; then + echo "No static library found in $(SRCS)" >&2 + exit 1 + fi + + $(OBJCOPY) "--dump-section=.nv_fatbin=$@" "$$STATIC_LIBRARY" || true + + if [ ! -f "$@" ]; then + # binutils' objcopy doesn't return a non-zero exit code if the + # section was not found, so we need to check for the file's existence instead. + $(OBJCOPY) "--dump-section=.hip_fatbin=$@" "$$STATIC_LIBRARY" + fi + """, + tags = ["gpu"], + toolchains = ["@bazel_tools//tools/cpp:current_cc_toolchain"], +) + +cc_library( + name = "gpu_test_kernels_fatbin", + testonly = True, + srcs = ["gpu_test_kernels_fatbin.cc"], + hdrs = ["gpu_test_kernels_fatbin.h"], + data = [":gpu_test_kernels_fatbin_extractor"], + local_defines = [ + "FATBIN_SRC=\\\"$(rootpath :gpu_test_kernels_fatbin_extractor)\\\"", + ], + tags = ["gpu"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + ], +) + xla_test( name = "gpu_kernel_test", - srcs = if_gpu_is_configured(["gpu_kernel_test.cc"]), + srcs = ["gpu_kernel_test.cc"], backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ ":gpu_test_kernels", + ":gpu_test_kernels_fatbin", "//xla/service:platform_util", "//xla/stream_executor", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/rocm:rocm_platform_id", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - ] + if_cuda([ - "//xla/stream_executor/cuda:cuda_platform", - ]) + if_rocm([ - "//xla/stream_executor/rocm:rocm_platform", - ]), + ], ) xla_test( @@ -739,8 +783,6 @@ xla_test( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], deps = [ "//xla/stream_executor", "//xla/stream_executor:device_memory", @@ -764,10 +806,9 @@ xla_test( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], deps = [ "//xla/stream_executor", + "//xla/stream_executor:stream_finder", "//xla/stream_executor/host:host_platform", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", @@ -788,8 +829,6 @@ xla_test( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], deps = [ "//xla/service:platform_util", "//xla/stream_executor:platform", @@ -818,10 +857,12 @@ xla_test( "//xla/tools/hlo_opt:gpu_specs/a6000.txtpb", "//xla/tools/hlo_opt:gpu_specs/h100_pcie.txtpb", "//xla/tools/hlo_opt:gpu_specs/h100_sxm.txtpb", + "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", "//xla/tools/hlo_opt:gpu_specs/p100.txtpb", "//xla/tools/hlo_opt:gpu_specs/v100.txtpb", ], deps = [ + "//xla/service:platform_util", "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:platform", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h index b4bf7ebd46d8dc..20caccbe18e62e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host_or_device_scalar.h" diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 0376a5a97b7796..6852cebf9a1014 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep @@ -60,11 +61,7 @@ static Platform* GpuPlatform() { static MultiKernelLoaderSpec GetAddI32KernelSpec() { MultiKernelLoaderSpec spec(/*arity=*/3); -#if defined(GOOGLE_CUDA) - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); -#elif defined(TENSORFLOW_USE_ROCM) - spec.AddCudaCubinInMemory(internal::kAddI32KernelModule, "add"); -#endif + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); return spec; } @@ -113,7 +110,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; @@ -183,7 +180,7 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { cast(bufs[2]), }); }); - spec.AddInProcessSymbol(internal::GetAddI32Ptrs3Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Ptrs3Kernel(), "AddI32Ptrs3"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Ptrs3::Create(executor, spec)); @@ -701,7 +698,7 @@ TEST(GpuCommandBufferTest, ConditionalIf) { TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; @@ -851,12 +848,12 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { // Load addition kernel. MultiKernelLoaderSpec add_spec(/*arity=*/3); - add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); // Load multiplication kernel. MultiKernelLoaderSpec mul_spec(/*arity=*/3); - mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); + mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "MulI32"); TF_ASSERT_OK_AND_ASSIGN(auto mul, MulI32Kernel::Create(executor, mul_spec)); int64_t length = 4; @@ -947,12 +944,12 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // Load addition kernel. MultiKernelLoaderSpec add_spec(/*arity=*/3); - add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); // Load multiplication kernel. MultiKernelLoaderSpec mul_spec(/*arity=*/3); - mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); + mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "MulI32"); TF_ASSERT_OK_AND_ASSIGN(auto mul, MulI32Kernel::Create(executor, mul_spec)); int64_t length = 4; @@ -1035,7 +1032,7 @@ TEST(GpuCommandBufferTest, ConditionalFor) { TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; @@ -1085,12 +1082,12 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { // Load addition kernel. MultiKernelLoaderSpec add_spec(/*arity=*/3); - add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); // Load inc_and_cmp kernel. MultiKernelLoaderSpec icmp_spec(/*arity=*/3); - icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); + icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "IncAndCmp"); TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, IncAndCmpKernel::Create(executor, icmp_spec)); @@ -1250,12 +1247,12 @@ TEST(GpuCommandBufferTest, ConditionalWhileInExecutionScope) { // Load addition kernel. MultiKernelLoaderSpec add_spec(/*arity=*/3); - add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); // Load inc_and_cmp kernel. MultiKernelLoaderSpec icmp_spec(/*arity=*/3); - icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); + icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "IncAndCmp"); TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, IncAndCmpKernel::Create(executor, icmp_spec)); @@ -1352,7 +1349,7 @@ static void BM_CreateCommandBuffer(benchmark::State& state) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); @@ -1375,7 +1372,7 @@ static void BM_TraceCommandBuffer(benchmark::State& state) { TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); @@ -1400,7 +1397,7 @@ static void BM_UpdateCommandBuffer(benchmark::State& state) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc index 9ecfe692fae457..b5ec38ff58ca5d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/container/flat_hash_map.h" +#include "xla/service/platform_util.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/platform.h" @@ -33,7 +34,7 @@ TEST(DeviceInfoTest, DeviceInfoMatches) { absl::flat_hash_map gpu_specs; for (const std::string file_name : {"a100_pcie_80", "a100_sxm_40", "a100_sxm_80", "a6000", "h100_pcie", - "h100_sxm", "p100", "v100"}) { + "h100_sxm", "p100", "v100", "mi200"}) { GpuTargetConfigProto proto; std::string spec_string; TF_ASSERT_OK(tsl::ReadFileToString( @@ -45,9 +46,10 @@ TEST(DeviceInfoTest, DeviceInfoMatches) { tsl::protobuf::TextFormat::ParseFromString(spec_string, &proto)); gpu_specs[proto.device_description_str()] = proto.gpu_device_info(); } - + auto name = absl::AsciiStrToUpper( + xla::PlatformUtil::CanonicalPlatformName("gpu").value()); TF_ASSERT_OK_AND_ASSIGN(Platform * platform, - PlatformManager::PlatformWithName("CUDA")); + PlatformManager::PlatformWithName(name)); bool all_skipped = false; for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 599480c13e92da..94cff4632638e1 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index f7dd572e918ccd..f7eab3beb9f626 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -1,4 +1,3 @@ -#include "xla/stream_executor/event_based_timer.h" /* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,10 +30,10 @@ limitations under the License. #include #include #include +#include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/functional/any_invocable.h" #include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -46,6 +45,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_driver.h" @@ -53,7 +53,6 @@ limitations under the License. #include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" @@ -138,15 +137,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::StatusOr> CreateOrShareConstant( Stream* stream, absl::Span content) override; - absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args) override; - - absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, const Kernel& kernel, - const KernelArgs& args) override; - DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; void Deallocate(DeviceMemoryBase* mem) override; @@ -316,11 +306,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status LoadModuleFromHsaco(const char* hsaco, GpuModuleHandle* module) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); - absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args); - bool UnloadGpuBinary(const void* gpu_binary) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h index 62db12705491bc..187d882c78369c 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h @@ -23,17 +23,10 @@ limitations under the License. #include -#include -#include - -#include "xla/stream_executor/gpu/gpu_types.h" -#include "tsl/platform/logging.h" +#include "xla/stream_executor/device_memory.h" namespace stream_executor { -template -class DeviceMemory; - namespace gpu { // Converts a const DeviceMemory reference to its underlying typed pointer in diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc index 9d93c1264d9128..fcb97ca7e790c7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc @@ -14,16 +14,25 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/service/platform_util.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/typed_kernel_factory.h" @@ -32,48 +41,75 @@ limitations under the License. #include "tsl/platform/test.h" namespace stream_executor::gpu { - -TEST(GpuKernelTest, Add) { - using AddI32Kernel = - TypedKernelFactory, DeviceMemory, - DeviceMemory>; - auto name = absl::AsciiStrToUpper( - xla::PlatformUtil::CanonicalPlatformName("gpu").value()); - Platform* platform = PlatformManager::PlatformWithName(name).value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); +namespace { + +using AddI32Kernel = + TypedKernelFactory, DeviceMemory, + DeviceMemory>; + +class GpuKernelTest : public ::testing::Test { + public: + void SetUp() override { + auto name = absl::AsciiStrToUpper( + xla::PlatformUtil::CanonicalPlatformName("gpu").value()); + Platform* platform = PlatformManager::PlatformWithName(name).value(); + executor_ = platform->ExecutorForDevice(0).value(); + } + + void RunAddI32Kernel(const MultiKernelLoaderSpec& spec) { + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor_->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor_, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor_->AllocateArray(length, 0); + DeviceMemory b = executor_->AllocateArray(length, 0); + DeviceMemory c = executor_->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Launch kernel. + ASSERT_TRUE( + stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c).ok()); + + // Copy data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); + } + + StreamExecutor* executor_; +}; + +TEST_F(GpuKernelTest, LoadAndRunKernelFromPtx) { + if (executor_->GetPlatform()->id() == + stream_executor::rocm::kROCmPlatformId) { + GTEST_SKIP() << "There is no PTX or any equivalent abstraction for ROCm."; + } MultiKernelLoaderSpec spec(/*arity=*/3); -#if defined(GOOGLE_CUDA) - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); -#elif defined(TENSORFLOW_USE_ROCM) - spec.AddCudaCubinInMemory(internal::kAddI32KernelModule, "add"); -#endif - - TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=2, c=0 - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); - TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); - TF_ASSERT_OK(stream->MemZero(&c, byte_length)); - - // Launch kernel. - ASSERT_TRUE(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c).ok()); + spec.AddCudaPtxInMemory(internal::kAddI32KernelPtx, "AddI32"); + RunAddI32Kernel(spec); +} - // Copy data back to host. - std::vector dst(4, 42); - TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); +TEST_F(GpuKernelTest, LoadAndRunKernelFromCubin) { + MultiKernelLoaderSpec spec(/*arity=*/3); + TF_ASSERT_OK_AND_ASSIGN(auto binary, GetGpuTestKernelsFatbin()); + spec.AddCudaCubinInMemory(binary, "AddI32"); + RunAddI32Kernel(spec); +} - std::vector expected = {3, 3, 3, 3}; - ASSERT_EQ(dst, expected); +TEST_F(GpuKernelTest, LoadAndRunKernelFromSymbol) { + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); + RunAddI32Kernel(spec); } +} // namespace } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index 706826553e4363..b257ffa0b675ec 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -31,10 +32,14 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/nvtx_utils.h" namespace stream_executor { @@ -195,6 +200,83 @@ GpuStream::CreateEventBasedTimer(bool use_delay_kernel) { return parent_->CreateEventBasedTimer(this, use_delay_kernel); } +absl::Status GpuStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, const Kernel& kernel, + const KernelArgs& args) { + return Launch(thread_dims, block_dims, std::nullopt, kernel, args); +} + +absl::Status GpuStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const ClusterDim& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { + return Launch(thread_dims, block_dims, std::make_optional(cluster_dims), + kernel, args); +} + +absl::Status GpuStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { + const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); + GpuFunctionHandle function = gpu_kernel->gpu_function(); + + if (gpu_kernel->cache_config() != KernelCacheConfig::kNoPreference) { + TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( + function, gpu_kernel->GetGpuCacheConfig())); + } + + // Launch kernels with packed arguments. + auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims, + &function](const KernelArgsPackedArrayBase& packed) { + int32_t expected_number_of_arguments = + kernel.Arity() + (packed.number_of_shared_bytes() > 0); + + CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) + << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() + << " arguments, but expected " << expected_number_of_arguments + << "; arity=" << kernel.Arity() + << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); + + void** params = const_cast(packed.argument_addresses().data()); + + if (cluster_dims.has_value()) { + return GpuDriver::LaunchKernel( + parent_->gpu_context(), kernel.name(), function, cluster_dims->x, + cluster_dims->y, cluster_dims->z, block_dims.x, block_dims.y, + block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), gpu_stream(), params, + /*extra=*/nullptr); + } else { + return GpuDriver::LaunchKernel( + parent_->gpu_context(), kernel.name(), function, block_dims.x, + block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, + thread_dims.z, packed.number_of_shared_bytes(), gpu_stream(), params, + /*extra=*/nullptr); + } + }; + + // If arguments are already packed we can just launch the kernel. + if (auto* packed = DynCast(&args)) { + return launch(*packed); + } + + // For device memory array we rely on a custom kernel arguments packing. + if (auto* device_mem = DynCast(&args)) { + auto& pack = kernel.args_packing(); + if (!pack) { + return absl::InternalError( + "Kernel is missing a custom arguments packing function for device " + "memory arguments array"); + } + + TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); + return launch(*packed); + } + + return absl::InternalError("Unsupported kernel arguments type"); +} + GpuStream* AsGpuStream(Stream* stream) { DCHECK(stream != nullptr); return static_cast(stream); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index 4cf21ca82207ed..249fbf78877a4e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -34,6 +34,8 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" @@ -103,8 +105,18 @@ class GpuStream : public StreamCommon { void set_name(absl::string_view name) override; absl::StatusOr> CreateEventBasedTimer( bool use_delay_kernel) override; + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const Kernel& k, const KernelArgs& args) override; + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const ClusterDim& cluster_dims, const Kernel& k, + const KernelArgs& args) override; private: + // Helper method to launch a kernel with optional cluster dimensions. + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args); + GpuExecutor* parent_; // Executor that spawned this stream. GpuStreamHandle gpu_stream_; // Wrapped CUDA stream handle. std::variant stream_priority_; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc index cab05701159ad9..b97771724d0ad6 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc @@ -17,8 +17,16 @@ limitations under the License. #include +#ifdef TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" +#endif + namespace stream_executor::gpu::internal { +// We want to be able to load those kernels by symbol name, so let's make them +// C functions. +extern "C" { + __global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) { int index = threadIdx.x + blockIdx.x * blockDim.x; c[index] = a[index] + b[index]; @@ -39,6 +47,7 @@ __global__ void AddI32Ptrs3(Ptrs3 ptrs) { int index = threadIdx.x + blockIdx.x * blockDim.x; ptrs.c[index] = ptrs.a[index] + ptrs.b[index]; } +} void* GetAddI32Kernel() { return reinterpret_cast(&AddI32); } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h index 74931452bb6624..dc143779389f56 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h @@ -23,11 +23,10 @@ namespace stream_executor::gpu::internal { // This is a collection of gpu kernels for writing simple StreamExecutor tests. // // Some of the kernels available as pre-compiled PTX blobs (can be loaded with -// CUDA driver API) / HSACO modules (can be loaded with ROCM driver api), and +// CUDA driver API), and // some of the kernels are written directly in CUDA C++ and can be loaded from a // symbol pointer (to test StreamExecutor CUDA runtime integration). -#if !defined(TENSORFLOW_USE_ROCM) // PTX kernel compiled from: // // __global__ void add(int* a, int* b, int* c) { @@ -36,24 +35,24 @@ namespace stream_executor::gpu::internal { // } // // Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kAddI32Kernel = R"( +inline constexpr std::string_view kAddI32KernelPtx = R"( .version 4.0 .target sm_50 .address_size 64 -.visible .entry add( - .param .u64 add_param_0, - .param .u64 add_param_1, - .param .u64 add_param_2 +.visible .entry AddI32( + .param .u64 AddI32_param_0, + .param .u64 AddI32_param_1, + .param .u64 AddI32_param_2 ) { .reg .b32 %r<8>; .reg .b64 %rd<11>; .loc 1 1 0 - ld.param.u64 %rd1, [add_param_0]; - ld.param.u64 %rd2, [add_param_1]; - ld.param.u64 %rd3, [add_param_2]; + ld.param.u64 %rd1, [AddI32_param_0]; + ld.param.u64 %rd2, [AddI32_param_1]; + ld.param.u64 %rd3, [AddI32_param_2]; .loc 1 3 3 cvta.to.global.u64 %rd4, %rd3; cvta.to.global.u64 %rd5, %rd2; @@ -75,9 +74,6 @@ inline constexpr std::string_view kAddI32Kernel = R"( ret; })"; -#else -#include "xla/stream_executor/rocm/add_i32_kernel.h" -#endif // !defined(TENSORFLOW_USE_ROCM) template struct Ptrs3 { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc new file mode 100644 index 00000000000000..da638565540cb2 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc @@ -0,0 +1,34 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" + +namespace stream_executor::gpu { + +absl::StatusOr> GetGpuTestKernelsFatbin() { + tsl::Env* env = tsl::Env::Default(); + std::string file_contents; + TF_RETURN_IF_ERROR(tsl::ReadFileToString(env, FATBIN_SRC, &file_contents)); + return std::vector(file_contents.begin(), file_contents.end()); +} +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.h b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.h new file mode 100644 index 00000000000000..803b8b3cab4b4f --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_FATBIN_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_FATBIN_H_ + +#include +#include + +#include "absl/status/statusor.h" + +namespace stream_executor::gpu { + +// Returns the NVIDIA or HIP fatbin for the :gpu_test_kernels target. +// The fatbin is being extracted at compile time from the compilation artifact. +// Note that this function will read the extracted fatbin from the file system +// at runtime and will only be able to succeed when the test is being invoked by +// `bazel test`. +absl::StatusOr> GetGpuTestKernelsFatbin(); + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_FATBIN_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h index be0f9a54a2af98..656dd1e9809490 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h @@ -21,12 +21,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" -#include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/stream.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/stream_executor/gpu/stream_search_test.cc b/third_party/xla/xla/stream_executor/gpu/stream_search_test.cc index c0f66159400039..c1e053e7914942 100644 --- a/third_party/xla/xla/stream_executor/gpu/stream_search_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/stream_search_test.cc @@ -16,7 +16,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_finder.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -37,19 +39,15 @@ class StreamSearchTest : public ::testing::Test { TEST_F(StreamSearchTest, NoMatchBadPtr) { void* bad_ptr = reinterpret_cast(0xdeadbeef); - StreamExecutorConfig config; - config.gpu_stream = bad_ptr; - - absl::StatusOr found_executor = - GetPlatform()->GetExecutor(config); - - // No executor found. - EXPECT_FALSE(found_executor.ok()); + EXPECT_FALSE(FindStream(GetPlatform(), bad_ptr).ok()); } TEST_F(StreamSearchTest, FoundPrevExecutor) { - TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, - GetPlatform()->ExecutorForDevice(0)); + int number_devices = GetPlatform()->VisibleDeviceCount(); + EXPECT_GT(number_devices, 0); + TF_ASSERT_OK_AND_ASSIGN( + StreamExecutor * executor, + GetPlatform()->ExecutorForDevice(number_devices > 1 ? 1 : 0)); TF_ASSERT_OK_AND_ASSIGN(auto s, executor->CreateStream()); TF_ASSERT_OK_AND_ASSIGN(auto s2, executor->CreateStream()); @@ -57,17 +55,10 @@ TEST_F(StreamSearchTest, FoundPrevExecutor) { void* gpu_ptr = s->platform_specific_handle().stream; void* gpu_ptr_2 = s2->platform_specific_handle().stream; - StreamExecutorConfig c; - c.gpu_stream = gpu_ptr; - - TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * found_executor, - GetPlatform()->GetExecutor(c)); - EXPECT_EQ(found_executor, executor); - - Stream* found1 = found_executor->FindAllocatedStream(gpu_ptr); + TF_ASSERT_OK_AND_ASSIGN(Stream * found1, FindStream(GetPlatform(), gpu_ptr)); EXPECT_EQ(found1, s.get()); - - Stream* found2 = found_executor->FindAllocatedStream(gpu_ptr_2); + TF_ASSERT_OK_AND_ASSIGN(Stream * found2, + FindStream(GetPlatform(), gpu_ptr_2)); EXPECT_EQ(found2, s2.get()); } diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index 326b3d60fe12e3..a03a21ceb5592b 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -81,10 +81,14 @@ cc_library( ], deps = [ ":host_event", + ":host_kernel", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", @@ -141,7 +145,6 @@ xla_cc_test( "//xla/stream_executor:kernel_spec", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -166,20 +169,21 @@ cc_library( ":host_event", ":host_kernel", ":host_stream", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:host_memory_allocation", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_common", - "//xla/stream_executor:stream_executor_h", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index 38715ce56ed3db..7e2ac758903eed 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -22,27 +22,27 @@ limitations under the License. #include #include +#include #include #include +#include #include -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/notification.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_event.h" #include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/host/host_stream.h" +#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/mem.h" @@ -91,26 +91,6 @@ absl::StatusOr> HostExecutor::LoadKernel( return absl::InternalError("No method of loading host kernel provided"); } -absl::Status HostExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, - const KernelArgs& args) { - const HostKernel* host_kernel = AsHostKernel(&kernel); - - const KernelArgsDeviceMemoryArray* device_mem = - DynCast(&args); - - absl::Status result; - if (device_mem != nullptr) { - result = host_kernel->Launch(thread_dims, device_mem->device_memory_args()); - } else { - result = absl::UnimplementedError( - "Host kernel implements Launch method only for DeviceMemoryArray " - "arguments."); - } - return result; -} - bool HostExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { tsl::port::MemoryInfo mem_info = tsl::port::GetMemoryInfo(); *free = (mem_info.free != INT64_MAX) ? mem_info.free : -1; diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index 7ab168d29f9d6d..55eacc5fff4851 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -13,20 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Declares the HostExecutor class, which is a CPU-only implementation of -// the StreamExecutor interface. For now, this is used for testing and to -// examine the performance of host-based StreamExecutor code. #ifndef XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ #define XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ -#include #include #include #include #include #include -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/stream_executor/device_description.h" @@ -36,24 +31,21 @@ limitations under the License. #include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_common.h" #include "tsl/platform/threadpool.h" namespace stream_executor { namespace host { -// An implementation of StreamExecutor that does no communication or interaction -// with a device, but DOES perform memory operations backed by the host. -// Kernel invocations will fail, but host callbacks may be enqueued on this -// executor and its associated stream, and should follow standard ordering -// semantics. +// Declares the HostExecutor class, which is a CPU-only implementation of +// the StreamExecutor interface. For now, this is used for testing and to +// examine the performance of host-based StreamExecutor code. // // This is useful for evaluating the performance of host-based or fallback // routines executed under the context of a GPU executor. -// See stream_executor.h for description of the below operations. class HostExecutor : public StreamExecutorCommon { public: // A function that loads a kernel function from a given spec. If spec is not @@ -73,10 +65,6 @@ class HostExecutor : public StreamExecutorCommon { absl::StatusOr> LoadKernel( const MultiKernelLoaderSpec& spec) override; - absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args) override; - DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; void Deallocate(DeviceMemoryBase* mem) override; @@ -88,7 +76,6 @@ class HostExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } absl::Status SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc index d92e25ec9a65a9..4e766fc92158d5 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" @@ -90,9 +91,9 @@ define ptr @LlvmAddI32(ptr noundef %0) { )"; static absl::StatusOr NewStreamExecutor() { - StreamExecutorConfig config(/*ordinal=*/0); TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host")); - TF_ASSIGN_OR_RETURN(auto stream_exec, platform->GetExecutor(config)); + TF_ASSIGN_OR_RETURN(auto stream_exec, + platform->ExecutorForDevice(/*ordinal=*/0)); return stream_exec; } diff --git a/third_party/xla/xla/stream_executor/host/host_platform.cc b/third_party/xla/xla/stream_executor/host/host_platform.cc index c9a12709d70f22..b70ea46fa25825 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.cc +++ b/third_party/xla/xla/stream_executor/host/host_platform.cc @@ -52,25 +52,18 @@ HostPlatform::DescriptionForDevice(int ordinal) const { } absl::StatusOr HostPlatform::ExecutorForDevice(int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); -} - -absl::StatusOr HostPlatform::GetExecutor( - const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> -HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = std::make_unique(this, config.ordinal); +HostPlatform::GetUncachedExecutor(int ordinal) { + auto executor = std::make_unique(this, ordinal); auto init_status = executor->Init(); if (!init_status.ok()) { return absl::InternalError(absl::StrFormat( - "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())); + "failed initializing StreamExecutor for device ordinal %d: %s", ordinal, + init_status.ToString().c_str())); } return std::move(executor); diff --git a/third_party/xla/xla/stream_executor/host/host_platform.h b/third_party/xla/xla/stream_executor/host/host_platform.h index 3dd90a6878bb9b..b8ce8f4340d6c4 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.h +++ b/third_party/xla/xla/stream_executor/host/host_platform.h @@ -51,15 +51,12 @@ class HostPlatform : public Platform { absl::StatusOr ExecutorForDevice(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; - private: - // Returns a device constructed with the options specified in "config" without + // Returns a device constructed with ordinal without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config); + int ordinal); // This platform's name. std::string name_; diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index ed6e040431e478..76b66711e03d62 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -33,6 +33,9 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_event.h" +#include "xla/stream_executor/host/host_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" #include "tsl/platform/denormal.h" @@ -192,6 +195,21 @@ absl::Status HostStream::BlockUntilDone() { return status; } -} // namespace host +absl::Status HostStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const Kernel& kernel, const KernelArgs& args) { + const HostKernel* host_kernel = AsHostKernel(&kernel); + + const KernelArgsDeviceMemoryArray* device_mem = + DynCast(&args); + + if (device_mem != nullptr) { + return host_kernel->Launch(thread_dims, device_mem->device_memory_args()); + } + return absl::UnimplementedError( + "Host kernel implements Launch method only for DeviceMemoryArray " + "arguments."); +} +} // namespace host } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index ed1bbc2011f48f..a43ba610e25417 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -13,12 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Class declaration for Stream type that enqueues tasks onto a host/CPU-based -// execution context (as opposed to a GPU device), HostExecutor. #ifndef XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_ #define XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_ -#include +#include #include #include @@ -27,13 +25,20 @@ limitations under the License. #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/env.h" #include "tsl/platform/thread_annotations.h" namespace stream_executor { namespace host { +// Class declaration for Stream type that enqueues tasks onto a host/CPU-based +// execution context (as opposed to a GPU device), HostExecutor. class HostStream : public StreamCommon { public: explicit HostStream(StreamExecutor* executor); @@ -65,6 +70,8 @@ class HostStream : public StreamCommon { uint64_t size) override; absl::Status DoHostCallbackWithStatus( absl::AnyInvocable callback) override; + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const Kernel& kernel, const KernelArgs& args) override; private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/third_party/xla/xla/stream_executor/kernel_test.cc b/third_party/xla/xla/stream_executor/kernel_test.cc index aa83971b3e946d..a554785735d3cd 100644 --- a/third_party/xla/xla/stream_executor/kernel_test.cc +++ b/third_party/xla/xla/stream_executor/kernel_test.cc @@ -68,8 +68,7 @@ static_assert( static StreamExecutor* NewStreamExecutor() { Platform* platform = PlatformManager::PlatformWithName("Host").value(); - StreamExecutorConfig config(/*ordinal=*/0); - return platform->GetExecutor(config).value(); + return platform->ExecutorForDevice(/*ordinal=*/0).value(); } TEST(KernelTest, PackDeviceMemoryArguments) { diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h index c74a03e1ad5226..bf964e05bbaae6 100644 --- a/third_party/xla/xla/stream_executor/lazy_op_runner.h +++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h @@ -280,76 +280,6 @@ struct FusedMatmulOp { } }; -struct FusedMHAOp { - using Signature = FusedMHASignature; - struct Config { - double scale; - const MatmulTensorDescriptor& bmm1_lhs_descriptor; - const MatmulTensorDescriptor& bmm1_rhs_descriptor; - const MatmulTensorDescriptor& bmm2_rhs_descriptor; - const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor; - const TensorDescriptor& output_descriptor; - std::optional bias_descriptor; - std::optional activation_descriptor; - std::optional dropout_rate; - std::optional seed; - FMHAMaskKind mask_type; - }; - - static absl::StatusOr>> - RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, - Stream* stream) { - TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); - return dnn->FusedMHARunnerFromDesc( - stream, desc, config.bmm1_lhs_descriptor, config.bmm1_rhs_descriptor, - config.bmm2_rhs_descriptor, config.intermediate_bmm2_lhs_descriptor, - config.output_descriptor, config.activation_descriptor, - config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.mask_type); - } -}; - -struct FusedMHABackwardOp { - using Signature = FusedMHABackwardSignature; - - struct Config { - double scale; - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor; - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor; - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor; - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor; - const MatmulTensorDescriptor& d_output_descriptor; - const TensorDescriptor& d_bmm1_lhs_descriptor; - const TensorDescriptor& d_bmm1_rhs_descriptor; - const TensorDescriptor& d_bmm2_rhs_descriptor; - std::optional d_s_descriptor; - std::optional d_bias_descriptor; - std::optional fwd_output_descriptor; - std::optional bias_descriptor; - std::optional dropout_rate; - std::optional seed; - FMHAMaskKind mask_type; - bool force_deterministic; - }; - - static absl::StatusOr< - std::unique_ptr>> - RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, - Stream* stream) { - TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); - return dnn->FusedMHABackwardRunnerFromDesc( - stream, desc, config.bmm1_grad_gemm1_rhs_descriptor, - config.bmm1_grad_gemm2_rhs_descriptor, - config.bmm2_grad_gemm1_lhs_descriptor, - config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor, - config.d_bmm1_lhs_descriptor, config.d_bmm1_rhs_descriptor, - config.d_bmm2_rhs_descriptor, config.d_s_descriptor, - config.d_bias_descriptor, config.fwd_output_descriptor, - config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.mask_type, config.force_deterministic); - } -}; - } // namespace dnn } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/mock_platform.h b/third_party/xla/xla/stream_executor/mock_platform.h new file mode 100644 index 00000000000000..7c8e11dcabe7dc --- /dev/null +++ b/third_party/xla/xla/stream_executor/mock_platform.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_MOCK_PLATFORM_H_ +#define XLA_STREAM_EXECUTOR_MOCK_PLATFORM_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/test.h" + +namespace stream_executor { + +// Implements the Platform interface for testing. +class MockPlatform : public Platform { + public: + MockPlatform() = default; + MOCK_METHOD(Id, id, (), (const, override)); + MOCK_METHOD(const std::string&, Name, (), (const, override)); + MOCK_METHOD(int, VisibleDeviceCount, (), (const, override)); + MOCK_METHOD(bool, Initialized, (), (const, override)); + MOCK_METHOD(absl::Status, Initialize, (), (override)); + MOCK_METHOD(absl::StatusOr>, + DescriptionForDevice, (int ordinal), (const, override)); + MOCK_METHOD(absl::StatusOr, ExecutorForDevice, (int ordinal), + (override)); + MOCK_METHOD(absl::StatusOr, FindExisting, (int ordinal), + (override)); +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_MOCK_PLATFORM_H_ diff --git a/third_party/xla/xla/stream_executor/mock_stream.h b/third_party/xla/xla/stream_executor/mock_stream.h new file mode 100644 index 00000000000000..5e9750e124caaa --- /dev/null +++ b/third_party/xla/xla/stream_executor/mock_stream.h @@ -0,0 +1,94 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_MOCK_STREAM_H_ +#define XLA_STREAM_EXECUTOR_MOCK_STREAM_H_ + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/test.h" + +namespace stream_executor { + +// Implements the Stream interface for testing. +class MockStream : public Stream { + public: + MockStream() = default; + MOCK_METHOD(PlatformSpecificHandle, platform_specific_handle, (), + (const, override)); + MOCK_METHOD(bool, ok, (), (const, override)); + MOCK_METHOD(absl::Status, RefreshStatus, (), (override)); + MOCK_METHOD(absl::StatusOr, GetOrCreateSubStream, (), (override)); + MOCK_METHOD(void, ReturnSubStream, (Stream * sub_stream), (override)); + MOCK_METHOD(absl::Status, WaitFor, (Stream * other), (override)); + MOCK_METHOD(absl::Status, WaitFor, (Event * event), (override)); + MOCK_METHOD(absl::Status, RecordEvent, (Event * event), (override)); + MOCK_METHOD(absl::Status, Memcpy, + (void *host_dst, const DeviceMemoryBase &gpu_src, uint64_t size), + (override)); + MOCK_METHOD(absl::Status, Memcpy, + (DeviceMemoryBase * gpu_dst, const void *host_src, uint64_t size), + (override)); + MOCK_METHOD(absl::Status, Memcpy, + (DeviceMemoryBase * gpu_dst, const DeviceMemoryBase &gpu_src, + uint64_t size), + (override)); + MOCK_METHOD(absl::Status, MemZero, + (DeviceMemoryBase * location, uint64_t size), (override)); + MOCK_METHOD(absl::Status, Memset32, + (DeviceMemoryBase * location, uint32_t pattern, uint64_t size), + (override)); + MOCK_METHOD(absl::Status, BlockHostUntilDone, (), (override)); + MOCK_METHOD(absl::Status, DoHostCallbackWithStatus, + (absl::AnyInvocable callback), (override)); + MOCK_METHOD(StreamExecutor *, parent, (), (const, override)); + MOCK_METHOD(CudaComputeCapability, GetCudaComputeCapability, (), + (const, override)); + MOCK_METHOD(RocmComputeCapability, GetRocmComputeCapability, (), + (const, override)); + MOCK_METHOD((std::variant), priority, (), + (const, override)); + MOCK_METHOD(absl::Status, Launch, + (const ThreadDim &thread_dims, const BlockDim &block_dims, + const Kernel &k, const KernelArgs &args), + (override)); + MOCK_METHOD(absl::Status, Launch, + (const ThreadDim &thread_dims, const BlockDim &block_dims, + const ClusterDim &cluster_dims, const Kernel &k, + const KernelArgs &args), + (override)); + MOCK_METHOD(absl::string_view, name, (), (const, override)); + MOCK_METHOD(void, set_name, (absl::string_view name), (override)); + MOCK_METHOD(absl::StatusOr>, + CreateEventBasedTimer, (bool use_delay_kernel), (override)); +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_MOCK_STREAM_H_ diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index 9e4cdc08fcf62f..0379c2c068dc18 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -1,3 +1,6 @@ +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/fft.h" /* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,7 +25,6 @@ limitations under the License. #include #include -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -33,7 +35,6 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" @@ -43,16 +44,6 @@ limitations under the License. namespace stream_executor { -namespace fft { -class FftSupport; -} -namespace dnn { -class DnnSupport; -} -namespace blas { -class BlasSupport; -} - // Implements StreamExecutor for testing. class MockStreamExecutor : public StreamExecutor { public: @@ -68,16 +59,6 @@ class MockStreamExecutor : public StreamExecutor { MOCK_METHOD(absl::StatusOr>, CreateOrShareConstant, (Stream * stream, absl::Span content), (override)); - MOCK_METHOD(absl::Status, Launch, - (Stream * stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& k, - const KernelArgs& args), - (override)); - MOCK_METHOD(absl::Status, Launch, - (Stream * stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const ClusterDim& cluster_dims, - const Kernel& k, const KernelArgs& args), - (override)); MOCK_METHOD(DeviceMemoryBase, Allocate, (uint64_t size, int64_t memory_space), (override)); MOCK_METHOD(void, Deallocate, (DeviceMemoryBase * mem), (override)); diff --git a/third_party/xla/xla/stream_executor/platform.cc b/third_party/xla/xla/stream_executor/platform.cc index 47bf5600d20297..9e8d4a8065c8b9 100644 --- a/third_party/xla/xla/stream_executor/platform.cc +++ b/third_party/xla/xla/stream_executor/platform.cc @@ -32,11 +32,6 @@ std::string StreamPriorityToString(StreamPriority priority) { } } -StreamExecutorConfig::StreamExecutorConfig() : ordinal(-1) {} - -StreamExecutorConfig::StreamExecutorConfig(int ordinal_in) - : ordinal(ordinal_in) {} - bool Platform::Initialized() const { return true; } absl::Status Platform::Initialize() { return absl::OkStatus(); } diff --git a/third_party/xla/xla/stream_executor/platform.h b/third_party/xla/xla/stream_executor/platform.h index cba96a6fc564b9..759a4c0acde70f 100644 --- a/third_party/xla/xla/stream_executor/platform.h +++ b/third_party/xla/xla/stream_executor/platform.h @@ -29,7 +29,6 @@ limitations under the License. namespace stream_executor { class StreamExecutor; -class DeviceDescription; // An enum to represent different levels of stream priorities. // This is to avoid platform-specific representations in abstractions. @@ -38,23 +37,6 @@ enum class StreamPriority { Default = 0, Lowest, Highest }; // Returns a printable description of StreamPriority. std::string StreamPriorityToString(StreamPriority priority); -// StreamExecutorConfig encapsulates the set of options for constructing a -// StreamExecutor for a given platform. -struct StreamExecutorConfig { - // Sets members to defaults: -1 for ordinal (must be changed). - StreamExecutorConfig(); - - // Simple ordinal-setting constructor. - explicit StreamExecutorConfig(int ordinal); - - // The GPU stream for which we are searching the executor. - // If this field is specified for the search, others will be ignored. - void* gpu_stream = nullptr; - - // The ordinal of the device to be managed by the returned StreamExecutor. - int ordinal; -}; - // Abstract base class for a platform registered with the PlatformManager. class Platform { public: @@ -105,19 +87,20 @@ class Platform { virtual absl::StatusOr> DescriptionForDevice(int ordinal) const = 0; - // Returns a device with the given ordinal on this platform with a default - // plugin configuration or, if none can be found with the given ordinal or - // there is an error in opening a context to communicate with the device, an - // error status is returned. + // Returns a StreamExecutor for the given ordinal if one has already been + // created, or an error is returned if none exists. Does not create a new + // context with the device. + virtual absl::StatusOr FindExisting(int ordinal) { + return absl::NotFoundError("Not implemented for this platform."); + } + + // Returns a device with the given ordinal on this platform or, if none can + // be found with the given ordinal or there is an error in opening a context + // to communicate with the device, an error status is returned. // // Ownership of the executor is NOT transferred to the caller -- // the Platform owns the executors in a singleton-like fashion. virtual absl::StatusOr ExecutorForDevice(int ordinal) = 0; - - // Returns a device constructed with the options specified in "config". - // Ownership of the executor is NOT transferred to the caller. - virtual absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) = 0; }; } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 7204036aa46399..1fbff0912d4b06 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -8,11 +8,6 @@ load( "//xla/stream_executor:build_defs.bzl", "stream_executor_friends", ) - -# copybara:comment_begin(oss-only) -load("//xla/stream_executor/rocm:build_defs.bzl", "rocm_embedded_test_modules") - -# copybara:comment_end load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_hipblaslt", @@ -705,10 +700,3 @@ cc_library( [":all_runtime"], ), ) - -# copybara:comment_begin(oss-only) -rocm_embedded_test_modules( - name = "add_i32_kernel", - srcs = if_rocm_is_configured(["add_i32_kernel.cu.cc"]), -) -# copybara:comment_end diff --git a/third_party/xla/xla/stream_executor/rocm/build_defs.bzl b/third_party/xla/xla/stream_executor/rocm/build_defs.bzl deleted file mode 100644 index 0be87739c8469f..00000000000000 --- a/third_party/xla/xla/stream_executor/rocm/build_defs.bzl +++ /dev/null @@ -1,68 +0,0 @@ -""" ROCM-specific build macros. -""" - -load("@local_config_rocm//rocm:build_defs.bzl", "rocm_gpu_architectures") - -def rocm_embedded_test_modules(name, srcs, testonly = True, **kwargs): - """Compile srcs into hsaco files and create a header only cc_library. - - Binary files are embedded as constant data. - - Args: - name: name for the generated cc_library target, and the base name for - generated header file - srcs: source files for input modules - testonly: If True, the target can only be used with tests. - **kwargs: keyword arguments passed onto the generated cc_library() rule. - """ - - # Lets piggyback this on top crosstool wrapper for now - hipcc_tool = "@local_config_rocm//crosstool:crosstool_wrapper_driver_is_not_gcc" - target_opts = " ".join(["--amdgpu-target=" + - arch for arch in rocm_gpu_architectures()]) - - header_file = "%s.h" % name - - native.genrule( - name = name + "_header_file", - srcs = srcs, - outs = [header_file], - cmd = """ - tmp_name_for_xxd() { - local filename=$$(basename $$1) - local name="k" - for word in $$(echo $${filename%%%%.*} | tr '_' ' '); do - name="$$name$${word^}" - done - echo "$${name}Module" - } - - echo '#pragma once' > $@ - echo '#include ' >> $@ - for src in $(SRCS); do - tmp=$$(tmp_name_for_xxd $$src); - $(location %s) -x rocm %s --genco -c $$src -o $$tmp && xxd -i $$tmp | sed \ - -e 's/unsigned char/inline constexpr uint8_t/g' \ - -e '$$d' >> $@; - rm -f $$tmp - done - """ % (hipcc_tool, target_opts), - tools = [hipcc_tool], - testonly = testonly, - target_compatible_with = select({ - "@local_config_rocm//rocm:using_hipcc": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - ) - - native.cc_library( - name = name, - srcs = [], - hdrs = [header_file], - testonly = testonly, - target_compatible_with = select({ - "@local_config_rocm//rocm:using_hipcc": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - **kwargs - ) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index 465dbbe84b2a00..2f61eae925d846 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -1148,6 +1148,21 @@ struct BitPatternToValue { return absl::OkStatus(); } +absl::Status GpuDriver::LaunchKernel( + GpuContext* context, absl::string_view kernel_name, + GpuFunctionHandle function, unsigned int cluster_dim_x, + unsigned int cluster_dim_y, unsigned int cluster_dim_z, + unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, + unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, unsigned int shared_mem_bytes, + GpuStreamHandle stream, void** kernel_params, void** extra) { + if (cluster_dim_x != 1 || cluster_dim_y != 1 || cluster_dim_z != 1) + return absl::UnimplementedError("Not implemented for ROCm"); + return LaunchKernel(context, kernel_name, function, grid_dim_x, grid_dim_y, + grid_dim_z, block_dim_x, block_dim_y, block_dim_z, + shared_mem_bytes, stream, kernel_params, extra); +} + /* static */ absl::Status GpuDriver::LoadPtx(GpuContext* context, const char* ptx_contents, hipModule_t* module) { diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index cf9fe323a9c939..76a4db74ae7a56 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -333,58 +333,6 @@ absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, return absl::OkStatus(); } -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, const KernelArgs& args) { - GpuStreamHandle hipstream = AsGpuStreamValue(stream); - const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); - hipFunction_t hipfunc = rocm_kernel->gpu_function(); - - if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) { - TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( - hipfunc, rocm_kernel->GetGpuCacheConfig())); - } - - auto launch = [&](const KernelArgsPackedArrayBase& packed) { - CHECK_EQ(kernel.Arity() + (args.number_of_shared_bytes() > 0), - packed.number_of_arguments()); - - void** kernel_params = - const_cast(packed.argument_addresses().data()); - - return GpuDriver::LaunchKernel( - GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x, - block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - args.number_of_shared_bytes(), hipstream, kernel_params, nullptr); - }; - - auto* packed_args = DynCast(&args); - if (packed_args) return launch(*packed_args); - - if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.args_packing(); - if (!pack) { - return absl::InternalError( - "Kernel is missing a custom arguments packing function for device " - "memory arguments array"); - } - - TF_ASSIGN_OR_RETURN(auto packed_args, pack(kernel, *device_mem)); - return launch(*packed_args); - } - - return absl::InternalError("Unsupported kernel arguments type"); -} - -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - if (cluster_dims.x != 1 || cluster_dims.y != 1 || cluster_dims.z != 1) - return absl::UnimplementedError("Not implemented for ROCm"); - return Launch(stream, thread_dims, block_dims, kernel, args); -} - absl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle) { // In GpuExecutor we store the pointer to the HSACO binary as diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc index ef7bc09be0c6e7..97413a6347584d 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc @@ -53,26 +53,17 @@ ROCmPlatform::DescriptionForDevice(int ordinal) const { } absl::StatusOr ROCmPlatform::ExecutorForDevice(int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); + return executor_cache_.GetOrCreate( + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } -absl::StatusOr ROCmPlatform::GetExecutor( - const StreamExecutorConfig& config) { - if (config.gpu_stream) { - // If the GPU stream was provided, it's not possible to get-or-create a - // stream with a required pointer: so we are looking for previously - // allocated streams. - return executor_cache_.Get(config); - } - return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); +absl::StatusOr ROCmPlatform::FindExisting(int ordinal) { + return executor_cache_.Get(ordinal); } absl::StatusOr> -ROCmPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = std::make_unique(this, config.ordinal); +ROCmPlatform::GetUncachedExecutor(int ordinal) { + auto executor = std::make_unique(this, ordinal); TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h index 2a4c6330a2d6d1..6888b64532c0dd 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h @@ -54,16 +54,14 @@ class ROCmPlatform : public Platform { int ordinal) const override; absl::StatusOr ExecutorForDevice(int ordinal) override; - - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; + absl::StatusOr FindExisting(int ordinal) override; private: - // Returns a device constructed with the options specified in "config" without + // Returns a device constructed with ordinal without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config); + int ordinal); // This platform's name. std::string name_; diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 71cdf9a35b8da7..0fccf94270a85c 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -272,7 +272,9 @@ class Stream { // platform driver. virtual absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &k, - const KernelArgs &args) = 0; + const KernelArgs &args) { + return absl::UnimplementedError("Not implemented"); + } // Launches a data parallel kernel with the given thread/block // dimensionality and already-packed args/sizes to pass to the underlying @@ -280,7 +282,9 @@ class Stream { virtual absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const ClusterDim &cluster_dims, const Kernel &k, - const KernelArgs &args) = 0; + const KernelArgs &args) { + return absl::UnimplementedError("Not implemented"); + } // Get/set a name for a stream, which can be shown in profiling tools virtual absl::string_view name() const = 0; diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index 048623da37c01a..e7833bfd25dab2 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -21,13 +21,10 @@ limitations under the License. #include #include -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/blas.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" @@ -40,19 +37,6 @@ StreamCommon::StreamCommon(StreamExecutor *parent) CHECK_NE(parent, nullptr); } -absl::Status StreamCommon::Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, const Kernel &k, - const KernelArgs &args) { - return parent_->Launch(this, thread_dims, block_dims, k, args); -} - -absl::Status StreamCommon::Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, - const ClusterDim &cluster_dims, - const Kernel &k, const KernelArgs &args) { - return parent_->Launch(this, thread_dims, block_dims, cluster_dims, k, args); -} - StreamCommon::PlatformSpecificHandle StreamCommon::platform_specific_handle() const { PlatformSpecificHandle handle; diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index 3d2ade72ff12e3..f7029c72fbadbf 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -28,7 +28,6 @@ limitations under the License. #include #include "absl/base/thread_annotations.h" -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -36,8 +35,6 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -84,11 +81,6 @@ class StreamCommon : public Stream { std::variant priority() const override { return StreamPriority::Default; } - absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, - const Kernel &k, const KernelArgs &args) override; - absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, - const ClusterDim &cluster_dims, const Kernel &k, - const KernelArgs &args) override; // Doesn't do anything interesting by default; GpuStream connects this to NVTX absl::string_view name() const override { return name_; } @@ -107,8 +99,6 @@ class StreamCommon : public Stream { // Checks the status and logs the error message, if any. void CheckStatus(absl::Status status) TF_LOCKS_EXCLUDED(mu_); - void SetError() { CheckError(false /* = operation_retcode */); } - std::string name_; private: diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index f5fb436dfd2274..a0c3c48e521e30 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -1,5 +1,3 @@ -#include "absl/functional/any_invocable.h" -#include "absl/log/log.h" /* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,12 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The StreamExecutor is a single-device abstraction for: -// -// * Loading/launching data-parallel-kernels -// * Invoking pre-canned high-performance library routines (like matrix -// multiply) - #ifndef XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ #define XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ @@ -31,6 +23,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -43,7 +36,6 @@ limitations under the License. #include "xla/stream_executor/fft.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" @@ -67,6 +59,12 @@ inline std::string MemoryTypeString(MemoryType memory_type) { } } +/// The StreamExecutor is a single-device abstraction for: +// +// * Loading/launching data-parallel-kernels +// * Invoking pre-canned high-performance library routines (like matrix +// multiply) +// // Interface which defines the method for interacting with an accelerator device // (e.g. GPU, TPU). class StreamExecutor { @@ -138,26 +136,6 @@ class StreamExecutor { return absl::UnimplementedError("Not Implemented"); } - // Launches a data parallel kernel with the given thread/block - // dimensionality and already-packed args/sizes to pass to the underlying - // platform driver. - - virtual absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& k, - const KernelArgs& args) { - return absl::UnimplementedError("Not Implemented"); - } - - // Launches a data parallel kernel with the given thread/block - // dimensionality and already-packed args/sizes to pass to the underlying - // platform driver. - virtual absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, const Kernel& k, - const KernelArgs& args) { - return absl::UnimplementedError("Not Implemented"); - } - // Synchronously allocates size bytes on the underlying platform and returns // a DeviceMemoryBase representing that allocation. In the case of failure, // nullptr is returned. diff --git a/third_party/xla/xla/stream_executor/stream_executor_test.cc b/third_party/xla/xla/stream_executor/stream_executor_test.cc index 5f89dc6c8f0ac0..9a2ca572534db8 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_test.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_test.cc @@ -26,9 +26,8 @@ limitations under the License. namespace stream_executor { static absl::StatusOr NewStreamExecutor() { - StreamExecutorConfig config(/*ordinal=*/0); TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host")); - TF_ASSIGN_OR_RETURN(auto stream_exec, platform->GetExecutor(config)); + TF_ASSIGN_OR_RETURN(auto stream_exec, platform->ExecutorForDevice(0)); return stream_exec; } diff --git a/third_party/xla/xla/stream_executor/stream_finder.cc b/third_party/xla/xla/stream_executor/stream_finder.cc new file mode 100644 index 00000000000000..e9cdcf02c8c65c --- /dev/null +++ b/third_party/xla/xla/stream_executor/stream_finder.cc @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/stream_finder.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" + +namespace stream_executor { + +absl::StatusOr FindStream(Platform* platform, void* gpu_stream) { + int number_devices = platform->VisibleDeviceCount(); + for (int i = 0; i < number_devices; ++i) { + auto stream_executor = platform->FindExisting(i); + if (!stream_executor.ok()) { + continue; + } + Stream* found_stream = nullptr; + if ((found_stream = (*stream_executor)->FindAllocatedStream(gpu_stream)) != + nullptr) { + return found_stream; + } + } + return absl::NotFoundError("Stream not found"); +} + +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/add_i32_kernel.cu.cc b/third_party/xla/xla/stream_executor/stream_finder.h similarity index 55% rename from third_party/xla/xla/stream_executor/rocm/add_i32_kernel.cu.cc rename to third_party/xla/xla/stream_executor/stream_finder.h index 8a6406fe05e5f6..0503d3fbe4e641 100644 --- a/third_party/xla/xla/stream_executor/rocm/add_i32_kernel.cu.cc +++ b/third_party/xla/xla/stream_executor/stream_finder.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#ifndef XLA_STREAM_EXECUTOR_STREAM_FINDER_H_ +#define XLA_STREAM_EXECUTOR_STREAM_FINDER_H_ -extern "C" __global__ void add(int32_t* a, int32_t* b, int32_t* c) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - c[index] = a[index] + b[index]; -} +#include "absl/status/statusor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" + +namespace stream_executor { + +// Returns a Stream given the gpu_stream handle. +absl::StatusOr FindStream(Platform* platform, void* gpu_stream); + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_STREAM_FINDER_H_ diff --git a/third_party/xla/xla/stream_executor/stream_finder_test.cc b/third_party/xla/xla/stream_executor/stream_finder_test.cc new file mode 100644 index 00000000000000..6bb8ac86779519 --- /dev/null +++ b/third_party/xla/xla/stream_executor/stream_finder_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/stream_finder.h" + +#include "absl/status/status.h" +#include "xla/stream_executor/mock_platform.h" +#include "xla/stream_executor/mock_stream.h" +#include "xla/stream_executor/mock_stream_executor.h" +#include "xla/test.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +using testing::Return; +namespace stream_executor { +namespace { + +TEST(StreamFinderTest, FindStreamFailsWithNoExecutors) { + MockStreamExecutor stream_executor; + MockPlatform platform; + EXPECT_CALL(platform, VisibleDeviceCount()).WillOnce(Return(0)); + EXPECT_FALSE(FindStream(&platform, nullptr).ok()); +} + +TEST(StreamFinderTest, FindStreamFailsWithNoMatchingStream) { + MockStreamExecutor stream_executor; + MockPlatform platform; + EXPECT_CALL(platform, VisibleDeviceCount()).WillOnce(Return(1)); + EXPECT_CALL(platform, FindExisting(0)).WillOnce(Return(&stream_executor)); + void *gpu_stream = reinterpret_cast(0x1234); + EXPECT_CALL(stream_executor, FindAllocatedStream(gpu_stream)) + .WillOnce(Return(nullptr)); + EXPECT_FALSE(FindStream(&platform, gpu_stream).ok()); +} + +TEST(StreamFinderTest, FindStreamSucceeds) { + MockStreamExecutor stream_executor0; + MockStreamExecutor stream_executor1; + MockPlatform platform; + EXPECT_CALL(platform, VisibleDeviceCount()).WillOnce(Return(2)); + EXPECT_CALL(platform, FindExisting(0)).WillOnce(Return(&stream_executor0)); + EXPECT_CALL(platform, FindExisting(1)).WillOnce(Return(&stream_executor1)); + void *gpu_stream = reinterpret_cast(0x1234); + MockStream stream; + EXPECT_CALL(stream_executor0, FindAllocatedStream(gpu_stream)) + .WillOnce(Return(nullptr)); + EXPECT_CALL(stream_executor1, FindAllocatedStream(gpu_stream)) + .WillOnce(Return(&stream)); + TF_ASSERT_OK_AND_ASSIGN(auto found_stream, FindStream(&platform, gpu_stream)); + EXPECT_EQ(found_stream, &stream); +} + +TEST(StreamFinderTest, OnlyExecutor1Exists) { + MockStreamExecutor stream_executor1; + MockPlatform platform; + EXPECT_CALL(platform, VisibleDeviceCount()).WillOnce(Return(2)); + EXPECT_CALL(platform, FindExisting(0)) + .WillRepeatedly(Return(absl::NotFoundError("Nope"))); + EXPECT_CALL(platform, FindExisting(1)).WillOnce(Return(&stream_executor1)); + void *gpu_stream = reinterpret_cast(0x1234); + MockStream stream; + EXPECT_CALL(stream_executor1, FindAllocatedStream(gpu_stream)) + .WillOnce(Return(&stream)); + TF_ASSERT_OK_AND_ASSIGN(auto found_stream, FindStream(&platform, gpu_stream)); + EXPECT_EQ(found_stream, &stream); +} +} // namespace +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_test.cc b/third_party/xla/xla/stream_executor/stream_test.cc index 9fa6bc347fa441..ef5294ebe4260b 100644 --- a/third_party/xla/xla/stream_executor/stream_test.cc +++ b/third_party/xla/xla/stream_executor/stream_test.cc @@ -31,8 +31,7 @@ class StreamTest : public ::testing::Test { protected: StreamExecutor* NewStreamExecutor() { Platform* platform = PlatformManager::PlatformWithName("Host").value(); - StreamExecutorConfig config(/*ordinal=*/0); - return platform->GetExecutor(config).value(); + return platform->ExecutorForDevice(/*ordinal=*/0).value(); } }; diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc index ac6da36a5ea559..a78e104670bf21 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc +++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc @@ -64,26 +64,13 @@ SyclPlatform::DescriptionForDevice(int ordinal) const { } absl::StatusOr SyclPlatform::ExecutorForDevice(int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); -} - -absl::StatusOr SyclPlatform::GetExecutor( - const StreamExecutorConfig& config) { - if (config.gpu_stream) { - // If the GPU stream was provided, it's not possible to get-or-create a - // stream with a required pointer: so we are looking for previously - // allocated streams. - return executor_cache_.Get(config); - } return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> -SyclPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = std::make_unique(this, config.ordinal); +SyclPlatform::GetUncachedExecutor(int ordinal { + auto executor = std::make_unique(this, ordinal); TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h index adc6cc92a58208..61f0eb3d5372b9 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h +++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h @@ -55,15 +55,12 @@ class SyclPlatform : public Platform { absl::StatusOr ExecutorForDevice(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; - private: - // Returns a device constructed with the options specified in "config" without + // Returns a device constructed with ordinal without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config) override; + int ordinal) override; // This platform's name. std::string name_; diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h index 04f09bedbc92de..6331524142cc7f 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h @@ -64,7 +64,6 @@ typedef struct TpuSerializedProto { typedef struct SE_PlatformId { void* id; // aka stream_executor::Platform::Id } SE_PlatformId; -typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig; typedef TF_Status* (*SE_StatusCallback)(void*); typedef struct SE_DeviceMemoryBase { diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_defn.h b/third_party/xla/xla/stream_executor/tpu/c_api_defn.h index 59ecd662196daf..2d4f945396ca0b 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_defn.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_defn.h @@ -49,10 +49,6 @@ struct SE_Event { std::unique_ptr event; }; -struct SE_StreamExecutorConfig { - stream_executor::StreamExecutorConfig config; -}; - // Ignored -- these are just used to enforce the interface types struct XLA_TransferManager {}; struct XLA_ComputationPlacer {}; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h index 1d0087a02fa98b..a415204e85f592 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h @@ -28,8 +28,7 @@ SE_Platform* TpuPlatform_New(); void TpuPlatform_Free(SE_Platform* platform); void TpuPlatform_Initialize(SE_Platform* platform, TF_Status* status); bool TpuPlatform_Initialized(SE_Platform* platform); -SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform, - SE_StreamExecutorConfig* config, +SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform, int ordinal, TF_Status* status); SE_PlatformId TpuPlatform_Id(SE_Platform* platform); int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform); @@ -132,10 +131,6 @@ const char* TpuStatus_Message(TF_Status* status); int TpuStatus_Code(TF_Status* status); bool TpuStatus_Ok(TF_Status* status); -SE_StreamExecutorConfig* TpuStreamExecutorConfig_Default(); -void TpuStreamExecutorConfig_SetOrdinal(SE_StreamExecutorConfig*, int ordinal); -void TpuStreamExecutorConfig_Free(SE_StreamExecutorConfig*); - SE_DeviceDescription* TpuDeviceDescription_New(); void TpuDeviceDescription_Free(SE_DeviceDescription* description); void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor, @@ -417,10 +412,6 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Code); TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Ok); - TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Default); - TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_SetOrdinal); - TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Free); - TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_New); TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_Free); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc b/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc index 1b30487f03ffa2..5bc6a8ac9c4086 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc @@ -60,10 +60,6 @@ absl::Status SetExecutorStructFn( TFTPU_SET_FN(executor_fn, TpuStatus_Code); TFTPU_SET_FN(executor_fn, TpuStatus_Ok); - TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Default); - TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_SetOrdinal); - TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Free); - TFTPU_SET_FN(executor_fn, TpuDeviceDescription_New); TFTPU_SET_FN(executor_fn, TpuDeviceDescription_Free); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc index 16efed10f42179..5aaf8a75e94146 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc @@ -77,32 +77,23 @@ int TpuPlatform::VisibleDeviceCount() const { ->TpuPlatform_VisibleDeviceCountFn(platform_); } -absl::StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor( - const ::stream_executor::StreamExecutorConfig& config) { +absl::StatusOr<::stream_executor::StreamExecutor*> +TpuPlatform::ExecutorForDevice(int ordinal) { return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> -TpuPlatform::GetUncachedExecutor( - const ::stream_executor::StreamExecutorConfig& config) { - SE_StreamExecutorConfig* c_config = stream_executor::tpu::ExecutorApiFn() - ->TpuStreamExecutorConfig_DefaultFn(); - - stream_executor::tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn( - c_config, config.ordinal); - +TpuPlatform::GetUncachedExecutor(int ordinal) { StatusHelper status; SE_StreamExecutor* executor = stream_executor::tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn( - platform_, c_config, status.c_status); - stream_executor::tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn( - c_config); + platform_, ordinal, status.c_status); if (!status.ok()) { return status.status(); } return std::make_unique(this, executor, - config.ordinal); + ordinal); } ::stream_executor::Platform::Id TpuPlatform::id() const { diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h index 6cf3e197b85658..8eb6f19b7cdd6b 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h @@ -82,15 +82,13 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { } absl::StatusOr<::stream_executor::StreamExecutor*> ExecutorForDevice( + int ordinal) override; + + absl::StatusOr<::stream_executor::StreamExecutor*> FindExisting( int ordinal) override { - stream_executor::StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); + return executor_cache_.Get(ordinal); } - absl::StatusOr<::stream_executor::StreamExecutor*> GetExecutor( - const ::stream_executor::StreamExecutorConfig& config) override; - StreamMap* stream_map() { return &stream_map_; } void InsertEvent(stream_executor::Event* key, SE_Event* val); @@ -114,11 +112,11 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { absl::Mutex& mutex() { return event_map_mu_; } private: - // Returns a device constructed with the options specified in "config" without + // Returns a device constructed with the ordinal without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> - GetUncachedExecutor(const ::stream_executor::StreamExecutorConfig& config); + GetUncachedExecutor(int ordinal); mutable SE_Platform* platform_; std::string name_; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc index c7df6619342830..63c83e5696cfc5 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc @@ -42,6 +42,7 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, << status_or_tpu_platform.status(); return nullptr; } + LOG(INFO) << "Platform manager status: " << status_or_tpu_platform.status(); // Use any other registered TPU platform. auto status_or_other_tpu_platforms = @@ -72,12 +73,14 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, --tries_left; if (tries_left <= 0) { - LOG(INFO) << "No TPU platform found."; + LOG(INFO) << "No TPU platform found. Platform manager status: " + << status_or_other_tpu_platforms.status(); return nullptr; } LOG(INFO) << "No TPU platform registered. Waiting 1 second and trying again... (" - << tries_left << " tries left)"; + << tries_left << " tries left) Platform manager status: " + << status_or_other_tpu_platforms.status(); tsl::Env::Default()->SleepForMicroseconds(1000000); // 1 second return GetRegisteredPlatformStatic(initialize_platform, tries_left); } diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 94751bfb423bb3..8f240268f20f0d 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -201,6 +201,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:computation_layout", "//xla/service:hlo_module_util", @@ -1782,12 +1783,16 @@ xla_test( ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", - "//xla:array2d", + "//xla:array3d", + "//xla:array4d", "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/hlo/ir:hlo", - "@local_tsl//tsl/platform:protobuf", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test", ], ) @@ -2237,7 +2242,6 @@ xla_test( tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:shape_util", @@ -2245,7 +2249,6 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:local_client", "//xla/client:xla_builder", - "//xla/stream_executor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@local_tsl//tsl/platform:ml_dtypes", @@ -2378,8 +2381,12 @@ xla_test( "//xla:literal", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl index 8a42642c8e5ff5..8d73f8969255d9 100644 --- a/third_party/xla/xla/tests/build_defs.bzl +++ b/third_party/xla/xla/tests/build_defs.bzl @@ -31,6 +31,8 @@ GPU_BACKENDS = NVIDIA_GPU_BACKENDS + AMD_GPU_DEFAULT_BACKENDS GPU_DEFAULT_BACKENDS = NVIDIA_GPU_DEFAULT_BACKENDS +DEFAULT_DISABLED_BACKENDS = [] + _ALL_BACKENDS = ["cpu", "interpreter"] + NVIDIA_GPU_BACKENDS + AMD_GPU_DEFAULT_BACKENDS + list(plugins.keys()) # buildifier: disable=function-docstring @@ -175,7 +177,7 @@ def xla_test( deps, xla_test_library_deps = [], backends = [], - disabled_backends = [], + disabled_backends = DEFAULT_DISABLED_BACKENDS, real_hardware_only = False, # @unused, all backends are real hardware. args = [], tags = [], @@ -281,6 +283,8 @@ def xla_test( ] if backend in NVIDIA_GPU_BACKENDS: this_backend_tags += tf_gpu_tests_tags() + if backend in AMD_GPU_DEFAULT_BACKENDS: + this_backend_tags.append("gpu") this_backend_copts.append("-DXLA_TEST_BACKEND_GPU=1") elif backend == "interpreter": backend_deps += [ @@ -320,8 +324,23 @@ def xla_test( # b/317293391. For this reason, if we would create an empty `test_suite`, # instead create a `cc_test` with no srcs that links against `main` to have # more predictable behavior that avoids bugs. + # + # Due to b/317293391, we also mark the test suite `manual`, so that wild card builds + # like in the XLA CI won't try to build the test suite target. Instead the wild card + # build will build the individual test targets and therefore respect the tags on each + # individual test target. + # Example: Assume we have an `xla_test(name=my_test)` in `//xla/service/gpu` with backends `cpu` + # and `gpu`. This generates two test targets `//xla/service/gpu:my_test_{cpu|gpu}`. The latter + # has a tag `gpu`. + # + # - `bazel test --test_tag_filters=-gpu //xla/service/gpu/...` will only run the cpu test. + # - `bazel test //xla/service/gpu/...` will run both tests. + # - `bazel test //xla/service/gpu:my_test` will run both tests. + # Caveat: + # - `bazel test --test_tag_filters=-gpu //xla/service/gpu:my_test` will run both tests and + # not respect the tag filter - but it's way better than the previous behavoir. if test_names: - native.test_suite(name = name, tags = tags, tests = test_names) + native.test_suite(name = name, tags = tags + ["manual"], tests = test_names) else: native.cc_test(name = name, deps = ["@local_tsl//tsl/platform:test_main"]) diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 99085382d07e8d..1e399127318242 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -25,12 +26,15 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" namespace xla { namespace { @@ -70,6 +74,31 @@ class CollectiveOpsTestE2E : public HloTestBase { GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); } + void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text, + const DebugOptions& options) { + if (!HasFp8Support()) { + return; + } + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + config.set_debug_options(options); + config.set_num_partitions(kNumPartitions); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_text, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); + EXPECT_TRUE(executable->has_module()); + HloInstruction* gemm_op = + FindInstruction(&executable->module(), HloOpcode::kCustomCall); + EXPECT_THAT(gemm_op, NotNull()); + EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + } + absl::StatusOr> ExecuteReplicated(Executable* executable, int64_t num_replicas) { DeviceAssignment device_assignment = MakeDeviceAssn(num_replicas); @@ -825,6 +854,59 @@ ENTRY main.12 { // Custom Calls. CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); +} + +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, + WindowedEinsumE2EAllGatherMultiConsumerF8) { + absl::string_view kModuleReplicatedStr = R"( +HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 + +ENTRY main { + rhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + lhs0 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + scale_rhs = bf16[] parameter(3) + scale_lhs0 = bf16[] parameter(4) + scale_rhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={} + scale_lhs0_bcast = bf16[48,192]{1,0} broadcast(scale_lhs0), dimensions={} + rhs_bf16 = bf16[2,16,48]{2,1,0} convert(rhs) + lhs0_bf16 = bf16[48,192]{1,0} convert(lhs0) + rhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16) + lhs0_scaled = bf16[48,192]{1,0} multiply(scale_lhs0_bcast, lhs0_bf16) + dot0 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + lhs1 = f8e4m3fn[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} + scale_lhs1 = bf16[] parameter(5) + scale_lhs1_bcast = bf16[48,192]{1,0} broadcast(scale_lhs1), dimensions={} + lhs1_bf16 = bf16[48,192]{1,0} convert(lhs1) + lhs1_scaled = bf16[48,192]{1,0} multiply(scale_lhs1_bcast, lhs1_bf16) + dot1 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT add.8 = bf16[2,16,192]{2,1,0} add(dot0, dot1) +} // main +)"; + + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -972,7 +1054,6 @@ ENTRY entry { const int64_t kNumReplicas = 1; SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); - const int64_t kNumPartitions = 4; HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -980,19 +1061,7 @@ ENTRY entry { opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); opts.set_xla_gpu_enable_pipelined_collectives(true); opts.set_xla_gpu_enable_triton_gemm(false); - config.set_debug_options(opts); - config.set_num_partitions(kNumPartitions); - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); - - TF_ASSERT_OK_AND_ASSIGN(auto executable, - CreateExecutable(std::move(module), - /*run_hlo_passes=*/true)); - EXPECT_TRUE(executable->has_module()); - HloInstruction* gemm_op = - FindInstruction(&executable->module(), HloOpcode::kCustomCall); - EXPECT_THAT(gemm_op, NotNull()); - EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); } TEST_F(CollectiveOpsTestE2E, diff --git a/third_party/xla/xla/tests/convert_test.cc b/third_party/xla/xla/tests/convert_test.cc index 13ca51a4025ebb..4db6394d1503ce 100644 --- a/third_party/xla/xla/tests/convert_test.cc +++ b/third_party/xla/xla/tests/convert_test.cc @@ -14,20 +14,19 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/base/casts.h" #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" -#include "xla/primitive_util.h" #include "xla/shape_util.h" -#include "xla/stream_executor/stream_executor.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -672,6 +671,59 @@ XLA_TEST_F(ConvertTest, ConvertBF16F32) { } } +XLA_TEST_F(ConvertTest, ConvertF32BF16) { + XlaBuilder builder(TestName()); + + std::vector floats(100); + std::minstd_rand0 generator; + for (int i = 0; i < floats.size(); ++i) { + floats[i] = generator(); + + // Ensure the first 10 cases has rounding. + if (i < 10) { + auto val = absl::bit_cast(floats[i]); + val |= 1 << 15; + floats[i] = absl::bit_cast(val); + } + } + // Test NaN and -Nan. + floats.push_back(std::numeric_limits::quiet_NaN()); + floats.push_back(-std::numeric_limits::quiet_NaN()); + + std::vector expected(floats.size()); + for (int i = 0; i < expected.size(); ++i) { + expected[i] = static_cast(floats[i]); + } + + xla::XlaOp lit_f32 = ConstantR1(&builder, floats); + xla::XlaOp lit_bf16 = ConvertElementType(lit_f32, BF16); + BitcastConvertType(lit_bf16, U16); + + TF_ASSERT_OK_AND_ASSIGN(const auto results, ExecuteAndTransfer(&builder, {})); + for (int i = 0; i < expected.size(); ++i) { + const auto result = results.Get({i}); + const auto correct = absl::bit_cast(expected[i]); + if (floats[i] != 0.0f && floats[i] < std::numeric_limits::min()) { + // Subnormals may not be preserved, zero will do. + const bfloat16 same_signed_zero = + bfloat16(std::signbit(floats[i]) ? -0.0f : 0.0f); + if (result != correct) { + EXPECT_EQ(result, absl::bit_cast(same_signed_zero)); + } + } else if (std::isnan(floats[i])) { + // NaNs may not be preserved, any NaN will do. + ASSERT_TRUE(std::isnan(absl::bit_cast(correct))); + EXPECT_TRUE(std::isnan(absl::bit_cast(result))); + if (client_->platform()->Name() == "Host") { + // The sign bits must match. + EXPECT_EQ(result >> 15, correct >> 15); + } + } else { + EXPECT_EQ(result, correct); + } + } +} + XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { // Convert from FP16 to FP8, then back to FP16 XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/tests/copy_test.cc b/third_party/xla/xla/tests/copy_test.cc index 45d94ab0333838..91b7fa2a1473c8 100644 --- a/third_party/xla/xla/tests/copy_test.cc +++ b/third_party/xla/xla/tests/copy_test.cc @@ -13,22 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include +#include -#include "xla/array2d.h" +#include +#include "absl/types/span.h" +#include "xla/array3d.h" +#include "xla/array4d.h" #include "xla/client/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/test.h" namespace xla { @@ -50,6 +59,25 @@ class CopyOpTest : public HloTestBase { EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } + // TODO(vsytch): Remove special handling for dynamic shapes once *all* of XLA + // supports those as module inputs/outputs. + void TestDynamicCopyOp(const Literal& literal, const Shape& bounded_shape) { + Literal dynamic_literal = literal.ToBoundedDynamic(bounded_shape); + auto builder = HloComputation::Builder(TestName()); + auto parameter = builder.AddInstruction( + HloInstruction::CreateParameter(0, dynamic_literal.shape(), "param")); + builder.AddInstruction(HloInstruction::CreateUnary( + parameter->shape(), HloOpcode::kCopy, parameter)); + auto computation = builder.Build(); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(std::move(computation)); + + std::vector args = {&dynamic_literal}; + Literal result = ExecuteAndTransfer(std::move(module), args); + Literal dynamic_result = result.ToBoundedDynamic(bounded_shape); + EXPECT_TRUE(LiteralTestUtil::Equal(dynamic_literal, dynamic_result)); + } + void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4, absl::Span permutation); @@ -67,6 +95,59 @@ XLA_TEST_F(CopyOpTest, CopyR1S3U32) { TestCopyOp(LiteralUtil::CreateR1({1, 2, 3})); } +XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic0) { + // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP(); + } + Shape bounded_shape = + ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true}); + TestDynamicCopyOp(LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(PrimitiveType::F32, {0}), 0, 1) + .value(), + bounded_shape); +} + +XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic106632) { + // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP(); + } + Shape bounded_shape = + ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true}); + TestDynamicCopyOp( + LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(PrimitiveType::F32, {106632}), 0, 1) + .value(), + bounded_shape); +} + +XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic1310720) { + // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP(); + } + Shape bounded_shape = + ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true}); + TestDynamicCopyOp( + LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}), 0, 1) + .value(), + bounded_shape); +} + +XLA_TEST_F(CopyOpTest, CopyDynamicR1S512U32Dynamic64) { + // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP(); + } + Shape bounded_shape = ShapeUtil::MakeShape(PrimitiveType::F32, {512}, {true}); + TestDynamicCopyOp(LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(PrimitiveType::F32, {64}), 0, 1) + .value(), + bounded_shape); +} + XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc index 3fd7cf554ee362..2ada7f0b22152b 100644 --- a/third_party/xla/xla/tests/custom_call_test.cc +++ b/third_party/xla/xla/tests/custom_call_test.cc @@ -1664,5 +1664,99 @@ XLA_TEST_F(FfiCustomCallTest, IntraOpThreadPool) { EXPECT_EQ(status, absl::OkStatus()); } +//===----------------------------------------------------------------------===// +// Stateful XLA:FFI handler +//===----------------------------------------------------------------------===// + +struct SomeState { + explicit SomeState(float value) : value(value) {} + float value = 0; +}; + +int instantiate_called_counter = 0; + +// Every time custom call HLO operation is instantiated as a CPU runtime Thunk, +// XLA calls instantiate callback to create a new instance of the handler state, +// that will be passed to all other FFI handler calls. +static absl::StatusOr> InstantiateState() { + ++instantiate_called_counter; + return std::make_unique(42.f); +} + +// At run time we can access the state created by the instantiate callback. +static absl::Status IncrementState(R0F32ResultBuffer out, SomeState* state) { + state->value += 1.f; + auto out_data = out->typed_data(); + *out_data = state->value; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kInstantiateState, InstantiateState, + ffi::Ffi::BindInstantiate()); + +XLA_FFI_DEFINE_HANDLER( + kIncrementState, IncrementState, + ffi::Ffi::Bind().Ret().Ctx>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$ffi_execution_state", + "Host", + { + /*instantiate=*/kInstantiateState, + /*prepare=*/nullptr, + /*initialize=*/nullptr, + /*execute=*/kIncrementState, + }); + +// This test doesn't care about execution results, its intent is just to test if +// instantiate function was called. +TEST_F(CustomCallTest, FfiExecutionStateInstantiate) { + const char* const kModuleStr = R"( + HloModule m + ENTRY test { + ROOT result = f32[] custom-call(), custom_call_target= + "__xla_test$$ffi_execution_state", api_version=API_VERSION_TYPED_FFI + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // Execute the module, but don't verify the results. + instantiate_called_counter = 0; + auto result = Execute(std::move(module), {}); + + // Check that instantiate callback was called. + EXPECT_EQ(instantiate_called_counter, 1); +} + +TEST_F(CustomCallTest, FfiExecutionStateExecute) { + // Execution state is only partially implemented at the moment. + GTEST_SKIP() << "Not implemented yet."; + + // TODO(abanas): Actually, this HLO probably creates two custom call thunks, + // each one is called once. If yes then fix it, cause the intent is to call + // the same custom call twice. + const char* const kModuleStr = R"( + HloModule m + ENTRY test { + first = f32[] custom-call(), custom_call_target= + "__xla_test$$ffi_execution_state", api_version=API_VERSION_TYPED_FFI + second = f32[] custom-call(), custom_call_target= + "__xla_test$$ffi_execution_state", api_version=API_VERSION_TYPED_FFI + ROOT result = (f32[], f32[]) tuple(first, second) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal expected0 = + LiteralUtil::CreateR0(43.f); // Incremented once. + Literal expected1 = + LiteralUtil::CreateR0(44.f); // Incremented twice. + Literal expected = LiteralUtil::MakeTuple({&expected0, &expected1}); + + TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {})); + EXPECT_EQ(result, expected); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD index ed239be2659944..dcd74bcc34750c 100644 --- a/third_party/xla/xla/tests/exhaustive/BUILD +++ b/third_party/xla/xla/tests/exhaustive/BUILD @@ -165,11 +165,18 @@ xla_test( ], ) +filegroup( + name = "exhaustive_binary_16_bit_test_srcs", + srcs = [ + "exhaustive_binary_16_bit_test.cc", + ], +) + xla_test( name = "exhaustive_binary_16_bit_test", srcs = [ - "exhaustive_binary_16_bit_test.cc", "exhaustive_test_main.cc", + ":exhaustive_binary_16_bit_test_srcs", ], backends = [ "gpu", diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc index c31d97b6df32dc..3f61111c974c84 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include -#include +#include +#include #include #include #include @@ -45,7 +48,7 @@ namespace { // including float16 and bfloat. // // Test parameter is a pair of (begin, end) for range under test. -template +template class Exhaustive16BitBinaryTest : public ExhaustiveBinaryTest, public ::testing::WithParamInterface> { @@ -57,9 +60,13 @@ class Exhaustive16BitBinaryTest } // Given a range of uint64_t representation, uses bits 0..15 and bits 16..31 - // for the values of src0 and src1 for a 16 bit binary operation being tested, - // and generates the cartesian product of the two sets as the two inputs for - // the test. + // for the values of src0 and src1 (see below for ordering) for the 16 bit + // binary operation being tested, and generates the cartesian product of the + // two sets as the two inputs for the test. + // + // If `kLeftToRightPacking == true`, bit 31..16 become src0 and 15..0 becomes + // src1. If `kLeftToRightPacking == false`, then bits 31..16 become src1 + // and 15..0 becomes src0. void FillInput(std::array* input_literals) override { int64_t input_size = GetInputSize(); CHECK_EQ(input_size, (*input_literals)[0].element_count()); @@ -67,17 +74,53 @@ class Exhaustive16BitBinaryTest int64_t begin, end; std::tie(begin, end) = GetParam(); - VLOG(2) << "Checking range [" << begin << ", " << end << "]"; + + uint16_t left_begin, left_end, right_begin, right_end; + if constexpr (kLeftToRightPacking) { + left_begin = std::bit_cast(static_cast(begin >> 16)); + left_end = std::bit_cast(static_cast(end >> 16)); + right_begin = std::bit_cast(static_cast(begin)); + right_end = std::bit_cast(static_cast(end)); + } else { + left_begin = std::bit_cast(static_cast(begin)); + left_end = std::bit_cast(static_cast(end)); + right_begin = std::bit_cast(static_cast(begin >> 16)); + right_end = std::bit_cast(static_cast(end >> 16)); + } + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + LOG(INFO) << "\tfrom=(" << left_begin << ", " << right_begin << "); hex=(" + << std::hex << left_begin << ", " << right_begin << "); float=(" + << *reinterpret_cast(&left_begin) << ", " + << *reinterpret_cast(&right_begin) + << ") (inclusive)"; + LOG(INFO) << "\tto=(" << left_end << ", " << right_end << "); hex=(" + << std::hex << left_end << ", " << right_end << "); float=(" + << *reinterpret_cast(&left_end) << ", " + << *reinterpret_cast(&right_end) + << ") (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } absl::Span input_arr_0 = (*input_literals)[0].data(); absl::Span input_arr_1 = (*input_literals)[1].data(); for (int64_t i = 0; i < input_size; i++) { uint32_t input_val = i + begin; - // Convert the lower 16 bits to the NativeT and replaced known incorrect - // input values with 0. - input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); - input_arr_1[i] = - ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0); + // Convert the packed bits to a pair of NativeT and replace known + // incorrect input values with 0. + // + // In either case, we only use 32 bits out of the 64 bits possible. + if constexpr (kLeftToRightPacking) { + // Left is stored at higher 16 bits. + input_arr_0[i] = + ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0); + input_arr_1[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); + } else { + // Left is stored at lower 16 bits. + input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); + input_arr_1[i] = + ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0); + } } } @@ -108,48 +151,99 @@ using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; BINARY_TEST_F16(test_name, __VA_ARGS__) \ BINARY_TEST_BF16(test_name, __VA_ARGS__) +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double AddCpuTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = static_cast(left) + static_cast(right); + + // Hardware flushes subnormal outputs to 0. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + BINARY_TEST_16BIT(Add, { - auto host_add = [](float x, float y) { return x + y; }; - Run(AddEmptyBroadcastDimension(Add), host_add); + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if ((IsCpu(platform_) || IsTpu(platform_))) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(AddCpuTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + Run( + AddEmptyBroadcastDimension(Add), [](float x, float y) { return x + y; }, + error_spec_gen); }) +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double SubCpuTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = static_cast(left) - static_cast(right); + + // Hardware flushes subnormal outputs to 0. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + BINARY_TEST_16BIT(Sub, { - auto host_sub = [](float x, float y) { return x - y; }; - Run(AddEmptyBroadcastDimension(Sub), host_sub); + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_) || IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(SubCpuTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + Run( + AddEmptyBroadcastDimension(Sub), [](float x, float y) { return x - y; }, + error_spec_gen); }) // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. -double MulCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { +double MulCpuTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { float output = static_cast(left) * static_cast(right); - // Subnormals are flushed to 0 (as inputs or outputs). In these cases, we - // calculate 0 instead of the expected very small number so we use the minimum - // float value as the absolute error to give a buffer. - auto left_is_subnormal = IsSubnormal(left); - auto right_is_subnormal = IsSubnormal(right); + // CPU BF16 and TPU (all types) flush subnormals to 0. auto output_is_subnormal = IsSubnormal(output); - if (left_is_subnormal || right_is_subnormal || output_is_subnormal) { + if (output_is_subnormal) { return std::numeric_limits::min(); } return 0.0; } -bool MulCpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { - // For BF16, multiplying a subnormal by infinity will lead to calculating 0 - // multiplied by infinity due to subnormal flushing, which is defined to be - // NaN. However, the calculation in higher precision does not flush the - // subnormal value to 0, leading to a result of infinity. - auto left_is_subnormal = IsSubnormal(left); - auto left_is_infinite = std::isinf(left); - auto right_is_subnormal = IsSubnormal(right); - auto right_is_infinite = std::isinf(right); - if ((left_is_subnormal && right_is_infinite) || - (left_is_infinite && right_is_subnormal)) { +bool MulCpuTpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { + // For CPU and TPU BF16, multiplying a subnormal by infinity will lead to + // calculating 0 multiplied by infinity due to subnormal flushing, which is + // defined to be NaN. However, the calculation in higher precision does not + // flush the subnormal value to 0, leading to a result of infinity. + if ((IsSubnormal(left) && std::isinf(right)) || + (std::isinf(left) && IsSubnormal(right))) { return true; } - return false; } @@ -157,19 +251,22 @@ BINARY_TEST_16BIT(Mul, { ErrorSpecGen error_spec_gen = +[](NativeT left, NativeT right) { return ErrorSpec::Builder().strict_signed_zeros().build(); }; - if (IsCpu(platform_)) { + + if (IsCpu(platform_) || IsTpu(platform_)) { if constexpr (std::is_same_v) { error_spec_gen = +[](NativeT left, NativeT right) { return ErrorSpec::Builder() - .abs_err(MulCpuBf16AbsErr(static_cast(left), - static_cast(right))) + .abs_err(MulCpuTpuBf16AbsErr(static_cast(left), + static_cast(right))) .strict_signed_zeros() - .skip_comparison(MulCpuBf16Skip(static_cast(left), - static_cast(right))) + .skip_comparison( + MulCpuTpuBf16Skip(static_cast(left), + static_cast(right))) .build(); }; } } + Run( AddEmptyBroadcastDimension(Mul), [](float x, float y) { return x * y; }, error_spec_gen); @@ -182,18 +279,63 @@ double DivCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { // Subnormals are flushed to 0 so we add a absolute error margin that is // larger than any subnormal. - auto output_is_subnormal = IsSubnormal(output); - if (output_is_subnormal) { + if (IsSubnormal(output)) { return std::numeric_limits::min(); } return 0.0; } +double DivTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float reciprocal = 1.0f / static_cast(right); + xla::bfloat16 output = left / right; + float output_as_float = static_cast(left) / static_cast(right); + + // If we calculate NaN, we don't need to adjust tolerances. + if (std::isnan(output_as_float)) { + return 0.0; + } + + // TPUs perform `left * (1 / right)`, where `left` and `1 / right` are + // flushed to `0` if they are subnormal. Also applies to if reciprocal is min + // normal. + if (IsSubnormal(left) || IsSubnormalOrMinNormal(reciprocal)) { + // Subnormals can have a larger value in BF16 than float due to rounding to + // the nearest BF16 value during conversion while having less representation + // bits. For normals, the float value is usually always bigger due to + // greater precision. + return std::max(std::abs(output), std::abs(output_as_float)); + } + + // For subnormals, we need to set absolute error to the smallest positive + // representable value due to hardware implementations that truncate + // subnormals to zero. + if (IsSubnormalOrMinNormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + +bool DivTpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { + float reciprocal = 1.0f / right; + + // TPU calculates `left * (1 / right)` and flushed `(1 / right)` to `0` when + // it is subnormal or min normal. It also follows the IEEE multiplication spec + // that inf * 0 is NaN. However, IEEE division of infinity by a subnormal is + // infinity, so we must skip comparison. + if (std::isinf(left) && IsSubnormalOrMinNormal(reciprocal)) { + return true; + } + + return false; +} + BINARY_TEST_16BIT(Div, { ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { return ErrorSpec::Builder().strict_signed_zeros().build(); }; + if (IsCpu(platform_)) { if constexpr (std::is_same_v) { error_spec_gen = +[](NativeT left, NativeT right) { @@ -205,34 +347,256 @@ BINARY_TEST_16BIT(Div, { }; } } + if (IsGpu(platform_) && std::is_same_v) { error_spec_gen = +[](NativeT, NativeT) { return ErrorSpec::Builder().distance_err(1).strict_signed_zeros().build(); }; } + + if (IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(DivTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .skip_comparison(DivTpuBf16Skip(static_cast(left), + static_cast(right))) + .build(); + }; + } else if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(std::numeric_limits::min()) + .strict_signed_zeros() + .build(); + }; + } + } + if (IsPreV5Tpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(DivTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .rel_err(std::numeric_limits::epsilon()) + .strict_signed_zeros() + .skip_comparison(DivTpuBf16Skip(static_cast(left), + static_cast(right))) + .build(); + }; + } + } + Run( AddEmptyBroadcastDimension(Div), [](float x, float y) { return x / y; }, error_spec_gen); }) +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double MaxMinCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + // Subnormals are treated as 0 and max returns the first if all are + // 0-equivalent. + if (IsSubnormal(left) && (right == 0.0 || IsSubnormal(right))) { + return std::abs(left); + } + return 0.0; +} + BINARY_TEST_16BIT(Max, { - Run(AddEmptyBroadcastDimension(Max), ReferenceMax); + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(MaxMinCpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + if (IsGpu(platform_) || IsTpu(platform_)) { + error_spec_gen = +[](NativeT, NativeT) { + // A100 and H100 return -0 for max(-0,0). + // + // TPUs return -0 for max(0,-0) and 0 for max(-0,0). + return ErrorSpec::Builder().strict_signed_zeros(false).build(); + }; + } + + Run(AddEmptyBroadcastDimension(Max), ReferenceMax, error_spec_gen); }) BINARY_TEST_16BIT(Min, { - Run(AddEmptyBroadcastDimension(Min), ReferenceMin); + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(MaxMinCpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + if (IsGpu(platform_) || IsTpu(platform_)) { + error_spec_gen = +[](NativeT, NativeT) { + // A100 and H100 return 0 for min(0,-0). + // + // TPUs return 0 for min(-0,0) and -0 for min(0,-0). + return ErrorSpec::Builder().strict_signed_zeros(false).build(); + }; + } + + Run(AddEmptyBroadcastDimension(Min), ReferenceMin, error_spec_gen); }) -// TODO(bixia): Pow fails with bfloat16 on CPU. -BINARY_TEST_16BIT(DISABLED_ON_GPU(DISABLED_ON_CPU(Pow)), { - // See b/162664705. - known_incorrect_fn_ = [](int64_t val) { - Eigen::bfloat16 f; - uint16_t val_16 = val; - memcpy(&f, &val_16, 2); - return std::isnan(f); +template +bool PowCpuGpuF16Skip(NativeT left, NativeT right) { + // Hardware always returns 1 if right is 0, no matter if left is NaN. + if (std::isnan(left) && right == 0.0f) { + return true; + } + // Hardware always returns 1 if left is 1, no matter if right is NaN. + if (left == 1.0f && std::isnan(right)) { + return true; + } + return false; +} + +double PowCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = std::pow(static_cast(left), static_cast(right)); + + // Output is flushed to 0 if subnormal. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + // TODO(b/359325328): pow computation for subnormal bases is different from + // std::pow. + // + // If the base is subnormal, the output computation selects a different base. + // The minimum value ever chosen is slightly greater than the 1e-91 used + // below. We return an absolute error from this value to the "real" output. + // + // Because the exponent (right) can be any floating point value, this allows + // an arbitrary absolute error for subnormal values. + if (IsSubnormal(left)) { + xla::bfloat16 output_as_bf16 = static_cast(output); + auto expected = std::pow(1e-91, static_cast(right)); + auto err = std::abs(expected - output_as_bf16); + if (!std::isnan(err)) { + return err; + } + } + + return 0.0; +} + +double PowTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = std::pow(static_cast(left), static_cast(right)); + + // Output is flushed to 0 if subnormal. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + +template +bool PowTpuSkip(NativeT left, NativeT right) { + // Hardware always returns 1 if right is 0 (or subnormal due to + // flushing subnormals to zero before the operation), no matter if left is + // NaN. + if (std::isnan(left) && (right == 0.0f || IsSubnormal(right))) { + return true; + } + // Hardware always returns 1 if left is 1, no matter if right is NaN. + if (left == 1.0f && std::isnan(right)) { + return true; + } + + return false; +} + +BINARY_TEST_16BIT(Pow, { + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); }; - Run(AddEmptyBroadcastDimension(Pow), std::pow); + + if (IsCpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .strict_signed_zeros() + .skip_comparison(PowCpuGpuF16Skip(left, right)) + .build(); + }; + } else if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(PowCpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } else if constexpr (std::is_same_v || + std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + }; + } + } + + if (IsGpu(platform_)) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .skip_comparison(PowCpuGpuF16Skip(left, right)) + .build(); + }; + } + + if (IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(PowTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .distance_err(1) + .strict_signed_zeros() + .skip_comparison(PowTpuSkip(left, right)) + .build(); + }; + } else if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .skip_comparison(PowTpuSkip(left, right)) + .build(); + }; + } + } + + Run(AddEmptyBroadcastDimension(Pow), std::pow, error_spec_gen); }) // Can be thought of as an absolute error of @@ -243,8 +607,7 @@ double Atan2CpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { // If the output would be a subnormal float, we allow some error to account // for BF16 implementation flushing subnormals to zero. - auto output_is_subnormal = IsSubnormal(output); - if (output_is_subnormal) { + if (IsSubnormal(output)) { return std::numeric_limits::min(); } @@ -262,10 +625,40 @@ bool Atan2CpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { return false; } +double Atan2TpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + xla::bfloat16 output = static_cast(std::atan2(left, right)); + float output_as_float = + std::atan2(static_cast(left), static_cast(right)); + + // If the output would be a subnormal float, we allow some error to account + // for BF16 implementation flushing subnormals to zero. TPUs also seem to + // flush the minimum value to 0 along with subnormals. + if (IsSubnormalOrMinNormal(output_as_float)) { + return std::numeric_limits::min(); + } + + // Implementation of Atan2 on TPUs is that they take the reciprocal of the + // larger of left or right. If this is subnormal or the minimum value, the TPU + // flushes it to 0 before using it in multiplication. When this happens, the + // error is the output calculation, either in BF16 or float, or PI/2, + // depending on which of the three is bigger. + float reciprocal_as_float = + 1.0f / std::max(std::abs(static_cast(left)), + std::abs(static_cast(right))); + if (!std::isnan(output_as_float) && + IsSubnormalOrMinNormal(reciprocal_as_float)) { + return std::max({std::abs(output_as_float), std::abs(output), + static_cast(M_PI_2)}); + } + + return 0.0; +} + BINARY_TEST_16BIT(Atan2, { auto error_spec_gen = +[](NativeT, NativeT) { return ErrorSpec::Builder().strict_signed_zeros().build(); }; + if (IsCpu(platform_)) { if constexpr (std::is_same_v) { error_spec_gen = +[](NativeT left, NativeT right) { @@ -280,11 +673,33 @@ BINARY_TEST_16BIT(Atan2, { }; } } + if (IsGpu(platform_)) { error_spec_gen = +[](NativeT, NativeT) { return ErrorSpec::Builder().distance_err(1).strict_signed_zeros().build(); }; } + + if (IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(Atan2TpuBf16AbsErr(static_cast(left), + static_cast(right))) + .distance_err(1) + .strict_signed_zeros() + .build(); + }; + } else if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + }; + } + } + Run(AddEmptyBroadcastDimension(Atan2), std::atan2, error_spec_gen); }) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc index 57d1c3fd2a371a..06cc4b0822f153 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc @@ -63,9 +63,12 @@ class Exhaustive32BitOrMoreBinaryTest FpValues values_0; FpValues values_1; std::tie(values_0, values_1) = GetParam(); - - VLOG(2) << " testing " << values_0.ToString() << " " << values_1.ToString() - << "total values " << input_size; + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\tleft values=" << values_0.ToString(); + LOG(INFO) << "\tright values=" << values_1.ToString(); + LOG(INFO) << "\ttotal values to test=" << input_size; + } CHECK(input_size == (*input_literals)[0].element_count() && input_size == (*input_literals)[1].element_count()); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc index f677539c677c99..ccea6e55388c0c 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -653,6 +653,13 @@ void ExhaustiveOpTestBase::ExpectNear( StringifyNum(actual))); PrintMismatch(&mismatches, [mismatch] { return mismatch; }); + + // If we have emitted debug logging, we fail the test execution at the first + // comparison failure to avoid dumping too much log data and ensure the + // relevant debugging information is the last logged data. + if (should_emit_debug_logging_) { + ASSERT_TRUE(false); + } } EXPECT_EQ(mismatches, 0); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h index 64653507ab7600..30c9acbe69c86f 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -32,6 +32,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -335,6 +336,21 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { mutable_debug_options()->clear_xla_disable_hlo_passes(); } + // Enable debug logging for the invocation of the lambda. + // + // This is intended to be used to wrap a call to `Run`, which will then log + // extra debug information for a failure such as the calculated absolute, + // relative, and distance errors. In addition, in an effort to reduce output + // log size, this will trigger an ASSERT failure to early return from a test + // at the first failure. + template , int> = 0> + void EnableDebugLoggingForScope(Callable&& work) { + should_emit_debug_logging_ = true; + work(); + should_emit_debug_logging_ = false; + } + void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op, OutputRangeCheck check_valid_range = nullptr) { Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(), @@ -657,8 +673,17 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // will be wildly off. We convert back to NativeT for this comparison. int64_t distance_err = GetDistanceErr(NativeT(expected), NativeT(actual)); - return abs_err <= spec.abs_err || rel_err <= spec.rel_err || - distance_err <= spec.distance_err; + bool passed = abs_err <= spec.abs_err || rel_err <= spec.rel_err || + distance_err <= spec.distance_err; + if (should_emit_debug_logging_ && !passed) { + LOG(INFO) << "actual: " << actual << "; expected: " << expected + << "\n\tabs_err: " << abs_err + << "; spec.abs_err: " << spec.abs_err + << "\n\trel_err: " << rel_err << "; spec.rel_err: " << rel_err + << "\n\tdistance_err: " << distance_err + << "; spec.distance_err: " << spec.distance_err; + } + return passed; } // Converts part or all bits in an uint64_t to the value of the floating point @@ -712,6 +737,10 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // Indicates if files of the expected and actual values should be dumped. bool should_dump_values_ = false; + + // Indicates if additional (potentially costly) logging should be emitted to + // ease with debugging. + bool should_emit_debug_logging_ = false; }; // Represents a set of 64 bit chunks by representing the starting bit chunk, diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc index e34d5cf2a3eb05..3bba8b80f9967d 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc @@ -72,14 +72,16 @@ class ExhaustiveComplexUnaryTestBase void FillInput(std::array* input_literal) override { FpValues real_values = std::get<0>(GetParam()); FpValues imag_values = std::get<1>(GetParam()); - - VLOG(2) << " testing input total " - << real_values.GetTotalNumValues() * imag_values.GetTotalNumValues() - << ", range " << real_values.ToString() << " " - << imag_values.ToString(); + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\treal values=" << real_values.ToString(); + LOG(INFO) << "\timag values=" << imag_values.ToString(); + LOG(INFO) << "\ttotal values to test=" + << real_values.GetTotalNumValues() * + imag_values.GetTotalNumValues(); + } absl::Span input_arr = (*input_literal)[0].data(); - uint64_t i = 0; for (auto real : real_values) { for (auto imag : imag_values) { diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc index 7a9857f927d7bc..c8354875823980 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -202,7 +203,6 @@ class Exhaustive32BitOrLessUnaryTest private: int64_t GetInputSize() override { auto [begin, end] = GetParam(); - VLOG(2) << "Checking range [" << begin << ", " << end << ")"; return end - begin; } @@ -217,8 +217,18 @@ class Exhaustive32BitOrLessUnaryTest typename ExhaustiveOpTestBase::ComponentIntegralNativeT; auto [begin, end] = GetParam(); + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + LOG(INFO) << "\tfrom=" << begin << "; hex=" << std::hex << begin + << "; float=" << *reinterpret_cast(&begin) + << " (inclusive)"; + LOG(INFO) << "\tto=" << end << "; hex=" << std::hex << end + << "; float=" << *reinterpret_cast(&end) + << " (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + int64_t input_size = (*input_literal)[0].element_count(); - VLOG(2) << "Checking range [" << begin << ", " << end << ")"; CHECK_EQ(input_size, end - begin); absl::Span input_arr = (*input_literal)[0].data(); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f64_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f64_test.cc index 0792b017813416..3f3b9de811fa60 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f64_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f64_test.cc @@ -53,11 +53,14 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, void FillInput(std::array* input_literal) override { FpValues fp_values = GetParam(); int64_t input_size = (*input_literal)[0].element_count(); - LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", " - << input_size; - absl::Span input_arr = (*input_literal)[0].data(); + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\t" << fp_values.ToString(); + LOG(INFO) << "\ttotal values to test=" << input_size; + } uint64_t i = 0; + absl::Span input_arr = (*input_literal)[0].data(); for (auto bits : fp_values) { input_arr[i] = this->ConvertAndReplaceKnownIncorrectValueWith(bits, 1); ++i; diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index 8bfcca3b4ef676..fe44f7020cbf54 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/layout_util.h" #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_parser.h" @@ -535,8 +537,22 @@ ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const Literal& literal) { return const_cast(&literal); }); - return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, - reference_preprocessor); + auto assertion_result = RunAndCompareNoHloPasses( + std::move(module), fake_argument_ptrs, error, reference_preprocessor); + if (!assertion_result) { + for (const auto& literal : fake_arguments) { + uint64_t total_elements = 1; + absl::c_for_each(literal.shape().dimensions(), + [&](int64_t dim) { total_elements *= dim; }); + if (total_elements > 1000) { + LOG(ERROR) << "argument literal is too large to print: " + << literal.shape().ToString(); + continue; + } + LOG(ERROR) << "argument literal: " << literal.ToString(); + } + } + return assertion_result; } ::testing::AssertionResult HloTestBase::Run(std::unique_ptr module, @@ -1013,23 +1029,15 @@ ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( HloComputation* HloTestBase::FindComputation(HloModule* module, absl::string_view name) { - auto computations = module->computations(); - auto it = absl::c_find_if( - computations, [&](HloComputation* c) { return c->name() == name; }); - if (it == computations.end()) { - return nullptr; - } - return *it; + return hlo_query::FindComputation(module, name); } HloInstruction* HloTestBase::FindInstruction(HloModule* module, absl::string_view name) { - for (const HloComputation* c : module->computations()) { - auto instructions = c->instructions(); - auto it = absl::c_find_if( - instructions, [&](HloInstruction* i) { return i->name() == name; }); - if (it != instructions.end()) { - return *it; + for (const HloComputation* computation : module->computations()) { + if (auto instruction = hlo_query::FindFirstInstruction(computation, name); + instruction.first != nullptr) { + return instruction.first; } } return nullptr; @@ -1037,12 +1045,10 @@ HloInstruction* HloTestBase::FindInstruction(HloModule* module, HloInstruction* HloTestBase::FindInstruction(HloModule* module, HloOpcode opcode) { - for (const HloComputation* c : module->computations()) { - auto instructions = c->instructions(); - auto it = absl::c_find_if( - instructions, [&](HloInstruction* i) { return i->opcode() == opcode; }); - if (it != instructions.end()) { - return *it; + for (const HloComputation* computation : module->computations()) { + if (auto instruction = hlo_query::FindFirstInstruction(computation, opcode); + instruction.first != nullptr) { + return instruction.first; } } return nullptr; diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index 24bcc948a6f465..4c194c8de351a0 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -425,7 +425,10 @@ class HloTestBase : public ManifestCheckingTest { } // Gets the computation/instruction from the given module with the given name. - // + // Note that it is encouraged to use these functions directly via the + // hlo_query.h header instead since they are independent from any test-time + // variables or contexts. + // This is useful for tests which create HLOs from a string and then want to // inspect a particular computation or instruction. HloComputation* FindComputation(HloModule* module, absl::string_view name); diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 1cb36a556f72c9..f904e11f51fad1 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -314,6 +314,7 @@ xla_cc_binary( xla_cc_binary( name = "hlo-opt", testonly = True, + linkopts = ["-Wl,-rpath,$$ORIGIN/../lit_lib"], deps = [ "//xla/tools/hlo_opt:opt_main", ], @@ -833,13 +834,46 @@ tsl_gpu_library( ) xla_test( - name = "xla_compile_lib_test", - srcs = ["xla_compile_lib_test.cc"], + name = "xla_cpu_compile_lib_test", + srcs = ["xla_cpu_compile_lib_test.cc"], + backends = [ + "cpu", + ], + data = [ + ":data/add.hlo", + ], + deps = [ + ":xla_compile_lib", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:platform_util", + "//xla/service:symbol_repository", + "//xla/service:xla_compile_result_proto_cc_impl", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_time", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_tsl//tsl/protobuf:status_proto_cc", + ] + if_google(["@com_google_protobuf//:duration_cc_proto"]), +) + +xla_test( + name = "xla_gpu_compile_lib_test", + srcs = ["xla_gpu_compile_lib_test.cc"], backend_tags = { "gpu": ["requires-gpu-nvidia"] + if_google(["config-cuda-only"]), }, backends = [ - "cpu", "gpu", ], data = [ @@ -847,9 +881,6 @@ xla_test( "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt", "//xla/service/gpu:gpu_compiler_test_autotune_db.textproto", ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), deps = [ ":xla_compile_lib", "//xla:util", @@ -861,21 +892,17 @@ xla_test( "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:device_description_proto_cc", "//xla/tests:hlo_test_base", - "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_time", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_tsl//tsl/protobuf:status_proto_cc", - ] + if_google(["@com_google_protobuf//:duration_cc_proto"]), + ], ) xla_test( diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index c50fb9a25941f2..d9cefa0eddad8a 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -175,6 +175,7 @@ lit_test_suite( cfg = "//xla:lit.cfg.py", data = [":test_utilities"], default_tags = tf_cuda_tests_tags(), + hermetic_cuda_data_dir = "%S/../../../../cuda_nvcc", tags_override = { "gpu_hlo_ptx.hlo": ["no_rocm"], }, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index 1e1450beb51699..f485145a5f286c 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -172,12 +172,9 @@ xla_test( name = "functional_hlo_runner_test", srcs = ["functional_hlo_runner_test.cc"], backend_tags = { - # This test is tagged "manual" because it requires multiple (2) GPUs. "gpu": [ - "manual", - "multi_gpu", + "multi_gpu_h100", "no_oss", - "notap", ], }, backends = [ diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc index 1a490a0802ef13..3aabf56650af23 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc @@ -57,43 +57,41 @@ static absl::StatusOr> GetPjRtClient( if (enable_mock_nccl) { CHECK_GT(num_nodes, 1); return CreateMockGpuClient(num_nodes); - } else { - if (num_nodes == 1) { - return CreateGpuClient({}); - } else { - TF_RET_CHECK(!address.empty()); - TF_RET_CHECK(node_id >= 0) - << "Node id is expected to be in range [0, num_nodes)"; - TF_RET_CHECK(node_id < num_nodes) - << "Node id is expected to be in range [0, num_nodes)"; - - CHECK_GT(address.length(), 0); - // Multinode. Start service on task 0. - if (node_id == 0) { - std::string coordinator_bind_address = - "[::]:" + std::string(address).substr(address.rfind(':') + 1); - xla::CoordinationServiceImpl::Options options; - options.num_nodes = num_nodes; - auto status_or = xla::GetDistributedRuntimeService( - coordinator_bind_address, options); - TF_QCHECK_OK(status_or.status()); - service = std::move(status_or.value()); - } - xla::DistributedRuntimeClient::Options options; - options.node_id = node_id; - options.init_timeout = init_timeout; - distributed_client = - GetDistributedRuntimeClient(std::string(address), options); - TF_QCHECK_OK(distributed_client->Connect()); - kv_store = GetDistributedKeyValueStore(distributed_client, - /*key_prefix=*/"gpu:"); - GpuClientOptions gpu_client_options; - gpu_client_options.node_id = node_id; - gpu_client_options.num_nodes = num_nodes; - gpu_client_options.kv_store = kv_store; - return CreateGpuClient(std::move(gpu_client_options)); - } } + + if (num_nodes == 1) { + return CreateGpuClient({}); + } + + TF_RET_CHECK(!address.empty()); + TF_RET_CHECK(node_id >= 0) + << "Node id is expected to be in range [0, num_nodes)"; + TF_RET_CHECK(node_id < num_nodes) + << "Node id is expected to be in range [0, num_nodes)"; + + CHECK_GT(address.length(), 0); + // Multinode. Start service on task 0. + if (node_id == 0) { + std::string coordinator_bind_address = + "[::]:" + std::string(address).substr(address.rfind(':') + 1); + xla::CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + TF_ASSIGN_OR_RETURN(service, xla::GetDistributedRuntimeService( + coordinator_bind_address, options)); + } + xla::DistributedRuntimeClient::Options options; + options.node_id = node_id; + options.init_timeout = init_timeout; + distributed_client = + GetDistributedRuntimeClient(std::string(address), options); + TF_QCHECK_OK(distributed_client->Connect()); + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); + GpuClientOptions gpu_client_options; + gpu_client_options.node_id = node_id; + gpu_client_options.num_nodes = num_nodes; + gpu_client_options.kv_store = kv_store; + return CreateGpuClient(std::move(gpu_client_options)); } absl::StatusOr GetPjRtClient(absl::string_view device_type, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 405e421a8d86f3..a7f986c8fc3d66 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -355,9 +355,9 @@ int main(int argc, char* argv[]) { xla::AppendDebugOptionsFlags(&flag_list); std::string usage = tsl::Flags::Usage(argv[0], flag_list); tsl::Flags::Parse(&argc, argv, flag_list); + testing::InitGoogleTest(&argc, argv); if (node_id >= 0) { return !xla::ShardedAutotuningWorksTestBody(node_id).ok(); } - testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc new file mode 100644 index 00000000000000..62c06734ddb990 --- /dev/null +++ b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "google/protobuf/duration.pb.h" +#include +#include +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/platform_util.h" +#include "xla/service/symbol_repository.h" +#include "xla/service/xla_compile_result.pb.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tools/xla_compile_lib.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/env_time.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/path.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/protobuf/error_codes.pb.h" +#include "tsl/protobuf/status.pb.h" + +namespace xla { +namespace { + +using ::testing::IsEmpty; +using ::testing::IsNull; +using ::testing::Not; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +class XlaCompileLibTest : public HloTestBase { + protected: + XlaCompileLibTest() + : HloTestBase(*PlatformUtil::GetPlatform("Host"), + GetReferencePlatform()) {} + void SetUp() override { + const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), + "tools", "data", "add.hlo"); + std::string hlo; + TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo)); + } + + std::unique_ptr module_; +}; + +TEST_F(XlaCompileLibTest, CompilesForCpu) { + CompilationResult result; + EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kCpu, + std::nullopt, result), + IsOkAndHolds(Not(IsEmpty()))); +} + +TEST_F(XlaCompileLibTest, ErrorsOnUnexpectedPlatform) { + XlaCompileOptions options; + options.platform = "tpu"; + EXPECT_THAT(XlaCompileMain(options), StatusIs(tsl::error::UNIMPLEMENTED)); +} + +TEST_F(XlaCompileLibTest, WriteResultFilePropagatesErrors) { + TimerStats stats; + CompilationResult result; + EXPECT_THAT(WriteResultFile("/does/not/exist", stats, result), Not(IsOk())); +} + +TEST_F(XlaCompileLibTest, WriteResultFileWritesTheFile) { + std::string result_output_file; + ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&result_output_file)); + + TimerStats stats; + { + absl::MutexLock ml(&stats.stats_mutex); + stats.cumulative_secs = 5.5; + stats.max_secs = 5.5; + } + + CompilationResult result; + google::protobuf::Duration duration; + duration.set_seconds(5); + duration.set_nanos(0.5 * tsl::EnvTime::kSecondsToNanos); + *result.mutable_perf_stats()->mutable_compilation_duration() = duration; + *result.mutable_perf_stats()->mutable_total_duration() = duration; + + TF_ASSERT_OK(WriteResultFile(result_output_file, stats, result)); + + CompilationResult got_result; + TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_output_file, + &got_result)); + // Sadly EqualsProto isn't OSS, so we inspect a few fields manually. + // See googletest#1761 and b/229726259. + EXPECT_EQ(5, got_result.perf_stats().compilation_duration().seconds()); + EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos, + got_result.perf_stats().compilation_duration().nanos()); + EXPECT_EQ(5, got_result.perf_stats().total_duration().seconds()); + EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos, + got_result.perf_stats().total_duration().nanos()); +} + +TEST_F(XlaCompileLibTest, LoadModuleErrors) { + EXPECT_THAT(LoadModule("/does/not/exist"), Not(IsOk())); +} + +TEST_F(XlaCompileLibTest, LoadModuleLoadsTextFormat) { + const std::string module_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); + TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, + module_->ToString())); + + EXPECT_THAT(LoadModule(module_file), IsOkAndHolds(Not(IsNull()))); +} + +TEST_F(XlaCompileLibTest, MainForCpu) { + const std::string module_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); + TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, + module_->ToString())); + + const std::string output_path = + tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_output"); + const std::string result_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_result.pb"); + + XlaCompileOptions options; + options.module_path = module_file; + options.output_path = output_path; + options.platform = "cpu"; + options.result_output_file = result_file; + TF_EXPECT_OK(XlaCompileMain(options)); + + CompilationResult result; + TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result)); + EXPECT_TRUE(result.has_status()); + EXPECT_EQ(result.status().code(), tensorflow::error::OK); +} + +TEST_F(XlaCompileLibTest, LoadAutotuneDataCpu) { + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kCpu), + IsOkAndHolds(false)); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/tools/xla_compile_lib_test.cc b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc similarity index 56% rename from third_party/xla/xla/tools/xla_compile_lib_test.cc rename to third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc index 538821eaf57856..bc34c8790fb14e 100644 --- a/third_party/xla/xla/tools/xla_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc @@ -13,30 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tools/xla_compile_lib.h" - #include #include #include #include -#include "google/protobuf/duration.pb.h" #include #include -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_symbol_repository.h" #include "xla/service/platform_util.h" #include "xla/service/symbol_repository.h" #include "xla/service/xla_compile_result.pb.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tools/xla_compile_lib.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/env.h" -#include "tsl/platform/env_time.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -44,35 +39,17 @@ limitations under the License. #include "tsl/protobuf/error_codes.pb.h" #include "tsl/protobuf/status.pb.h" -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/autotuning/autotuner_util.h" -#endif - namespace xla { namespace { using ::testing::IsEmpty; -using ::testing::IsNull; using ::testing::Not; -using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; -using ::tsl::testing::StatusIs; - -#if XLA_TEST_BACKEND_CPU -static constexpr absl::string_view kPlatformName = "Host"; -#elif XLA_TEST_BACKEND_GPU -static constexpr absl::string_view kPlatformName = -#if TENSORFLOW_USE_ROCM - "ROCM"; -#else - "CUDA"; -#endif -#endif // XLA_TEST_BACKEND_CPU class XlaCompileLibTest : public HloTestBase { protected: XlaCompileLibTest() - : HloTestBase(*PlatformUtil::GetPlatform(std::string(kPlatformName)), + : HloTestBase(*PlatformUtil::GetPlatform(std::string("GPU")), GetReferencePlatform()) {} void SetUp() override { const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), @@ -85,14 +62,7 @@ class XlaCompileLibTest : public HloTestBase { std::unique_ptr module_; }; -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(CompilesForCpu)) { - CompilationResult result; - EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kCpu, - std::nullopt, result), - IsOkAndHolds(Not(IsEmpty()))); -} - -TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithDevice)) { +TEST_F(XlaCompileLibTest, CompilesForGpuWithDevice) { CompilationResult result; EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu, std::nullopt, result), @@ -100,7 +70,7 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithDevice)) { EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); } -TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) { +TEST_F(XlaCompileLibTest, CompilesForGpuWithoutDevice) { const std::string target_config_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "xla_aot_compile_test_gpu_target_config.prototxt"); @@ -114,89 +84,7 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) { EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); } -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(ErrorsOnUnexpectedPlatform)) { - XlaCompileOptions options; - options.platform = "tpu"; - EXPECT_THAT(XlaCompileMain(options), StatusIs(tsl::error::UNIMPLEMENTED)); -} - -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFilePropagatesErrors)) { - TimerStats stats; - CompilationResult result; - EXPECT_THAT(WriteResultFile("/does/not/exist", stats, result), Not(IsOk())); -} - -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFileWritesTheFile)) { - std::string result_output_file; - ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&result_output_file)); - - TimerStats stats; - { - absl::MutexLock ml(&stats.stats_mutex); - stats.cumulative_secs = 5.5; - stats.max_secs = 5.5; - } - - CompilationResult result; - google::protobuf::Duration duration; - duration.set_seconds(5); - duration.set_nanos(0.5 * tsl::EnvTime::kSecondsToNanos); - *result.mutable_perf_stats()->mutable_compilation_duration() = duration; - *result.mutable_perf_stats()->mutable_total_duration() = duration; - - TF_ASSERT_OK(WriteResultFile(result_output_file, stats, result)); - - CompilationResult got_result; - TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_output_file, - &got_result)); - // Sadly EqualsProto isn't OSS, so we inspect a few fields manually. - // See googletest#1761 and b/229726259. - EXPECT_EQ(5, got_result.perf_stats().compilation_duration().seconds()); - EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos, - got_result.perf_stats().compilation_duration().nanos()); - EXPECT_EQ(5, got_result.perf_stats().total_duration().seconds()); - EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos, - got_result.perf_stats().total_duration().nanos()); -} - -TEST_F(XlaCompileLibTest, LoadModuleErrors) { - EXPECT_THAT(LoadModule("/does/not/exist"), Not(IsOk())); -} - -TEST_F(XlaCompileLibTest, LoadModuleLoadsTextFormat) { - const std::string module_file = - tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); - TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, - module_->ToString())); - - EXPECT_THAT(LoadModule(module_file), IsOkAndHolds(Not(IsNull()))); -} - -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(MainForCpu)) { - const std::string module_file = - tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); - TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, - module_->ToString())); - - const std::string output_path = - tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_output"); - const std::string result_file = - tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_result.pb"); - - XlaCompileOptions options; - options.module_path = module_file; - options.output_path = output_path; - options.platform = "cpu"; - options.result_output_file = result_file; - TF_EXPECT_OK(XlaCompileMain(options)); - - CompilationResult result; - TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result)); - EXPECT_TRUE(result.has_status()); - EXPECT_EQ(result.status().code(), tensorflow::error::OK); -} - -TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) { +TEST_F(XlaCompileLibTest, MainForGpu) { const std::string module_file = tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, @@ -221,17 +109,7 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) { EXPECT_EQ(result.status().code(), tensorflow::error::OK); } -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(LoadAutotuneDataCpu)) { - HloModuleAndMetadata mod; - mod.hlo_module = std::move(module_); - - EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kCpu), - IsOkAndHolds(false)); -} - -TEST_F(XlaCompileLibTest, - DISABLED_ON_CPU(LoadAutotuneDataGpuDataPresentAndAutotuningEnabled)) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TEST_F(XlaCompileLibTest, LoadAutotuneDataGpuDataPresentAndAutotuningEnabled) { gpu::AutotunerUtil::ClearAutotuneResults(); HloModuleAndMetadata mod; @@ -254,12 +132,9 @@ TEST_F(XlaCompileLibTest, EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), IsOkAndHolds(true)); EXPECT_FALSE(gpu::AutotunerUtil::ResultCacheIsEmpty()); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } -TEST_F(XlaCompileLibTest, - DISABLED_ON_CPU(LoadAutotuneDataGpuDataPresentAndAutotuningDisabled)) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TEST_F(XlaCompileLibTest, LoadAutotuneDataGpuDataPresentAndAutotuningDisabled) { gpu::AutotunerUtil::ClearAutotuneResults(); HloModuleAndMetadata mod; @@ -282,12 +157,10 @@ TEST_F(XlaCompileLibTest, EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), IsOkAndHolds(false)); EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } TEST_F(XlaCompileLibTest, - DISABLED_ON_CPU(LoadAutotuneDataGpuDataNotPresentAndAutotuningEnabled)) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + LoadAutotuneDataGpuDataNotPresentAndAutotuningEnabled) { gpu::AutotunerUtil::ClearAutotuneResults(); HloModuleAndMetadata mod; @@ -300,13 +173,10 @@ TEST_F(XlaCompileLibTest, EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), IsOkAndHolds(false)); EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } -TEST_F( - XlaCompileLibTest, - DISABLED_ON_CPU(LoadAutotuneDataGpuDataNotPresentAndAutotuningDisabled)) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TEST_F(XlaCompileLibTest, + LoadAutotuneDataGpuDataNotPresentAndAutotuningDisabled) { gpu::AutotunerUtil::ClearAutotuneResults(); HloModuleAndMetadata mod; @@ -319,7 +189,6 @@ TEST_F( EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), IsOkAndHolds(false)); EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } } // namespace diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index 9b87ab5510ee07..64908395b52ba8 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -102,7 +102,7 @@ constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication"; // Note: This sanitization function causes an irreversible many-to-one mapping // and any solution to mitigate this would cause issues with the reverse -// direction. Longterm solution is to add a function attribute to maintain the +// direction. Long-term solution is to add a function attribute to maintain the // original HLO naming. std::string SanitizeFunctionName(llvm::StringRef name) { std::string output(name); diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD index 931448842644ea..40e05c8873d229 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD @@ -185,18 +185,26 @@ cc_library( deps = [ ":mlir_hlo_to_hlo", ":type_to_shape", + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla/client:xla_builder", + "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "//xla/service:hlo_proto_util", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h index 2c85a82680345a..2ecd4e3ef3ba3d 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h @@ -19,8 +19,9 @@ limitations under the License. #define XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ #include -#include +#include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/client/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -30,10 +31,10 @@ limitations under the License. namespace mlir { // XLA Layout preferences. Currently, when it comes to TPU, there are two -// primary layout choices for any XLA argumetns (parameter or resource): (1) +// primary layout choices for any XLA arguments (parameter or resource): (1) // CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU // layout while Linear is native host (CPU) layout. -// This enum allows the caller of XLA to progogate layout preference to the XLA +// This enum allows the caller of XLA to propagate layout preference to the XLA // compiler. // kNoPreference: the generic layout where the XLA compiler has the freedom // to assign any layout. diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 287b455083cd5f..3e965fdaeb89a1 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -1585,6 +1585,8 @@ LogicalResult ExportXlaOp(DotOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::XlaOp lhs, rhs; + // TODO: Support algorithm lowering in followup. + if (op.getAlgorithm().has_value()) return mlir::failure(); if (failed(GetXlaOp(op.getLhs(), value_map, &lhs, op))) return mlir::failure(); if (failed(GetXlaOp(op.getRhs(), value_map, &rhs, op))) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc b/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc index 88f05238846b7a..7dad7c322e2d84 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc @@ -22,18 +22,20 @@ limitations under the License. namespace mlir { namespace mhlo { namespace { -constexpr char kConfigNumPartitions[] = "mhlo.num_partitions"; -constexpr char kConfigNumReplicas[] = "mhlo.num_replicas"; + +constexpr char kMhloNumPartitions[] = "mhlo.num_partitions"; +constexpr char kMhloNumReplicas[] = "mhlo.num_replicas"; + } // namespace void ExportHloModuleConfig(xla::HloModuleConfig& config, mlir::ModuleOp module) { if (auto num_partitions = - module->getAttrOfType(kConfigNumPartitions)) { + module->getAttrOfType(kMhloNumPartitions)) { config.set_num_partitions(num_partitions.getInt()); } if (auto num_replicas = - module->getAttrOfType(kConfigNumReplicas)) { + module->getAttrOfType(kMhloNumReplicas)) { config.set_replica_count(num_replicas.getInt()); } } diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir index 680341f0d899ac..3d44aff99a7226 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir @@ -1,5 +1,5 @@ // RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts %s | FileCheck %s -// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts --via-builder=true %s | FileCheck %s +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts --via-builder=true %s | FileCheck %s #CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir index c39048f5663e05..6672e62daf04de 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -1,6 +1,16 @@ // RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s // RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics --via-builder=true %s | FileCheck %s +// CHECK: HloModule foo +// CHECK: ENTRY %main +module @foo { + func.func @main(%arg: tensor) -> tensor { + func.return %arg : tensor + } +} + +// ----- + // CHECK: HloModule func.func @main(%arg0: tensor<2xi1>) -> tensor<2xi1> { %0 = "mhlo.add"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc b/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc index 8cff1c99592b1c..7c07582a46c794 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc @@ -16,26 +16,42 @@ limitations under the License. #include #include +#include #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_proto_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication"; @@ -123,6 +139,8 @@ absl::Status ConvertMlirHloToHloViaBuilder( mlir::cast(b).getValue()); auto hlo_module = computation.proto(); + mlir::StringRef module_name = module.getName() ? *module.getName() : "main"; + hlo_module.set_name(module_name.str()); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); return absl::OkStatus(); diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref.h b/third_party/xla/xla/tsl/concurrency/async_value_ref.h index 65fb655448ced8..1065a7b5fcc3dc 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ref.h +++ b/third_party/xla/xla/tsl/concurrency/async_value_ref.h @@ -63,6 +63,56 @@ AsyncValueRef MakeConstructedAsyncValueRef(Args&&... args); template AsyncValueRef MakeAvailableAsyncValueRef(Args&&... args); +// A collection of type traits used by AsyncValueRef and AsyncValuePtr. +namespace internal { + +// Detects if a type is a specialization of an AsyncValueRef template. +template +struct IsAsyncValueRef : std::false_type {}; +template +struct IsAsyncValueRef> : std::true_type {}; + +template +inline constexpr bool is_async_value_ref_v = IsAsyncValueRef::value; + +// Detects types that are `absl::StatusOr` container. +template +struct IsStatusOr : std::false_type {}; +template +struct IsStatusOr> : std::true_type {}; + +// Type predicates for detecting absl::Status-like types. +template +static constexpr bool is_status_v = std::is_same_v; +template +static constexpr bool is_status_or_v = IsStatusOr::value; +template +static constexpr bool is_status_like_v = is_status_v || is_status_or_v; + +// Deduces the result type of invoking `F` with a first compatible `Arg`. +template +struct FirstInvokeResult { + template > + struct is_invocable : std::false_type { + using type = void; + }; + + template + struct is_invocable : std::true_type { + using type = std::invoke_result_t; + }; + + using type = typename std::disjunction...>::type; +}; + +// In contrast to `std::invoke_result_t` `Args` are not passed to `F` all +// together, but instead they are passed one-by-one, and the first valid one +// determines the result type. +template +using first_invoke_result_t = typename FirstInvokeResult::type; + +} // namespace internal + // AsyncValueRef is an asynchronous container for a payload of type `T` or an // error of type `absl::Status`. It is similar to an `absl::StatusOr`, but // does not require immediate value or error to be constructed. It is a promise @@ -295,6 +345,7 @@ class AsyncValueRef { return value_->SetError(std::move(status)); } + ABSL_DEPRECATED("Use SetError with absl::Status argument") void SetError(std::string_view message) const { SetError(absl::InternalError(message)); } @@ -335,35 +386,12 @@ class AsyncValueRef { RCReference value_; }; -// Detects if a type is a specialization of an AsyncValueRef template. -template -struct IsAsyncValueRef : std::false_type {}; -template -struct IsAsyncValueRef> : std::true_type {}; - -template -inline constexpr bool is_async_value_ref_v = IsAsyncValueRef::value; - // Non owning typed pointer for the AsyncValue. Can be cheaply passed around // when the lifetime of the underlying async value is clear from the context. // It is the user responsibility to construct an owning AsyncValueRef to extend // the lifetime of the underlying value if needed. template class AsyncValuePtr { - // Detect result types that are `absl::StatusOr` container. - template - struct IsStatusOr : std::false_type {}; - template - struct IsStatusOr> : std::true_type {}; - - // Type predicates for detecting absl::Status-like types. - template - static constexpr bool is_status_v = std::is_same_v; - template - static constexpr bool is_status_or_v = IsStatusOr::value; - template - static constexpr bool is_status_like_v = is_status_v || is_status_or_v; - // Wait for async value availability: AndThen([] {}) template using SimpleWaiter = std::enable_if_t>; @@ -383,26 +411,25 @@ class AsyncValuePtr { using StatusWaiter = std::enable_if_t<(std::is_invocable_v && !std::is_invocable_v> && - !is_status_v)>; - - // Because AsyncValue itself is a discriminated union of absl::Status and - // typed payload (error or value) the use of AsyncValueRef is - // discouraged (work in progress to disable with static assert) and `Map` - // automatically folds returned status-like object into the returned async - // value error. + !internal::is_status_v)>; - // Async value map functor: Map([](T& value) -> U); - // - R must be constructible from U - template + // Map async value of type `T` to an async value of type `R`. + template > using MapFunctor = std::enable_if_t>; - // Async value try map functor: TryMap([](T& value) -> absl::StatusOr); - // - R must be constructible from U - template + // Try map async value of type `T` to an async value of type `R`. + template > using TryMapFunctor = - std::enable_if_t && is_status_or_v && - std::is_constructible_v && - !std::is_constructible_v>; + std::enable_if_t && + std::is_constructible_v>; + + // Flat map async value of type `T` to an async value `R` (`R` itself is an + // async value ref). Returns `R` value type (async payload type). + template >> + using FlatMapFunctor = std::enable_if_t, + typename R::value_type>; public: // AsyncValuePtr::value_type @@ -593,8 +620,7 @@ class AsyncValuePtr { // return U(value); // R must be constructible from U // }) // - template , - MapFunctor* = nullptr> + template * = nullptr> AsyncValueRef Map(F&& f) { auto result = MakeUnconstructedAsyncValueRef(); AndThen([f = std::forward(f), result, ptr = *this]() mutable { @@ -608,8 +634,7 @@ class AsyncValuePtr { } // An overload that executes `f` on a user-provided executor. - template , - MapFunctor* = nullptr> + template * = nullptr> AsyncValueRef Map(AsyncValue::Executor& executor, F&& f) { auto result = MakeUnconstructedAsyncValueRef(); // We don't know when the executor will run the callback, so we need to @@ -639,8 +664,7 @@ class AsyncValuePtr { // // If returned status container will have an error status, it will be // automatically converted to async value error. - template , - TryMapFunctor* = nullptr> + template * = nullptr> AsyncValueRef TryMap(F&& f) { auto result = MakeUnconstructedAsyncValueRef(); AndThen([f = std::forward(f), result, ptr = *this]() mutable { @@ -659,8 +683,7 @@ class AsyncValuePtr { } // An overload that executes `f` on a user-provided executor. - template , - TryMapFunctor* = nullptr> + template * = nullptr> AsyncValueRef TryMap(AsyncValue::Executor& executor, F&& f) { auto result = MakeUnconstructedAsyncValueRef(); // We don't know when the executor will run the callback, so we need to @@ -696,7 +719,7 @@ class AsyncValuePtr { // A `TryMap` overload that automatically infers the type of result from `f`. template , - std::enable_if_t>* = nullptr> + std::enable_if_t>* = nullptr> auto TryMap(F&& f) { return TryMap(std::forward(f)); } @@ -704,12 +727,12 @@ class AsyncValuePtr { // A `TryMap` overload that automatically infers the type of result from `f` // and executes `f` on user-provided executor. template , - std::enable_if_t>* = nullptr> + std::enable_if_t>* = nullptr> auto TryMap(AsyncValue::Executor& executor, F&& f) { return TryMap(executor, std::forward(f)); } - // Returns and AsyncValueRef that will be forwarded to the AsyncValueRef + // Returns an AsyncValueRef that will be forwarded to the AsyncValueRef // returned from a functor. // // Sample usage: @@ -718,14 +741,25 @@ class AsyncValuePtr { // return LaunchAsyncTask(value); // }) // - template , - std::enable_if_t>* = nullptr> - AsyncValueRef FlatMap(F&& f) { + // Functor argument can be a `T&` or an `AsyncValueRef`, where async value + // pointer is guaranteed to be in concrete state. Async value pointer allows + // the functor to extend the lifetime of underlying async value if needed. + // + // async_value_ptr.FlatMap([](AsyncValuePtr ptr) -> AsyncValueRef { + // return LaunchAsyncTask([ref = ptr.CopyRef()] { ... }); + // }) + // + template > + AsyncValueRef FlatMap(F&& f) { // If async value is in concrete state, we can immediately call the functor. // We don't handle errors here and prefer a generic code path below because // error handling is never on a performance critical path. if (ABSL_PREDICT_TRUE(IsConcrete())) { - return f(get()); + if constexpr (std::is_invocable_v) { + return f(get()); + } else { + return f(*this); + } } auto promise = MakePromise(); @@ -733,17 +767,19 @@ class AsyncValuePtr { if (ABSL_PREDICT_FALSE(ptr.IsError())) { promise->SetError(ptr.GetError()); } else { - promise->ForwardTo(f(*ptr)); + if constexpr (std::is_invocable_v) { + promise->ForwardTo(f(*ptr)); + } else { + promise->ForwardTo(f(ptr)); + } } }); - return AsyncValueRef(promise); + return AsyncValueRef(promise); } // An overload that executes `f` on a user-provided executor. - template , - std::enable_if_t>* = nullptr> - AsyncValueRef FlatMap(AsyncValue::Executor& executor, - F&& f) { + template > + AsyncValueRef FlatMap(AsyncValue::Executor& executor, F&& f) { // We don't have a special handling for concrete values here because // we must execute user functor on a separate executor and can't call it in // the caller thread. @@ -755,10 +791,14 @@ class AsyncValuePtr { if (ABSL_PREDICT_FALSE(ref.IsError())) { promise->SetError(ref.GetError()); } else { - promise->ForwardTo(f(*ref)); + if constexpr (std::is_invocable_v) { + promise->ForwardTo(f(*ref)); + } else { + promise->ForwardTo(f(ref.AsPtr())); + } } }); - return AsyncValueRef(promise); + return AsyncValueRef(promise); } private: @@ -767,8 +807,8 @@ class AsyncValuePtr { // types and this will be a run time error. template RCReference MakePromise() { - if constexpr (std::is_final_v) { - return MakeIndirectAsyncValue(); + if constexpr (std::is_final_v) { + return MakeIndirectAsyncValue(); } else { return MakeIndirectAsyncValue(); }; @@ -918,6 +958,47 @@ AsyncValueRef MakeAvailableAsyncValueRef(Args&&... args) { std::forward(args)...))); } +// Allocates an AsyncValueRef that is constructed from the result of calling an +// `f` on a user-provided `executor`. +// +// Sample usage: +// +// MakeAsyncValueRef(executor, []() -> int32_t { ... }); +// +template , + std::enable_if_t>* = nullptr> +AsyncValueRef MakeAsyncValueRef(AsyncValue::Executor& executor, F&& f) { + auto result = MakeUnconstructedAsyncValueRef(); + executor.Execute([result, f = std::forward(f)] { result.emplace(f()); }); + return result; +} + +// Allocates an AsyncValueRef that is constructed from the result of calling an +// `f` on a user-provided `executor`. `F` must return an absl::StatusOr, and +// result of type `T` must be constructible from `U`. +// +// Sample usage: +// +// TryMakeAsyncValueRef(executor, +// []() -> absl::StatusOr { ... }); +// +template , + std::enable_if_t< + internal::is_status_or_v && + std::is_constructible_v>* = nullptr> +AsyncValueRef TryMakeAsyncValueRef(AsyncValue::Executor& executor, F&& f) { + auto result = MakeUnconstructedAsyncValueRef(); + executor.Execute([result, f = std::forward(f)] { + absl::StatusOr status_or = f(); + if (ABSL_PREDICT_TRUE(status_or.ok())) { + result.emplace(std::move(status_or).value()); + } else { + result.SetError(std::move(status_or).status()); + } + }); + return result; +} + //===----------------------------------------------------------------------===// // Constructing non-reference-counted values in user provided storage. //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc index 646b05b2246e82..0cb4aad9b3b588 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc @@ -420,6 +420,44 @@ struct DeferredExecutor : public AsyncValue::Executor { std::vector tasks; }; +TEST(AsyncValueRefTest, MakeAsyncValueRef) { + DeferredExecutor executor; + + { // Make AsyncValueRef from a function that returns a value. + AsyncValueRef ref = + MakeAsyncValueRef(executor, []() -> float { return 42.0f; }); + + EXPECT_FALSE(ref.IsAvailable()); + EXPECT_EQ(executor.Quiesce(), 1); + + EXPECT_TRUE(ref.IsAvailable()); + EXPECT_EQ(ref.get(), 42.0f); + } + + { // Make AsyncValueRef from a function that returns a StatusOr value. + AsyncValueRef ref = TryMakeAsyncValueRef( + executor, []() -> absl::StatusOr { return 42.0f; }); + + EXPECT_FALSE(ref.IsAvailable()); + EXPECT_EQ(executor.Quiesce(), 1); + + EXPECT_TRUE(ref.IsAvailable()); + EXPECT_EQ(ref.get(), 42.0f); + } + + { // Make AsyncValueRef from a function that returns a StatusOr error. + AsyncValueRef ref = TryMakeAsyncValueRef( + executor, + []() -> absl::StatusOr { return absl::InternalError("test"); }); + + EXPECT_FALSE(ref.IsAvailable()); + EXPECT_EQ(executor.Quiesce(), 1); + + EXPECT_TRUE(ref.IsError()); + EXPECT_EQ(ref.GetError(), absl::InternalError("test")); + } +} + TEST(AsyncValueRefTest, MapAvailableOnExecutor) { AsyncValueRef ref = MakeAvailableAsyncValueRef(42); @@ -521,6 +559,52 @@ TEST(AsyncValueRefTest, FlatMapAvailableOnExecutor) { EXPECT_EQ(fmapped_to_float.get(), 42.0f); } +TEST(AsyncValueRefTest, FlatMapDeferredAsyncValueOnExecutor) { + DeferredExecutor executor0; + DeferredExecutor executor1; + + // Use non-copyable std::unique_ptr to make sure that we don't + // accidentally copy the value into the FlatMap functor. + + { // Use a regular FlatMap. + AsyncValueRef fmapped_to_float = + MakeAsyncValueRef>(executor0, [] { + return std::make_unique(42); + }).FlatMap([&](AsyncValuePtr> ptr) { + return MakeAsyncValueRef( + executor1, [ref = ptr.CopyRef()] { return **ref; }); + }); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(executor0.Quiesce(), 1); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(executor1.Quiesce(), 1); + + EXPECT_TRUE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(fmapped_to_float.get(), 42.0f); + } + + { // Use a FlatMap that itself executed on given executor. + AsyncValueRef fmapped_to_float = + MakeAsyncValueRef>(executor0, [] { + return std::make_unique(42); + }).FlatMap(executor1, [&](AsyncValuePtr> ptr) { + return MakeAsyncValueRef( + executor1, [ref = ptr.CopyRef()] { return **ref; }); + }); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(executor0.Quiesce(), 1); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(executor1.Quiesce(), 2); + + EXPECT_TRUE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(fmapped_to_float.get(), 42.0f); + } +} + TEST(AsyncValueRefTest, BlockUntilReady) { AsyncValueRef ref = MakeAvailableAsyncValueRef(42); BlockUntilReady(ref); diff --git a/third_party/xla/xla/tsl/concurrency/async_value_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_test.cc index f03034d5c67517..eb14685f37903f 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_test.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_test.cc @@ -132,7 +132,7 @@ TEST(AsyncValueTest, KeepPayloadOnError) { EXPECT_TRUE(!value.IsError()); - value.SetError("error"); + value.SetError(absl::InternalError("error")); EXPECT_EQ(1, *value->value); EXPECT_TRUE(value.IsError()); diff --git a/third_party/xla/xla/tsl/cuda/BUILD.bazel b/third_party/xla/xla/tsl/cuda/BUILD.bazel index dabb8f5f4b11df..704e0b9c50e5d4 100644 --- a/third_party/xla/xla/tsl/cuda/BUILD.bazel +++ b/third_party/xla/xla/tsl/cuda/BUILD.bazel @@ -10,6 +10,10 @@ load( "cuda_rpath_flags", "if_cuda_is_configured", ) +load( + "//xla/tsl:tsl.bzl", + "if_hermetic_cuda_libs", +) load("//xla/tsl/cuda:stub.bzl", "cuda_stub") package( @@ -22,7 +26,7 @@ cuda_stub( ) cc_library( - name = "cublas", # buildifier: disable=duplicated-name + name = "cublas_stub", srcs = if_cuda_is_configured([ "cublas_stub.cc", "cublas.tramp.S", @@ -44,13 +48,19 @@ cc_library( ]), ) +alias( + name = "cublas", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cublas//:cublas", ":cublas_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cublasLt", srcs = ["cublasLt.symbols"], ) cc_library( - name = "cublas_lt", + name = "cublas_lt_stub", srcs = if_cuda_is_configured([ "cublasLt_stub.cc", "cublasLt.tramp.S", @@ -68,6 +78,12 @@ cc_library( ]), ) +alias( + name = "cublas_lt", + actual = if_hermetic_cuda_libs("@cuda_cublas//:cublasLt", ":cublas_lt_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cuda", srcs = ["cuda.symbols"], @@ -98,7 +114,7 @@ cuda_stub( ) cc_library( - name = "cudart", # buildifier: disable=duplicated-name + name = "cudart_stub", srcs = select({ # include dynamic loading implementation only when if_cuda_is_configured and build dynamically "@local_xla//xla/tsl:is_cuda_enabled_and_oss": [ @@ -129,13 +145,19 @@ cc_library( }), ) +alias( + name = "cudart", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cudart//:cudart", ":cudart_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cudnn", srcs = ["cudnn.symbols"], ) cc_library( - name = "cudnn", # buildifier: disable=duplicated-name + name = "cudnn_stub", srcs = if_cuda_is_configured([ "cudnn_stub.cc", "cudnn.tramp.S", @@ -155,12 +177,24 @@ cc_library( ]), ) +alias( + name = "cudnn", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cudnn//:cudnn", ":cudnn_stub"), + visibility = ["//visibility:public"], +) + cc_library( - name = "nccl_rpath", + name = "nccl_rpath_flags", linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")), visibility = ["//visibility:public"], ) +alias( + name = "nccl_rpath", + actual = if_hermetic_cuda_libs("@cuda_nccl//:nccl", ":nccl_rpath_flags"), + visibility = ["//visibility:public"], +) + cc_library( name = "tensorrt_rpath", linkopts = if_cuda_is_configured(cuda_rpath_flags("tensorrt")), @@ -173,7 +207,7 @@ cuda_stub( ) cc_library( - name = "cufft", # buildifier: disable=duplicated-name + name = "cufft_stub", srcs = if_cuda_is_configured([ "cufft_stub.cc", "cufft.tramp.S", @@ -192,13 +226,19 @@ cc_library( ]), ) +alias( + name = "cufft", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cufft//:cufft", ":cufft_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cupti", srcs = ["cupti.symbols"], ) cc_library( - name = "cupti", # buildifier: disable=duplicated-name + name = "cupti_stub", srcs = if_cuda_is_configured([ "cupti_stub.cc", "cupti.tramp.S", @@ -219,13 +259,19 @@ cc_library( ]), ) +alias( + name = "cupti", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cupti//:cupti", ":cupti_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cusolver", srcs = ["cusolver.symbols"], ) cc_library( - name = "cusolver", # buildifier: disable=duplicated-name + name = "cusolver_stub", srcs = if_cuda_is_configured([ "cusolver_stub.cc", "cusolver.tramp.S", @@ -244,13 +290,19 @@ cc_library( ]), ) +alias( + name = "cusolver", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cusolver//:cusolver", ":cusolver_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cusparse", srcs = ["cusparse.symbols"], ) cc_library( - name = "cusparse", # buildifier: disable=duplicated-name + name = "cusparse_stub", srcs = if_cuda_is_configured([ "cusparse_stub.cc", "cusparse.tramp.S", @@ -270,13 +322,19 @@ cc_library( ]), ) +alias( + name = "cusparse", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cusparse//:cusparse", ":cusparse_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "nccl", srcs = ["nccl.symbols"], ) cc_library( - name = "nccl_stub", + name = "nccl", # buildifier: disable=duplicated-name srcs = if_cuda_is_configured([ "nccl_stub.cc", "nccl.tramp.S", @@ -296,3 +354,9 @@ cc_library( "@local_tsl//tsl/platform:load_library", ]), ) + +alias( + name = "nccl_stub", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_nccl//:nccl", ":nccl"), + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index a198b4c7a95b4b..30e3c5c32df348 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -78,6 +78,7 @@ tsl_gpu_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index c3ba0da797a07e..e73985c668ed16 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/bind_front.h" #include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -150,7 +151,15 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { const DeviceInfo& ListClusterDevices() override ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); uint64_t GetServiceIncarnation() override; - void StartCheckStaleness(); // Checks both heartbeat and barrier timeouts. + // Checks if any task has stopped sending heartbeats. + void CheckHeartbeatTimeout(); + // Checks if any barrier has timed out. + void CheckBarrierTimeout(); + // Checks both heartbeat and barrier timeouts. Use a single function so they + // can be run in the same thread as threads are a constrained resource. + void CheckStaleness(); + // Starts a thread to check staleness. + void StartCheckStaleness(); void Stop(bool shut_staleness_thread = true); bool ServiceHasStopped() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Report service error to a specified task. @@ -447,137 +456,134 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( StartCheckStaleness(); } -// Checks both heartbeat and barrier timeouts in the same thread, since threads -// are a constrained resource. -void CoordinationServiceStandaloneImpl::StartCheckStaleness() { - check_staleness_thread_.reset( - env_.StartThread({}, kHealthCheckThread, [this]() { - const bool has_service_to_client_connection = client_cache_ != nullptr; - // Used to store stale tasks and barriers. - std::vector stale_task_names; - absl::flat_hash_map expired_barriers; - while (true) { - { - absl::MutexLock l(&state_mu_); - check_staleness_thread_cv_.WaitWithTimeout(&state_mu_, - absl::Seconds(1)); - if (shutting_down_) { - return; - } - } - // Heartbeat check. - absl::Status status = absl::OkStatus(); - { - absl::MutexLock l(&state_mu_); - for (const auto& [task_name, task_state] : cluster_state_) { - // Skip tasks that are not registered or in error state - if (task_state->GetState() != - CoordinatedTaskState::TASKSTATE_CONNECTED) { - continue; - } - const bool is_stale = task_state->TimeSinceLastHeartbeatMs() > - heartbeat_timeout_ms_; - VLOG(10) << "Checking staleness for " << task_name - << " stale?=" << is_stale; - if (is_stale) { - stale_task_names.push_back(task_name); - status = MakeCoordinationError(absl::UnavailableError( - absl::StrCat("Task ", task_name, - " heartbeat timeout. This indicates that the " - "remote task has failed, got preempted, or " - "crashed unexpectedly. Check the task logs " - "for an earlier error to debug further."))); - SetTaskError(task_name, status); - } - } - } - // Propagate heartbeat timeout errors to other connected tasks. - if (!stale_task_names.empty()) { - if (!has_service_to_client_connection) { - absl::Status heartbeat_timeout_error = - MakeCoordinationError(absl::UnavailableError(absl::StrCat( - "The following tasks are unhealthy (stopped sending " - "heartbeats):\n", - absl::StrJoin(stale_task_names, "\n"), - "\nCheck the task logs for an earlier error to debug " - "further."))); - if (SendErrorPollingResponseOrStopService( - heartbeat_timeout_error)) { - return; - } - } else { - for (const auto& stale_task_name : stale_task_names) { - PropagateError(GetTaskFromName(stale_task_name)); - } - stale_task_names.clear(); - } - } - - // Barrier timeout check. - uint64_t current_time_micros = Env::Default()->NowMicros(); - { - absl::MutexLock l(&state_mu_); - // Gather barriers which have timed out. - for (std::string_view barrier_id : ongoing_barriers_) { - auto* barrier = &barriers_[barrier_id]; - if (current_time_micros > barrier->deadline_in_micros) { - expired_barriers[barrier_id] = barrier; - } - } - // Pass these barriers with the time out error. - for (const auto& [barrier_id, barrier] : expired_barriers) { - std::string pending_tasks; - int pending_task_count = 0; - for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { - if (!at_barrier) { - ++pending_task_count; - if (pending_task_count <= kPendingTaskLogLimit) { - absl::StrAppend(&pending_tasks, GetTaskName(task), "\n"); - } else { - break; - } - } - } - std::string error_message = absl::StrFormat( - "Barrier timed out. This usually happens because a task " - "triggered the barrier unexpectedly early, or some tasks are " - "too slow. Please look at the other task logs to debug " - "further. Barrier_id: %s. The first task at the barrier: " - "%s. ", - barrier_id, GetTaskName(barrier->initiating_task)); - if (pending_task_count > kPendingTaskLogLimit) { - absl::StrAppend(&error_message, - "Too many tasks have timed out. The first ", - kPendingTaskLogLimit, - " timed out task names:\n", pending_tasks); - } else { - absl::StrAppend( - &error_message, - "Total Number of tasks already at the barrier: ", - barrier->tasks_at_barrier.size() - pending_task_count, "/", - barrier->tasks_at_barrier.size(), - ". Timed out task names:\n%s", pending_tasks); - } - const absl::Status error = MakeCoordinationError( - absl::DeadlineExceededError(error_message)); - PassBarrier(barrier_id, error, barrier); - } - } - if (!has_service_to_client_connection && - expired_barriers.contains(shutdown_barrier_id_)) { - // Error cannot be propagated through service-to-client connection. - // Note: we cannot destroy the thread within its own function. - // However, this thread will be destroyed once the function returns. - SendErrorPollingResponseOrStopService( - MakeCoordinationError(absl::DeadlineExceededError( - "Shutdown barrier timed out. Check the task logs for an " - "earlier error."))); - } +void CoordinationServiceStandaloneImpl::CheckHeartbeatTimeout() { + absl::Status status = absl::OkStatus(); + std::vector stale_task_names; + const bool has_service_to_client_connection = client_cache_ != nullptr; + { + absl::MutexLock l(&state_mu_); + for (const auto& [task_name, task_state] : cluster_state_) { + // Skip tasks that are not registered or in error state + if (task_state->GetState() != CoordinatedTaskState::TASKSTATE_CONNECTED) { + continue; + } + const bool is_stale = + task_state->TimeSinceLastHeartbeatMs() > heartbeat_timeout_ms_; + VLOG(10) << "Checking staleness for " << task_name + << " stale?=" << is_stale; + if (is_stale) { + stale_task_names.push_back(task_name); + status = MakeCoordinationError(absl::UnavailableError( + absl::StrCat("Task ", task_name, + " heartbeat timeout. This indicates that the " + "remote task has failed, got preempted, or " + "crashed unexpectedly. Check the task logs " + "for an earlier error to debug further."))); + SetTaskError(task_name, status); + } + } + } + // Propagate heartbeat timeout errors to other connected tasks. + if (!stale_task_names.empty()) { + if (!has_service_to_client_connection) { + absl::Status heartbeat_timeout_error = + MakeCoordinationError(absl::UnavailableError(absl::StrCat( + "The following tasks are unhealthy (stopped sending " + "heartbeats):\n", + absl::StrJoin(stale_task_names, "\n"), + "\nCheck the task logs for an earlier error to debug " + "further."))); + if (SendErrorPollingResponseOrStopService(heartbeat_timeout_error)) { + return; + } + } else { + for (const auto& stale_task_name : stale_task_names) { + PropagateError(GetTaskFromName(stale_task_name)); + } + } + } +} - // Reset this for the next barrier check. - expired_barriers.clear(); +void CoordinationServiceStandaloneImpl::CheckBarrierTimeout() { + const bool has_service_to_client_connection = client_cache_ != nullptr; + absl::flat_hash_map expired_barriers; + uint64_t current_time_micros = Env::Default()->NowMicros(); + { + absl::MutexLock l(&state_mu_); + // Gather barriers which have timed out. + for (std::string_view barrier_id : ongoing_barriers_) { + auto* barrier = &barriers_[barrier_id]; + if (current_time_micros > barrier->deadline_in_micros) { + expired_barriers[barrier_id] = barrier; + } + } + // Pass these barriers with the time out error. + for (const auto& [barrier_id, barrier] : expired_barriers) { + std::string pending_tasks; + int pending_task_count = 0; + for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { + if (at_barrier) { + continue; } - })); + ++pending_task_count; + if (pending_task_count > kPendingTaskLogLimit) { + break; + } + absl::StrAppend(&pending_tasks, GetTaskName(task), "\n"); + } + std::string error_message = absl::StrFormat( + "Barrier timed out. This usually happens because a task " + "triggered the barrier unexpectedly early, or some tasks are " + "too slow. Please look at the other task logs to debug " + "further. Barrier_id: %s. The first task at the barrier: " + "%s. ", + barrier_id, GetTaskName(barrier->initiating_task)); + if (pending_task_count > kPendingTaskLogLimit) { + absl::StrAppend( + &error_message, "Too many tasks have timed out. The first ", + kPendingTaskLogLimit, " timed out task names:\n", pending_tasks); + } else { + absl::StrAppend(&error_message, + "Total Number of tasks already at the barrier: ", + barrier->tasks_at_barrier.size() - pending_task_count, + "/", barrier->tasks_at_barrier.size(), + ". Timed out task names:\n%s", pending_tasks); + } + const absl::Status error = + MakeCoordinationError(absl::DeadlineExceededError(error_message)); + PassBarrier(barrier_id, error, barrier); + } + } + if (!has_service_to_client_connection && + expired_barriers.contains(shutdown_barrier_id_)) { + // Error cannot be propagated through service-to-client connection. + SendErrorPollingResponseOrStopService( + MakeCoordinationError(absl::DeadlineExceededError( + "Shutdown barrier timed out. Check the task logs for an " + "earlier error."))); + } +} + +void CoordinationServiceStandaloneImpl::CheckStaleness() { + // Used to store stale tasks and barriers. + while (true) { + { + absl::MutexLock l(&state_mu_); + check_staleness_thread_cv_.WaitWithTimeout(&state_mu_, absl::Seconds(1)); + if (shutting_down_) { + return; + } + } + CheckHeartbeatTimeout(); + CheckBarrierTimeout(); + } +} + +void CoordinationServiceStandaloneImpl::StartCheckStaleness() { + check_staleness_thread_.reset(env_.StartThread( + {}, kHealthCheckThread, + absl::bind_front(&CoordinationServiceStandaloneImpl::CheckStaleness, + this))); } void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) { diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index 45a7ddb1e8ff20..00845e5001b7ff 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -170,11 +170,7 @@ absl::Status PreemptionSyncManagerImpl::Initialize( call_opts_ = agent_->GetKeyValueAsync( kPreemptionNoticeKey, [this, agent = agent_](absl::StatusOr status_or_death_time) { - if (absl::IsCancelled(status_or_death_time.status()) || - // TODO(b/349613356): Investigate if we can always ensure that - // the RPC is cancelled before the server goes away, so we can - // differentiate between network failure and shutdown behaviour. - absl::IsUnavailable(status_or_death_time.status())) { + if (absl::IsCancelled(status_or_death_time.status())) { // The agent cancels pending GetKeyValue RPCs because of shutdown, // so simply log and return. LOG(INFO) << "Cancelled call to retrieve preemption notice. This is " diff --git a/third_party/xla/xla/tsl/framework/BUILD b/third_party/xla/xla/tsl/framework/BUILD index 84775bd3782b75..52faa0be9359cf 100644 --- a/third_party/xla/xla/tsl/framework/BUILD +++ b/third_party/xla/xla/tsl/framework/BUILD @@ -194,6 +194,7 @@ cc_library( ":allocator", ":metrics", ":shared_counter", + "//xla/tsl/protobuf:bfc_memory_map_proto_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -210,7 +211,6 @@ cc_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:scoped_memory_debug_annotation", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/protobuf:bfc_memory_map_proto_cc", ], ) diff --git a/third_party/xla/xla/tsl/framework/allocator.h b/third_party/xla/xla/tsl/framework/allocator.h index 29db454ec871b7..c289532c78a75e 100644 --- a/third_party/xla/xla/tsl/framework/allocator.h +++ b/third_party/xla/xla/tsl/framework/allocator.h @@ -146,6 +146,13 @@ class Allocator { // REQUIRES: "ptr" was previously returned by a call to AllocateRaw virtual void DeallocateRaw(void* ptr) = 0; + virtual void DeallocateRaw(void* ptr, size_t alignment, size_t num_bytes) { + (void)alignment; + (void)num_bytes; + + DeallocateRaw(ptr); + } + // Returns true if this allocator tracks the sizes of allocations. // RequestedSize and AllocatedSize must be overridden if // TracksAllocationSizes is overridden to return true. diff --git a/third_party/xla/xla/tsl/framework/bfc_allocator.cc b/third_party/xla/xla/tsl/framework/bfc_allocator.cc index c2d8f8b121f64f..a5f3401bbe86d4 100644 --- a/third_party/xla/xla/tsl/framework/bfc_allocator.cc +++ b/third_party/xla/xla/tsl/framework/bfc_allocator.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tsl/framework/allocator_retry.h" +#include "xla/tsl/protobuf/bfc_memory_map.pb.h" #include "tsl/lib/core/bits.h" #include "tsl/platform/file_system.h" #include "tsl/platform/logging.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tsl/platform/types.h" #include "tsl/profiler/lib/scoped_memory_debug_annotation.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/protobuf/bfc_memory_map.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc b/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc index a9cbf0c4650ac6..9c9de966cfb67d 100644 --- a/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc +++ b/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc @@ -121,6 +121,17 @@ class CPUAllocator : public Allocator { port::AlignedFree(ptr); } + void DeallocateRaw(void* ptr, size_t alignment, size_t num_bytes) override { + if (cpu_allocator_collect_stats) { + const std::size_t alloc_size = + port::MallocExtension_GetAllocatedSize(ptr); + mutex_lock l(mu_); + stats_.bytes_in_use -= alloc_size; + AddTraceMe("MemoryDeallocation", ptr, 0, alloc_size); + } + port::AlignedSizedFree(ptr, alignment, num_bytes); + } + void AddTraceMe(absl::string_view traceme_name, const void* chunk_ptr, std::size_t req_bytes, std::size_t alloc_bytes) { tsl::profiler::TraceMe::InstantActivity( diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD b/third_party/xla/xla/tsl/lib/histogram/BUILD similarity index 62% rename from third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD rename to third_party/xla/xla/tsl/lib/histogram/BUILD index 4de34f8e390755..cbd206f6bd8083 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD +++ b/third_party/xla/xla/tsl/lib/histogram/BUILD @@ -1,13 +1,13 @@ load( - "@local_tsl//tsl/platform:rules_cc.bzl", - "cc_library", + "@local_tsl//tsl/platform:build_config.bzl", + "tsl_cc_test", ) -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", + "@local_tsl//tsl/platform:rules_cc.bzl", + "cc_library", ) +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -20,12 +20,12 @@ cc_library( hdrs = ["histogram.h"], visibility = ["//visibility:public"], deps = [ - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:mutex", - "//tsl/platform:thread_annotations", - "//tsl/platform:types", - "//tsl/protobuf:histogram_proto_cc", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:thread_annotations", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/protobuf:histogram_proto_cc", ], alwayslink = True, ) @@ -55,9 +55,9 @@ tsl_cc_test( ], deps = [ ":histogram", - "//tsl/platform:logging", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/protobuf:histogram_proto_cc", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.cc b/third_party/xla/xla/tsl/lib/histogram/histogram.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.cc rename to third_party/xla/xla/tsl/lib/histogram/histogram.cc index d6dc8aa4a5ab20..e8203549272547 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/histogram/histogram.h" +#include "xla/tsl/lib/histogram/histogram.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.h b/third_party/xla/xla/tsl/lib/histogram/histogram.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.h rename to third_party/xla/xla/tsl/lib/histogram/histogram.h index a024e2275b4d29..64b0cd188e7222 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.h +++ b/third_party/xla/xla/tsl/lib/histogram/histogram.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ -#define TENSORFLOW_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ +#ifndef XLA_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ +#define XLA_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ #include #include @@ -121,7 +121,7 @@ class ThreadSafeHistogram { void Clear(); - // TODO(touts): It might be a good idea to provide a AddN() + // TODO(mdevin): It might be a good idea to provide a AddN() // method to avoid grabbing/releasing the lock when adding many values. void Add(double value); @@ -140,4 +140,4 @@ class ThreadSafeHistogram { } // namespace histogram } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ +#endif // XLA_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram_test.cc b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/histogram/histogram_test.cc rename to third_party/xla/xla/tsl/lib/histogram/histogram_test.cc index cda166f943d208..4051d98f49ab97 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram_test.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/histogram/histogram.h" +#include "xla/tsl/lib/histogram/histogram.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD b/third_party/xla/xla/tsl/lib/strings/BUILD similarity index 79% rename from third_party/xla/third_party/tsl/tsl/lib/strings/BUILD rename to third_party/xla/xla/tsl/lib/strings/BUILD index 699965e401c526..03f82a366f78c6 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD +++ b/third_party/xla/xla/tsl/lib/strings/BUILD @@ -2,8 +2,8 @@ load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -13,13 +13,13 @@ cc_library( hdrs = ["proto_serialization.h"], visibility = ["//visibility:public"], deps = [ - "//tsl/lib/gtl:inlined_vector", - "//tsl/platform:hash", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:protobuf", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@local_tsl//tsl/lib/gtl:inlined_vector", + "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc rename to third_party/xla/xla/tsl/lib/strings/proto_serialization.cc index 139849e306a8b7..06ef0747ee553d 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc +++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h b/third_party/xla/xla/tsl/lib/strings/proto_serialization.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h rename to third_party/xla/xla/tsl/lib/strings/proto_serialization.h index 96a5c55f647694..b79e9aff6c21df 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h +++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ -#define TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ +#ifndef XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ +#define XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ #include "tsl/platform/protobuf.h" @@ -45,4 +45,4 @@ uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto, } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ +#endif // XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ diff --git a/third_party/xla/xla/tsl/protobuf/BUILD b/third_party/xla/xla/tsl/protobuf/BUILD new file mode 100644 index 00000000000000..1a6ce0e4277571 --- /dev/null +++ b/third_party/xla/xla/tsl/protobuf/BUILD @@ -0,0 +1,27 @@ +load( + "@local_tsl//tsl/platform:build_config.bzl", + "tf_proto_library", +) +load( + "//xla/tsl:tsl.bzl", + "if_google", + "internal_visibility", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/core:__subpackages__", + "//xla/tsl:internal", + "//tensorflow_models:__subpackages__", + ]), + features = if_google(["-parse_headers"]), + licenses = ["notice"], +) + +tf_proto_library( + name = "bfc_memory_map_proto", + srcs = ["bfc_memory_map.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/bfc_memory_map.proto b/third_party/xla/xla/tsl/protobuf/bfc_memory_map.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/bfc_memory_map.proto rename to third_party/xla/xla/tsl/protobuf/bfc_memory_map.proto diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index 33571902eb052a..2882b9b96861f1 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -221,6 +221,17 @@ def if_with_tpu_support(if_true, if_false = []): "//conditions:default": if_false, }) +# These configs are used to determine whether we should use the hermetic CUDA +# tools in cc_libraries. +# They are intended for the OSS builds only. +def if_hermetic_cuda_tools(if_true, if_false = []): # buildifier: disable=unused-variable + """Shorthand for select()'ing on whether we're building with hermetic CUDA tools.""" + return select({"@local_config_cuda//cuda:hermetic_cuda_tools": if_true, "//conditions:default": if_false}) # copybara:comment_replace return if_false + +def if_hermetic_cuda_libs(if_true, if_false = []): # buildifier: disable=unused-variable + """Shorthand for select()'ing on whether we need to include hermetic CUDA libraries.""" + return select({"@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": if_true, "//conditions:default": if_false}) # copybara:comment_replace return if_false + def get_win_copts(is_external = False): WINDOWS_COPTS = [ # copybara:uncomment_begin(no MSVC flags in google) diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 620f1b6af02c12..64d092261b645e 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -28,7 +28,100 @@ message CompilationEnvironmentsProto { // Debugging options for XLA. These options may change at any time - there are // no guarantees about backward or forward compatibility for these fields. +// +// Debug options naming and organization: +// +// 1. Backend-agnostic options: `xla_$flag_name` - go first, and sorted +// alphabetically by the flag name. +// +// 2. Backend-specific options: `xla_$backend_$flag_name` - must be in the +// corresponding backend section, and sorted alphabetically by the flag name. +// message DebugOptions { + //--------------------------------------------------------------------------// + // XLA backend-agnostic options. + //--------------------------------------------------------------------------// + // go/keep-sorted start + + // go/keep-sorted end + + //--------------------------------------------------------------------------// + // XLA:CPU options. + //--------------------------------------------------------------------------// + + // go/keep-sorted start newline_separated=yes + // + // When true, XLA:CPU uses HLO module scheduler that is optimized for + // extracting concurrency at the cost of extra memory: we extend the live + // ranges of temporaries to allow XLA runtime to schedule independent + // operations in parallel on separate threads. + bool xla_cpu_enable_concurrency_optimized_scheduler = 307; + + // When true, "unsafe" mathematical optimizations are enabled. These + // transformations include but are not limited to: + // + // - Reducing the precision of operations (e.g. using an approximate sin + // function, or transforming x/y into x * (1/y)). + // - Assuming that operations never produce or consume NaN or +/- Inf (this + // behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). + // - Assuming that +0 and -0 are indistinguishable. + bool xla_cpu_enable_fast_math = 99; + + // When false we lower the Minimum and Maximum hlos in the CPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag + // this is false we always propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the gpu flag + // below! + bool xla_cpu_enable_fast_min_max = 140; + + // When xla_cpu_enable_fast_math is true then this controls whether we forbid + // to use the reciprocal of an argument instead of division. Ignored when + // xla_cpu_enable_fast_math is false. + bool xla_cpu_fast_math_honor_division = 126; + + // When xla_cpu_enable_fast_math is true then this controls whether we forbid + // to approximate calculations for functions. Ignored when + // xla_cpu_enable_fast_math is false. + bool xla_cpu_fast_math_honor_functions = 129; + + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_infs = 121; + + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_nans = 120; + + // When true, XLA:CPU uses the thunk runtime to execute compiled program. + bool xla_cpu_use_thunk_runtime = 298; + + // A `prefer-vector-width` value that is passed to the LLVM backend. Default + // value is `256` (AVX2 on x86 platforms). + int32 xla_cpu_prefer_vector_width = 308; + + // go/keep-sorted end + + //--------------------------------------------------------------------------// + // XLA:GPU options. + //--------------------------------------------------------------------------// + // go/keep-sorted start + + // go/keep-sorted end + + //--------------------------------------------------------------------------// + // XLA:TPU options. + //--------------------------------------------------------------------------// + // go/keep-sorted start + + // go/keep-sorted end + + //--------------------------------------------------------------------------// + // A bag of XLA options that have to be categorized. + //--------------------------------------------------------------------------// + // Show addresses of HLO ops in graph dump. bool xla_hlo_graph_addresses = 2; @@ -115,58 +208,9 @@ message DebugOptions { bool xla_cpu_use_mkl_dnn = 97; reserved 177; // Was xla_cpu_use_xla_runtime - bool xla_cpu_use_thunk_runtime = 298; - - // When true, XLA:CPU uses HLO module scheduler that is optimized for - // extracting concurrency at the cost of extra memory: we extend the live - // ranges of temporaries to allow XLA runtime to schedule independent - // operations in parallel on separate threads. - bool xla_cpu_enable_concurrency_optimized_scheduler = 307; - - // A `prefer-vector-width` value that is passed to the LLVM backend. Default - // value is `256` (AVX2 on x86 platforms). - int32 xla_cpu_prefer_vector_width = 308; reserved 98; // Was xla_gpu_max_kernel_unroll_factor - // When true, "unsafe" mathematical optimizations are enabled. These - // transformations include but are not limited to: - // - // - Reducing the precision of operations (e.g. using an approximate sin - // function, or transforming x/y into x * (1/y)). - // - Assuming that operations never produce or consume NaN or +/- Inf (this - // behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). - // - Assuming that +0 and -0 are indistinguishable. - bool xla_cpu_enable_fast_math = 99; - - // When xla_cpu_enable_fast_math is true then this controls whether we allow - // operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is - // false. - bool xla_cpu_fast_math_honor_nans = 120; - - // When xla_cpu_enable_fast_math is true then this controls whether we allow - // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is - // false. - bool xla_cpu_fast_math_honor_infs = 121; - - // When xla_cpu_enable_fast_math is true then this controls whether we forbid - // to use the reciprocal of an argument instead of division. Ignored when - // xla_cpu_enable_fast_math is false. - bool xla_cpu_fast_math_honor_division = 126; - - // When xla_cpu_enable_fast_math is true then this controls whether we forbid - // to approximate calculations for functions. Ignored when - // xla_cpu_enable_fast_math is false. - bool xla_cpu_fast_math_honor_functions = 129; - - // When false we lower the Minimum and Maximum hlos in the CPU backend such - // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag - // this is false we always propagate NaNs through Min and Max. - // - // Note, this does not correspond to the exact same behavior as the gpu flag - // below! - bool xla_cpu_enable_fast_min_max = 140; - // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. @@ -883,7 +927,14 @@ message DebugOptions { // If enabled, uses the libnvjitlink library for PTX compilation and linking bool xla_gpu_enable_libnvjitlink = 319; - // Next id: 320 + // If enabled, generates triton gemm kernels for int4 inputs. + bool xla_gpu_enable_triton_gemm_int4 = 320; + + // If true, XLA will wrap `dot` operations into async computations in an + // effort to parallelize matrix operations. + bool xla_gpu_async_dot = 321; + + // Next id: 322 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.