From a072369a76b67381711936a65ae614f44ff206a7 Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:26:18 -0700 Subject: [PATCH] [Codegen][ROCM] Add validation for quantization data types (#18423) Add a check for valid quantized data types for the ROCm backend fixes #18367 Signed-off-by: nithinsubbiah --- .../Codegen/LLVMGPU/ConvertToROCDL.cpp | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp index 415937422482..2e3055eef12b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp @@ -72,6 +72,21 @@ static void populateConvertGPUToAMDGPUPatterns(RewritePatternSet &patterns) { } // namespace +// Function to check valid data types on the ROCm backend. +static LogicalResult validateDataTypes(Operation *op) { + auto operandTypes = llvm::to_vector(op->getOperandTypes()); + auto resultTypes = llvm::to_vector(op->getResultTypes()); + if (llvm::any_of(llvm::concat(operandTypes, resultTypes), + llvm::IsaPred)) { + op->emitOpError() + << "F8E5M2 and F8E4M3FN types are not supported on " + "the ROCm backend; try F8E5M2FNUZ or F8E4M3FNUZ instead."; + return failure(); + } + + return success(); +} + /// A pass that replaces all occurrences of GPU device operations with their /// corresponding ROCDL equivalent. /// @@ -90,6 +105,11 @@ struct ConvertToROCDLPass final void runOnOperation() override { ModuleOp m = getOperation(); + m.walk([&](Operation *op) { + if (failed(validateDataTypes(op))) + return signalPassFailure(); + }); + if (clROCMIndexingBits != 32 && clROCMIndexingBits != 64) { m.emitOpError() << "unsupported: ROCm index bit widths must either be " "64 or 32, got "