Skip to content

Commit

Permalink
bump stablehlo
Browse files Browse the repository at this point in the history
stablehlo patch

stablehlo patches
  • Loading branch information
sjain-stanford committed Jan 30, 2024
1 parent 032f225 commit abdeec4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion externals/stablehlo
6 changes: 4 additions & 2 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
return failure();

auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0}));
op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}),
elementTy);

Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -666,7 +667,8 @@ LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite(
/*indexVectorDim=*/indexVecDim);

auto stablehloScatterOp = rewriter.create<stablehlo::ScatterOp>(
loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false);
loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers,
false, false);

// config update computation function: just return the element from src.
Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock();
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
initValue,
initIndex,
},
rewriter.getI64TensorAttr(dim));
rewriter.getDenseI64ArrayAttr(dim));

Block &block = stablehloReduceOp.getBody().emplaceBlock();

Expand Down Expand Up @@ -412,7 +412,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(

llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -473,7 +473,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -535,7 +535,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -625,7 +625,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(

llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -729,7 +729,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(

auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue,
rewriter.getI64TensorAttr(dims));
rewriter.getDenseI64ArrayAttr(dims));

Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -848,7 +848,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
ord, nullptr);

auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims));
op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));

Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down

0 comments on commit abdeec4

Please sign in to comment.