Skip to content

Commit

Permalink
stablehlo patch
Browse files Browse the repository at this point in the history
  • Loading branch information
sjain-stanford committed Jan 29, 2024
1 parent a6c9590 commit c01b129
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion externals/stablehlo
6 changes: 4 additions & 2 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
return failure();

auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0}));
op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0}),
elementTy);

Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -666,7 +667,8 @@ LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite(
/*indexVectorDim=*/indexVecDim);

auto stablehloScatterOp = rewriter.create<stablehlo::ScatterOp>(
loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false);
loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers,
false, false);

// config update computation function: just return the element from src.
Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock();
Expand Down

0 comments on commit c01b129

Please sign in to comment.