Skip to content

Commit

Permalink
Fix Triton containers (#828)
Browse files Browse the repository at this point in the history
Mainly: add some extra include paths to account for Google/OpenXLA
changes like openxla/triton@10c56aa7 that are not accounted for in
the CMake build scripts.

Also, the `unittest` line is now indented so fix the sed regex.

Separate the Bazel/LLVM/Triton steps into different commands as it
didn't actually help with layer caching in CI.
  • Loading branch information
olupton authored May 16, 2024
1 parent fe7e193 commit b567b05
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
34 changes: 22 additions & 12 deletions .github/container/Dockerfile.triton
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@ ENV TRITON_NVDISASM_PATH=/usr/local/cuda/bin/nvdisasm
RUN [ -x "${TRITON_PTXAS_PATH}" ] && [ -x "${TRITON_CUOBJDUMP_PATH}" ] && [ -x "${TRITON_NVDISASM_PATH}" ]

###############################################################################
## Check out LLVM and Triton sources that match XLA, build them.
## Check out LLVM and Triton sources that match XLA. This uses XLA's Bazel
## configuration to get the relevant tag from the openxla/triton fork's
## llvm-head branch and apply XLA's extra patches to it. Also fetches the
## compatible LLVM sources.
###############################################################################
FROM base as builder
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
RUN <<"EOF" bash -ex
# Use XLA's Bazel configuration to get the relevant tag from the openxla/triton
# fork's llvm-head branch and apply XLA's extra patches to it. Also fetches the
# compatible LLVM sources.
pushd "${SRC_PATH_XLA}"
BAZEL=$(find "${SRC_PATH_JAX}/build" -type f -executable -name 'bazel-*')
"${BAZEL}" --output_base=/opt/checkout fetch @triton//:BUILD
# Build XLA's version of LLVM
rm -rf /root/.cache
EOF

###############################################################################
## Build LLVM
###############################################################################
RUN <<"EOF" bash -ex
mkdir /opt/llvm-build
pushd /opt/llvm-build
cmake -G Ninja \
Expand All @@ -42,17 +48,25 @@ cmake -G Ninja \
-DLLVM_TARGETS_TO_BUILD="host;NVPTX" \
/opt/checkout/external/llvm-raw/llvm
ninja
# Build XLA's version of Triton against that LLVM
EOF

###############################################################################
## Build Triton
###############################################################################
RUN <<"EOF" bash -ex
pushd /opt/checkout/external/triton
mkdir dist
# Do not compile with -Werror
sed -i -e 's|-Werror||g' CMakeLists.txt
# The LLVM build above does not enable these libraries
sed -i -e 's|\(LLVMAMDGPU.*\)|# \1|g' CMakeLists.txt
# Do not build tests
sed -i -e 's|^add_subdirectory(unittest)|# unit tests disabled|' CMakeLists.txt
sed -i -e 's|^\s*add_subdirectory(unittest)|# unit tests disabled|' CMakeLists.txt
# Do not build the AMD GPU backend
sed -i -e 's|BackendInstaller.copy(\["nvidia", "amd"\])|BackendInstaller.copy(["nvidia"])|g' python/setup.py
# Google has patches that mess with include paths in source files
sed -i -e '/include_directories(${PROJECT_SOURCE_DIR}\/third_party)/a include_directories(${PROJECT_SOURCE_DIR}/third_party/nvidia/include)' CMakeLists.txt
sed -i -e '/include_directories(${PROJECT_BINARY_DIR}\/third_party)/a include_directories(${PROJECT_BINARY_DIR}/third_party/nvidia/include)' CMakeLists.txt
# Extra patches to Triton maintained in XLA. These are already applied in the working directory.
XLA_TRITON_PATCHES="${SRC_PATH_XLA}/third_party/triton"
# This patch adds two files that are not known to CMake
Expand All @@ -68,10 +82,6 @@ LLVM_INCLUDE_DIRS=/opt/llvm-build/include \
pip wheel --verbose --wheel-dir=dist/ python/
# Clean up the wheel build directory so it doesn't end up bloating the container
rm -rf python/build
# Make the layer for the *current* step smaller, so it is more likely to stay
# resident in the Docker cache
cp -r /opt/checkout/external/triton /opt/triton-copy
rm -rf /opt/checkout /opt/llvm-build /root/.cache
EOF

###############################################################################
Expand All @@ -83,7 +93,7 @@ ARG SRC_PATH_JAX_TRITON
ARG SRC_PATH_TRITON

# Get the triton source + wheel from the build step
COPY --from=builder /opt/triton-copy ${SRC_PATH_TRITON}
COPY --from=builder /opt/checkout/external/triton ${SRC_PATH_TRITON}
RUN echo "triton @ file://$(ls ${SRC_PATH_TRITON}/dist/triton-*.whl)" >> /opt/pip-tools.d/requirements-triton.in

# Check out jax-triton
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ jobs:
${{ needs.build-triton.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee test-triton.log
# autotuner tests from jax-triton now hit a triton code path that uses utilities from pytorch...
pip install --no-deps torch
pip install --no-deps torch --index-url https://download.pytorch.org/whl/cpu
python /opt/jax-triton/tests/triton_call_test.py --xml_output_file /output/triton_test.xml
EOF
STATISTICS_SCRIPT: |
Expand Down

0 comments on commit b567b05

Please sign in to comment.