diff --git a/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp b/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp index 6a3dccdb97f9..64a99f4782ad 100644 --- a/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp +++ b/compiler/plugins/input/Torch/InputConversion/BitCastQuantTensor.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -20,6 +21,47 @@ namespace mlir::iree_compiler::TorchInput { namespace { +class BitCastViewDtype + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(torch::Torch::AtenViewDtypeOp op, + PatternRewriter &rewriter) const override { + + Value in = op.getSelf(); + auto loc = op.getLoc(); + auto inType = cast(in.getType()); + auto resultType = cast(op.getType()); + + auto bType = inType.toBuiltinTensor(); + + if (auto dtype = dyn_cast(bType.getElementType())) { + bType = bType.clone( + rewriter.getType(dtype.getIntOrFloatBitWidth())); + } + + // Cast to the builtin tensor type. + Value builtinCast = + rewriter.create(loc, bType, + in); + + auto rType = resultType.toBuiltinTensor(); + if (auto dtype = dyn_cast(rType.getElementType())) { + rType = rType.clone( + rewriter.getType(dtype.getIntOrFloatBitWidth())); + } + + Value flowBitcast = rewriter.create( + loc, rType, builtinCast, ValueRange(), ValueRange()); + + auto torchCast = + rewriter.create( + loc, resultType, flowBitcast); + rewriter.replaceOp(op, torchCast); + return success(); + } +}; + class BitCastQuantizedMatmul : public OpRewritePattern { public: @@ -117,7 +159,7 @@ class BitCastQuantTensorPass final void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); diff --git a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir index fad4e7c9194a..95465956a1d2 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir @@ -14,3 +14,13 @@ func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8] %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16> return %output : !torch.vtensor<[1,1,8],f16> } + +// ----- + +// CHECK-LABEL: @view_type +func.func @view_type(%arg0 : !torch.vtensor<[295501824],ui8>) -> !torch.vtensor<[147750912],si16> { + %int4 = torch.constant.int 4 + // CHECK: flow.tensor.bitcast %[[IN:.+]] : tensor<295501824xi8> -> tensor<147750912xi16> + %0 = torch.aten.view.dtype %arg0, %int4 : !torch.vtensor<[295501824],ui8>, !torch.int -> !torch.vtensor<[147750912],si16> + return %0 : !torch.vtensor<[147750912],si16> +}