Skip to content

Commit

Permalink
[frontend] Add tests for tosa operator conversion functions
Browse files Browse the repository at this point in the history
  • Loading branch information
xTayEx committed Oct 24, 2023
1 parent b0ee68e commit 5c4c44e
Show file tree
Hide file tree
Showing 23 changed files with 757 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/BuddyPython/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,5 @@ In PyTorch FX graph, there exist dependencies between operators. These dependenc
### Import Strategy
In order to make the importing procedure more robust, we've implement a fallback importing strategy. This machenism is consisted of two parts, i.e. primary registry and fallback registry. When importer is going to import a PyTorch operator, it will first search the primary registry for the operator's mapping function. If the operator is not found in the primary registry, the importer will try to search the fallback registry. By default, the importer will use `tosa` registry as the primary registry, and all the other registries as the fallback registry.

## Known Limitations
## Limitations
Currently, we only support AOT execution of the generated MLIR code. To execute the generated MLIR code, one need to use the llvm tooltrain to compile it to an executable binary. We are working on the JIT execution of the generated MLIR code.
3 changes: 1 addition & 2 deletions frontend/Python/ops/tosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,7 @@ def reshape_op(node, symbol_table):
Import the reshape operation.
From PyTorch `aten.reshape.default` operator to MLIR TOSA `reshape` operation.
Note: If the new shape contains one and only one `-1`, the size of the
new shape will be inferred automatically.
Note: If the new shape contains one and only one `-1`, the size of the new shape will be inferred automatically.
"""
input1 = symbol_table.get((str(node.args[0]), 0))
new_shape = node.args[1]
Expand Down
35 changes: 35 additions & 0 deletions tests/Python/test_addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x, y, z):
return torch.ops.aten.addmm(z, x, y)


in1 = torch.randn(4, 2)
in2 = torch.randn(2, 4)
in3 = torch.randn(4, 4)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1, in2, in3)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.matmul"
# CHECK: %{{.*}} = "tosa.add"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
34 changes: 34 additions & 0 deletions tests/Python/test_amax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import random
import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x, dim):
return torch.ops.aten.amax(x, dim, True)


in1 = torch.randn(4, 5, 2, 9)
dim = [random.randint(0, 3)]

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1, dim)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.reduce_max"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
34 changes: 34 additions & 0 deletions tests/Python/test_arith_div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x, y):
return x / y


in1 = torch.randn(10)
in2 = torch.randn(10)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1, in2)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.reciprocal"
# CHECK: %{{.*}} = "tosa.mul"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
33 changes: 33 additions & 0 deletions tests/Python/test_arith_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x, y):
return x * y


in1 = torch.randn(10)
in2 = torch.randn(10)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1, in2)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.mul"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
33 changes: 33 additions & 0 deletions tests/Python/test_arith_sub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x, y):
return x - y


in1 = torch.randn(10)
in2 = torch.randn(10)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1, in2)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.sub"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
33 changes: 33 additions & 0 deletions tests/Python/test_bmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x, y):
return torch.ops.aten.bmm(x, y)


in1 = torch.randn(10, 3, 2)
in2 = torch.randn(10, 2, 3)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1, in2)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.matmul"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
32 changes: 32 additions & 0 deletions tests/Python/test_clone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x):
return torch.ops.aten.clone(x)


in1 = torch.randn(10)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.identity"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
33 changes: 33 additions & 0 deletions tests/Python/test_convert_element_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x, to_cast_type):
return torch.ops.prims.convert_element_type(x, to_cast_type)


in1 = torch.randn(10).to(torch.float32)
to_cast_type = torch.float16

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1, to_cast_type)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.cast"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
58 changes: 58 additions & 0 deletions tests/Python/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(weight, indices):
return torch.ops.aten.embedding(weight, indices)


# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

# test trivial case
weight = torch.randn(10, 5)
indices = torch.randint(10, (3, 3)).to(torch.int32)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(weight, indices)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.reshape"
# CHECK: %{{.*}} = "tosa.reshape"
# CHECK: %{{.*}} = "tosa.gather"
# CHECK: %{{.*}} = "tosa.reshape"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)


# test cast case
weight = torch.randn(10, 5)
indices = torch.randint(10, (3, 3)).to(torch.int64)


foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(weight, indices)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.reshape"
# CHECK: %{{.*}} = "tosa.cast"
# CHECK: %{{.*}} = "tosa.reshape"
# CHECK: %{{.*}} = "tosa.gather"
# CHECK: %{{.*}} = "tosa.reshape"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
32 changes: 32 additions & 0 deletions tests/Python/test_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x):
return torch.ops.aten.exp(x)


in1 = torch.randn(10)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
foo_mlir(in1)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.exp"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
print(dynamo_compiler.imported_module)
Loading

0 comments on commit 5c4c44e

Please sign in to comment.