From 35767dd6407312f2496bb8068f6761ab6f4c57c7 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 30 Jan 2024 14:13:01 -0800 Subject: [PATCH] stablehlo patches --- externals/stablehlo | 2 +- lib/Conversion/TorchToStablehlo/GatherScatter.cpp | 2 +- lib/Conversion/TorchToStablehlo/Reduction.cpp | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/externals/stablehlo b/externals/stablehlo index f2b687771302..7767179364a9 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit f2b6877713027370f26569929fb5d841a79d3589 +Subproject commit 7767179364a921e31c3a008613a645a84b0e2232 diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index ce56a51b181e..fc8e924a959f 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -334,7 +334,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return failure(); auto stablehloReduceOp = rewriter.create( - op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0}), + op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}), elementTy); Region ®ion = stablehloReduceOp.getBody(); diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 36f4d49e9a99..97196489559c 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -130,7 +130,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initValue, initIndex, }, - rewriter.getI64TensorAttr(dim)); + rewriter.getDenseI64ArrayAttr(dim)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); @@ -412,7 +412,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -473,7 +473,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -535,7 +535,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -625,7 +625,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -729,7 +729,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto reduceOp = rewriter.create( op->getLoc(), squareOp.getResult(), initValue, - rewriter.getI64TensorAttr(dims)); + rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -848,7 +848,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ord, nullptr); auto reduceOp = rewriter.create( - op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims)); + op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock();