-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Triton: use CUDA 12.3 tools from the base image #656
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that the paths are hard-coded, can we add some test so that we get notified if the binaries change locations? e.g.
RUN if [[ ! -x ${TRITON_PTXAS_PATH} ]]; then <THROW-ERROR>; fi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix! Left a minor suggestion adding some context.
Should we fix this upstream?
________________________________
From: Andrey Portnoy ***@***.***>
Sent: Friday, March 22, 2024 9:09:38 AM
To: NVIDIA/JAX-Toolbox ***@***.***>
Cc: Frederic Bastien ***@***.***>; Review requested ***@***.***>
Subject: Re: [NVIDIA/JAX-Toolbox] Triton: use CUDA 12.3 tools from the base image (PR #656)
@andportnoy approved this pull request.
Thank you for the fix! Left a minor suggestion adding some context.
________________________________
In .github/container/Dockerfile.triton<#656 (comment)>:
@@ -2,10 +2,18 @@
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
ARG SRC_PATH_TRITON=/opt/openxla-triton
+FROM ${BASE_IMAGE} as base
+# Tell Triton to use CUDA binaries from the host container. These should be set
⬇️ Suggested change
-# Tell Triton to use CUDA binaries from the host container. These should be set
+# Triton setup.py downloads and installs CUDA binaries at specific versions
+# hardcoded in the script itself:
+# https://github.com/openxla/triton/blob/84f9d9de158fb866fac67970f0f5d323999d9db1/python/setup.py#L373-L393
+# Tell Triton to use CUDA binaries from the host container instead. These should be set
—
Reply to this email directly, view it on GitHub<#656 (review)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/AABMF627MQCSEU6JBTO5DBDYZRJUFAVCNFSM6AAAAABFDGANMOVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMYTSNJVGIYDQOJVGE>.
You are receiving this because your review was requested.Message ID: ***@***.***>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great suggestion, Andrey. Good job.
Co-authored-by: Andrey Portnoy <aportnoy@nvidia.com>
4289497
to
fe2f64f
Compare
I'm not sure what that would look like? What did you have in mind? |
I taught it was for Triton via Pallas. All is good. |
Previously, Triton would download its own copies of
ptxas
,cuobjdump
andnvdisasm
:https://github.com/openxla/triton/blob/cl617459344/python/setup.py#L373-L393
This began to cause problems when those versions were bumped to CUDA 12.4, meaning that Triton started to generate PTX with version number 8.3. When this was compiled, using the
ptxas
from the base container, inside XLA, then there were errors:in the nightly tests, which are taken from JAX-Triton.
Setting environment variables like
TRITON_PTXAS_PATH
has two effects:setup.py
If Triton starts depending on new features before the base container is updated to CUDA 12.4, problems may resurface.
Thanks to @andportnoy for help debugging.