Skip to content

Commit

Permalink
fix: check block type
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 committed Aug 15, 2024
1 parent 2c7e56c commit 963ce67
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
27 changes: 22 additions & 5 deletions compiler/lib/Conversion/HloToTensor/ConvertHloToTensor.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ struct ConvertScatterToInsertSlice
return failure();
}

auto inputs = op.getInputs();
Value input =
llvm::cast<mlir::TypedValue<mlir::RankedTensorType>>(*inputs.begin());
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape();

auto dimNumAttr = op.getScatterDimensionNumbersAttr();
auto insertedWindowDims = dimNumAttr.getInsertedWindowDims();
if (insertedWindowDims.size() != 1) {
Expand All @@ -49,6 +55,10 @@ struct ConvertScatterToInsertSlice
return failure();
}
}
if (updatedWindowDims.size() + 1 != inputShape.size()) {
return failure();
}

auto scatterDimsToOperands = dimNumAttr.getScatterDimsToOperandDims();
if (scatterDimsToOperands.size() != 1) {
return failure();
Expand All @@ -67,12 +77,19 @@ struct ConvertScatterToInsertSlice
return failure();
}

Region &region = op.getUpdateComputation();
if (region.getBlocks().size() != 1) {
return failure();
}

auto &block = region.front();
Operation *retOp = block.getTerminator();
auto computeOp = retOp->getOperand(0).getDefiningOp();
if (computeOp) {
return failure();
}

// Prepare arguments for InsertSlice.
auto inputs = op.getInputs();
Value input =
llvm::cast<mlir::TypedValue<mlir::RankedTensorType>>(*inputs.begin());
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Value> indices0 = {zero, zero};
Expand Down
1 change: 0 additions & 1 deletion tests/numerical_test/testset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _get_test_files_from_dir(directory):
"transpose1203.mlir",
"transpose2013.mlir",
"transpose120.mlir",
"scatter.mlir",
}

CUDA_ALL_SET = (CUDA_MLIR_TEST_SET | CUDA_TORCH_TEST_SET) - CUDA_XFAIL_SET
Expand Down

0 comments on commit 963ce67

Please sign in to comment.