Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Fix FP32 atomic_rmw #18

Open
wants to merge 1 commit into
base: rocm-jaxlib-v0.4.28-qa
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions third_party/triton/temporary/amd_pr7.patch
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,31 @@ index f59efd6..cf601f0 100644
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = convertFp32ToFp16NZ(loc, rewriter, v);
diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
index 83f24d711..82aad06c5 100644
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
@@ -599,12 +599,23 @@ struct AtomicRMWOpConversion
auto maybeKind = matchAtomicOp(atomicRmwAttr);
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
// atomics for MI-* series of AMD GPU.
+ if(isa<FloatType>(valElements[i].getType()) &&
+ (*maybeKind != mlir::LLVM::AtomicBinOp::fadd)) {
+ valElem = bitcast(valElements[i],
+ int_ty(valElements[i].getType().getIntOrFloatBitWidth()));
+ }
+
Value atom = rewriter
.create<LLVM::AtomicRMWOp>(
loc, *maybeKind, rmwPtr, valElements[i],
atomicMemOrdering, StringRef("agent"))
.getResult();

+ if(isa<FloatType>(valElements[i].getType()) &&
+ (*maybeKind != mlir::LLVM::AtomicBinOp::fadd)) {
+ atom = bitcast(atom, valElements[i].getType());
+ }
+
// NV for the f16v2 case generates one packed instruction. We have to
// create two separate instructions since LLVM::AtomicRMWOp doesn't
// support this. Can be optimized out with rocdl.raw.buffer.atomic.
Loading