diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp index 415937422482..2e3055eef12b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp @@ -72,6 +72,21 @@ static void populateConvertGPUToAMDGPUPatterns(RewritePatternSet &patterns) { } // namespace +// Function to check valid data types on the ROCm backend. +static LogicalResult validateDataTypes(Operation *op) { + auto operandTypes = llvm::to_vector(op->getOperandTypes()); + auto resultTypes = llvm::to_vector(op->getResultTypes()); + if (llvm::any_of(llvm::concat(operandTypes, resultTypes), + llvm::IsaPred)) { + op->emitOpError() + << "F8E5M2 and F8E4M3FN types are not supported on " + "the ROCm backend; try F8E5M2FNUZ or F8E4M3FNUZ instead."; + return failure(); + } + + return success(); +} + /// A pass that replaces all occurrences of GPU device operations with their /// corresponding ROCDL equivalent. /// @@ -90,6 +105,11 @@ struct ConvertToROCDLPass final void runOnOperation() override { ModuleOp m = getOperation(); + m.walk([&](Operation *op) { + if (failed(validateDataTypes(op))) + return signalPassFailure(); + }); + if (clROCMIndexingBits != 32 && clROCMIndexingBits != 64) { m.emitOpError() << "unsupported: ROCm index bit widths must either be " "64 or 32, got "