diff --git a/.github/container/Dockerfile.triton b/.github/container/Dockerfile.triton index 11368b5ef..25b80b456 100644 --- a/.github/container/Dockerfile.triton +++ b/.github/container/Dockerfile.triton @@ -2,10 +2,21 @@ ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax ARG SRC_PATH_TRITON=/opt/openxla-triton +FROM ${BASE_IMAGE} as base +# Triton setup.py downloads and installs CUDA binaries at specific versions +# hardcoded in the script itself: +# https://github.com/openxla/triton/blob/84f9d9de158fb866fac67970f0f5d323999d9db1/python/setup.py#L373-L393 +# Tell Triton to use CUDA binaries from the host container instead. These should be set +# both during the build stage and in the final container. +ENV TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas +ENV TRITON_CUOBJDUMP_PATH=/usr/local/cuda/bin/cuobjdump +ENV TRITON_NVDISASM_PATH=/usr/local/cuda/bin/nvdisasm +RUN [ -x "${TRITON_PTXAS_PATH}" ] && [ -x "${TRITON_CUOBJDUMP_PATH}" ] && [ -x "${TRITON_NVDISASM_PATH}" ] + ############################################################################### ## Check out Triton source and build a wheel ############################################################################### -FROM ${BASE_IMAGE} as builder +FROM base as builder ARG SRC_PATH_TRITON @@ -38,7 +49,7 @@ RUN rm -rf "${SRC_PATH_TRITON}/python/build" ############################################################################### ## Download source and add auxiliary scripts ############################################################################### -FROM ${BASE_IMAGE} as mealkit +FROM base as mealkit ARG SRC_PATH_TRITON