Skip to content

Commit

Permalink
[torch] Support torch.aten.view.dtype conversion to flow (#18346)
Browse files Browse the repository at this point in the history
Convert to `flow.tensorbitcast` for supporting `i4` operations.
  • Loading branch information
rsuderman authored Aug 24, 2024
1 parent f8d8e60 commit bd78854
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
Expand All @@ -20,6 +21,47 @@ namespace mlir::iree_compiler::TorchInput {

namespace {

class BitCastViewDtype
: public OpRewritePattern<torch::Torch::AtenViewDtypeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(torch::Torch::AtenViewDtypeOp op,
PatternRewriter &rewriter) const override {

Value in = op.getSelf();
auto loc = op.getLoc();
auto inType = cast<torch::Torch::ValueTensorType>(in.getType());
auto resultType = cast<torch::Torch::ValueTensorType>(op.getType());

auto bType = inType.toBuiltinTensor();

if (auto dtype = dyn_cast<IntegerType>(bType.getElementType())) {
bType = bType.clone(
rewriter.getType<IntegerType>(dtype.getIntOrFloatBitWidth()));
}

// Cast to the builtin tensor type.
Value builtinCast =
rewriter.create<torch::TorchConversion::ToBuiltinTensorOp>(loc, bType,
in);

auto rType = resultType.toBuiltinTensor();
if (auto dtype = dyn_cast<IntegerType>(rType.getElementType())) {
rType = rType.clone(
rewriter.getType<IntegerType>(dtype.getIntOrFloatBitWidth()));
}

Value flowBitcast = rewriter.create<IREE::Flow::TensorBitCastOp>(
loc, rType, builtinCast, ValueRange(), ValueRange());

auto torchCast =
rewriter.create<torch::TorchConversion::FromBuiltinTensorOp>(
loc, resultType, flowBitcast);
rewriter.replaceOp(op, torchCast);
return success();
}
};

class BitCastQuantizedMatmul
: public OpRewritePattern<torch::Torch::OperatorOp> {
public:
Expand Down Expand Up @@ -117,7 +159,7 @@ class BitCastQuantTensorPass final
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<BitCastQuantizedMatmul>(context);
patterns.add<BitCastQuantizedMatmul, BitCastViewDtype>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,13 @@ func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8]
%output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16>
return %output : !torch.vtensor<[1,1,8],f16>
}

// -----

// CHECK-LABEL: @view_type
func.func @view_type(%arg0 : !torch.vtensor<[295501824],ui8>) -> !torch.vtensor<[147750912],si16> {
%int4 = torch.constant.int 4
// CHECK: flow.tensor.bitcast %[[IN:.+]] : tensor<295501824xi8> -> tensor<147750912xi16>
%0 = torch.aten.view.dtype %arg0, %int4 : !torch.vtensor<[295501824],ui8>, !torch.int -> !torch.vtensor<[147750912],si16>
return %0 : !torch.vtensor<[147750912],si16>
}

0 comments on commit bd78854

Please sign in to comment.