Skip to content

Commit

Permalink
stablehlo patches
Browse files Browse the repository at this point in the history
  • Loading branch information
sjain-stanford committed Jan 30, 2024
1 parent c01b129 commit 35767dd
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion externals/stablehlo
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
return failure();

auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0}),
op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}),
elementTy);

Region &region = stablehloReduceOp.getBody();
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
initValue,
initIndex,
},
rewriter.getI64TensorAttr(dim));
rewriter.getDenseI64ArrayAttr(dim));

Block &block = stablehloReduceOp.getBody().emplaceBlock();

Expand Down Expand Up @@ -412,7 +412,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(

llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -473,7 +473,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -535,7 +535,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -625,7 +625,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(

llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -729,7 +729,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(

auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue,
rewriter.getI64TensorAttr(dims));
rewriter.getDenseI64ArrayAttr(dims));

Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -848,7 +848,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
ord, nullptr);

auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims));
op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));

Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down

0 comments on commit 35767dd

Please sign in to comment.