Skip to content

Commit

Permalink
add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 committed Aug 15, 2024
1 parent e5500c7 commit 03745ab
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 2 deletions.
2 changes: 1 addition & 1 deletion compiler/include/byteir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def MhloToCat : Pass<"mhlo-to-cat", "func::FuncOp"> {
// HloToTensor
//===----------------------------------------------------------------------===//

def ConvertHloToTensor : Pass<"hlo-to-tensor", "func::FuncOp"> {
def ConvertHloToTensor : Pass<"convert-hlo-to-tensor", "func::FuncOp"> {
let summary = "Convert hlo op to Tensor op.";
let constructor = "mlir::createConvertHloToTensorPass()";
let dependentDialects = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: byteir-opt -hlo-to-tensor --canonicalize --split-input-file %s | FileCheck %s
// RUN: byteir-opt -convert-hlo-to-tensor --canonicalize --split-input-file %s | FileCheck %s

func.func @forward(%arg0: tensor<6x8x5xf32>, %arg1: tensor<6x1x5xf32>) -> tensor<6x8x5xf32> {
%0 = mhlo.constant dense<0> : tensor<1x1xi64>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
func.func @forward(%arg0: tensor<6x8x5xf32>, %arg1: tensor<6x1x5xf32>) -> tensor<6x8x5xf32> {
%c = stablehlo.constant dense<0> : tensor<1x1xi64>
%0 = "stablehlo.scatter"(%arg0, %c, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
stablehlo.return %arg3 : tensor<f32>
}) : (tensor<6x8x5xf32>, tensor<1x1xi64>, tensor<6x1x5xf32>) -> tensor<6x8x5xf32>
return %0 : tensor<6x8x5xf32>
}
8 changes: 8 additions & 0 deletions tests/numerical_test/mlir_tests/ops/scatter_insert_slice.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
func.func @forward(%arg0: tensor<6x8x5xf32>, %arg1: tensor<6x1x5xf32>) -> tensor<6x8x5xf32> {
%0 = mhlo.constant dense<0> : tensor<1x1xi64>
%1 = "mhlo.scatter"(%arg0, %0, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 2], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
mhlo.return %arg3 : tensor<f32>
}) : (tensor<6x8x5xf32>, tensor<1x1xi64>, tensor<6x1x5xf32>) -> tensor<6x8x5xf32>
return %1 : tensor<6x8x5xf32>
}
13 changes: 13 additions & 0 deletions tests/numerical_test/torch_e2e_testing/test_suite/basic.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,16 @@ def forward(self, x0):
@register_test_case(module_factory=lambda: ContiguousSliceModule())
def ContiguousSliceModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 64))

# ==============================================================================

class InsertSliceScatterModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten.slice_scatter(x, src, dim=1, start=0, end=1, step=1)

@register_test_case(module_factory=lambda: InsertSliceScatterModule)
def InsertSliceScatterModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8, 5), tu.rand(6, 1, 5))

0 comments on commit 03745ab

Please sign in to comment.