Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 #2821

Merged
merged 7 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 528 files
18 changes: 9 additions & 9 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
adaptor.getAlpha(), outElemTy);
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions);
}

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
Expand Down Expand Up @@ -424,7 +424,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Expand Down Expand Up @@ -542,7 +542,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
} else {
return op.emitError("operator haven't been supported");
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastCompareOp>(
op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr,
compareTypeAttr);
Expand Down Expand Up @@ -570,7 +570,7 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern<AtenOpT> {
Value rhs =
hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType);

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
Expand Down Expand Up @@ -757,7 +757,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers));
rewriter.getDenseI64ArrayAttr(dimensionNumbers));
}
return success();
}
Expand Down Expand Up @@ -887,7 +887,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Expand Down Expand Up @@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(

Value window =
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
DenseIntElementsAttr broadcastDimensions;
DenseI64ArrayAttr broadcastDimensions;
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
broadcastDimensions);
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, mulOut, start,
Expand Down Expand Up @@ -1721,7 +1721,7 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
rewriter.create<shape::ShapeOfOp>(op->getLoc(), adaptor.getSelf());
Value bcastScalar = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(), outType, scalarTensor, shapeTensor,
rewriter.getI64TensorAttr({}));
rewriter.getDenseI64ArrayAttr({}));
rewriter.replaceOp(op, bcastScalar);
return success();
}
Expand Down
10 changes: 6 additions & 4 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
return failure();

auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0}));
op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}),
elementTy);

Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -510,7 +511,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, input, gatherIndicies, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
rewriter.getDenseI64ArrayAttr(sliceSizes));
return success();
}

Expand Down Expand Up @@ -666,7 +667,8 @@ LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite(
/*indexVectorDim=*/indexVecDim);

auto stablehloScatterOp = rewriter.create<stablehlo::ScatterOp>(
loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false);
loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers,
false, false);

// config update computation function: just return the element from src.
Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock();
Expand Down Expand Up @@ -833,7 +835,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, resultType, input, finalIndexTensor, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
rewriter.getDenseI64ArrayAttr(sliceSizes));
return success();
}

Expand Down
34 changes: 12 additions & 22 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
RankedTensorType outTy =
RankedTensorType::get(shape, tensorTy.getElementType());

RankedTensorType attrTy =
RankedTensorType::get({static_cast<int64_t>(broadcastDims.size())},
rewriter.getIntegerType(64));
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
auto broadcastAttr = rewriter.getDenseI64ArrayAttr(broadcastDims);

auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, stablehloShape, broadcastAttr);
Expand Down Expand Up @@ -549,8 +546,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {

// Prepare for transposed convolution
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr stablehloStride =
rewriter.getI64TensorAttr(stablehloStrideVec);
auto stablehloStride = rewriter.getDenseI64ArrayAttr(stablehloStrideVec);
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
for (int i = 0; i < nSpatialDims; ++i) {
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
Expand All @@ -563,15 +559,15 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
stablehloPaddingVec);
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
DenseIntElementsAttr stablehloLhsDilation =
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
auto stablehloLhsDilation =
rewriter.getDenseI64ArrayAttr(stablehloLhsDilationVec);
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(),
stablehloRhsDilationVec.begin());
DenseIntElementsAttr stablehloRhsDilation =
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
auto stablehloRhsDilation =
rewriter.getDenseI64ArrayAttr(stablehloRhsDilationVec);

DenseElementsAttr windowReversal;
DenseBoolArrayAttr windowReversal;
ArrayAttr precisionConfig;

SmallVector<int64_t> spatialDims;
Expand Down Expand Up @@ -614,10 +610,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
int64_t nDims = outType.getRank();

// Get stablehlo::ConvolutionOp attributes
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(stride.size())},
rewriter.getI64Type()),
stride);
auto stablehloWindowStride = rewriter.getDenseI64ArrayAttr(stride);
std::vector<int64_t> stablehloPaddingVec;
for (size_t i = 0; i < padding.size(); i++) {
stablehloPaddingVec.emplace_back(padding[i]);
Expand All @@ -628,10 +621,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
rewriter.getI64Type()),
stablehloPaddingVec);
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(dilation.size())},
rewriter.getI64Type()),
dilation);
auto stablehloRhsDilation = rewriter.getDenseI64ArrayAttr(dilation);
SmallVector<int64_t> spatialDimensions;
for (int64_t i = 2; i < nDims; i++) {
spatialDimensions.emplace_back(i);
Expand All @@ -648,8 +638,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
/*outputSpatialDimensions=*/spatialDimensions);

// stablehlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr stablehloLhsDilation;
DenseElementsAttr windowReversal;
DenseI64ArrayAttr stablehloLhsDilation;
DenseBoolArrayAttr windowReversal;
ArrayAttr precisionConfig;

auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
Expand Down Expand Up @@ -781,7 +771,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
options.dimSizeIndexBits);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, outTy, stablehloConvResult, bias, bcastDimensions);
return success();
Expand Down
73 changes: 18 additions & 55 deletions lib/Conversion/TorchToStablehlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -242,19 +233,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -453,20 +435,10 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
Value initVal =
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -508,7 +480,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
.value();
}
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
return success();
Expand All @@ -528,7 +500,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(),
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({}));

Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
Expand Down Expand Up @@ -599,19 +571,10 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
stablehloPadding[dim * 2] = inputShape[dim] - 1;

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
initValue,
initIndex,
},
rewriter.getI64TensorAttr(dim));
rewriter.getDenseI64ArrayAttr(dim));

Block &block = stablehloReduceOp.getBody().emplaceBlock();

Expand Down Expand Up @@ -412,7 +412,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(

llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -473,7 +473,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -535,7 +535,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -625,7 +625,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(

llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -729,7 +729,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(

auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue,
rewriter.getI64TensorAttr(dims));
rewriter.getDenseI64ArrayAttr(dims));

Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -848,7 +848,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
ord, nullptr);

auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims));
op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));

Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down
Loading
Loading