Skip to content

Commit

Permalink
[Codegen][ROCM] Add validation for quantization data types (iree-org#…
Browse files Browse the repository at this point in the history
…18423)

Add a check for valid quantized data types for the ROCm backend 
fixes iree-org#18367

Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
  • Loading branch information
nithinsubbiah authored Sep 4, 2024
1 parent d8caa8d commit a072369
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ static void populateConvertGPUToAMDGPUPatterns(RewritePatternSet &patterns) {

} // namespace

// Function to check valid data types on the ROCm backend.
static LogicalResult validateDataTypes(Operation *op) {
auto operandTypes = llvm::to_vector(op->getOperandTypes());
auto resultTypes = llvm::to_vector(op->getResultTypes());
if (llvm::any_of(llvm::concat<Type>(operandTypes, resultTypes),
llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>)) {
op->emitOpError()
<< "F8E5M2 and F8E4M3FN types are not supported on "
"the ROCm backend; try F8E5M2FNUZ or F8E4M3FNUZ instead.";
return failure();
}

return success();
}

/// A pass that replaces all occurrences of GPU device operations with their
/// corresponding ROCDL equivalent.
///
Expand All @@ -90,6 +105,11 @@ struct ConvertToROCDLPass final
void runOnOperation() override {
ModuleOp m = getOperation();

m.walk([&](Operation *op) {
if (failed(validateDataTypes(op)))
return signalPassFailure();
});

if (clROCMIndexingBits != 32 && clROCMIndexingBits != 64) {
m.emitOpError() << "unsupported: ROCm index bit widths must either be "
"64 or 32, got "
Expand Down

0 comments on commit a072369

Please sign in to comment.