From b567b054cf3fc3acec418f920005dc1363ba917d Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Thu, 16 May 2024 15:57:07 +0200 Subject: [PATCH] Fix Triton containers (#828) Mainly: add some extra include paths to account for Google/OpenXLA changes like openxla/triton@10c56aa7 that are not accounted for in the CMake build scripts. Also, the `unittest` line is now indented so fix the sed regex. Separate the Bazel/LLVM/Triton steps into different commands as it didn't actually help with layer caching in CI. --- .github/container/Dockerfile.triton | 34 +++++++++++++++++++---------- .github/workflows/_ci.yaml | 2 +- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/.github/container/Dockerfile.triton b/.github/container/Dockerfile.triton index df51de9a7..0b65cdfce 100644 --- a/.github/container/Dockerfile.triton +++ b/.github/container/Dockerfile.triton @@ -18,19 +18,25 @@ ENV TRITON_NVDISASM_PATH=/usr/local/cuda/bin/nvdisasm RUN [ -x "${TRITON_PTXAS_PATH}" ] && [ -x "${TRITON_CUOBJDUMP_PATH}" ] && [ -x "${TRITON_NVDISASM_PATH}" ] ############################################################################### -## Check out LLVM and Triton sources that match XLA, build them. +## Check out LLVM and Triton sources that match XLA. This uses XLA's Bazel +## configuration to get the relevant tag from the openxla/triton fork's +## llvm-head branch and apply XLA's extra patches to it. Also fetches the +## compatible LLVM sources. ############################################################################### FROM base as builder ARG SRC_PATH_JAX ARG SRC_PATH_XLA RUN <<"EOF" bash -ex -# Use XLA's Bazel configuration to get the relevant tag from the openxla/triton -# fork's llvm-head branch and apply XLA's extra patches to it. Also fetches the -# compatible LLVM sources. pushd "${SRC_PATH_XLA}" BAZEL=$(find "${SRC_PATH_JAX}/build" -type f -executable -name 'bazel-*') "${BAZEL}" --output_base=/opt/checkout fetch @triton//:BUILD -# Build XLA's version of LLVM +rm -rf /root/.cache +EOF + +############################################################################### +## Build LLVM +############################################################################### +RUN <<"EOF" bash -ex mkdir /opt/llvm-build pushd /opt/llvm-build cmake -G Ninja \ @@ -42,7 +48,12 @@ cmake -G Ninja \ -DLLVM_TARGETS_TO_BUILD="host;NVPTX" \ /opt/checkout/external/llvm-raw/llvm ninja -# Build XLA's version of Triton against that LLVM +EOF + +############################################################################### +## Build Triton +############################################################################### +RUN <<"EOF" bash -ex pushd /opt/checkout/external/triton mkdir dist # Do not compile with -Werror @@ -50,9 +61,12 @@ sed -i -e 's|-Werror||g' CMakeLists.txt # The LLVM build above does not enable these libraries sed -i -e 's|\(LLVMAMDGPU.*\)|# \1|g' CMakeLists.txt # Do not build tests -sed -i -e 's|^add_subdirectory(unittest)|# unit tests disabled|' CMakeLists.txt +sed -i -e 's|^\s*add_subdirectory(unittest)|# unit tests disabled|' CMakeLists.txt # Do not build the AMD GPU backend sed -i -e 's|BackendInstaller.copy(\["nvidia", "amd"\])|BackendInstaller.copy(["nvidia"])|g' python/setup.py +# Google has patches that mess with include paths in source files +sed -i -e '/include_directories(${PROJECT_SOURCE_DIR}\/third_party)/a include_directories(${PROJECT_SOURCE_DIR}/third_party/nvidia/include)' CMakeLists.txt +sed -i -e '/include_directories(${PROJECT_BINARY_DIR}\/third_party)/a include_directories(${PROJECT_BINARY_DIR}/third_party/nvidia/include)' CMakeLists.txt # Extra patches to Triton maintained in XLA. These are already applied in the working directory. XLA_TRITON_PATCHES="${SRC_PATH_XLA}/third_party/triton" # This patch adds two files that are not known to CMake @@ -68,10 +82,6 @@ LLVM_INCLUDE_DIRS=/opt/llvm-build/include \ pip wheel --verbose --wheel-dir=dist/ python/ # Clean up the wheel build directory so it doesn't end up bloating the container rm -rf python/build -# Make the layer for the *current* step smaller, so it is more likely to stay -# resident in the Docker cache -cp -r /opt/checkout/external/triton /opt/triton-copy -rm -rf /opt/checkout /opt/llvm-build /root/.cache EOF ############################################################################### @@ -83,7 +93,7 @@ ARG SRC_PATH_JAX_TRITON ARG SRC_PATH_TRITON # Get the triton source + wheel from the build step -COPY --from=builder /opt/triton-copy ${SRC_PATH_TRITON} +COPY --from=builder /opt/checkout/external/triton ${SRC_PATH_TRITON} RUN echo "triton @ file://$(ls ${SRC_PATH_TRITON}/dist/triton-*.whl)" >> /opt/pip-tools.d/requirements-triton.in # Check out jax-triton diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index c16f82dbc..cf2425ad1 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -378,7 +378,7 @@ jobs: ${{ needs.build-triton.outputs.DOCKER_TAG_FINAL }} \ bash <<"EOF" |& tee test-triton.log # autotuner tests from jax-triton now hit a triton code path that uses utilities from pytorch... - pip install --no-deps torch + pip install --no-deps torch --index-url https://download.pytorch.org/whl/cpu python /opt/jax-triton/tests/triton_call_test.py --xml_output_file /output/triton_test.xml EOF STATISTICS_SCRIPT: |