Skip to content

Commit

Permalink
[examples][MLIRPython] add operators folder
Browse files Browse the repository at this point in the history
  • Loading branch information
xTayEx committed Oct 15, 2023
1 parent b4c1949 commit a58079f
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 23 deletions.
10 changes: 9 additions & 1 deletion examples/MLIRPython/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
import torch._dynamo as dynamo
from buddy.compiler import BuddyDynamoCompiler
from torch._inductor.decomposition import decompositions as inductor_decomp
from buddy.operators.tosa_operators import (
operators_registry as tosa_operators_registry,
)
from buddy.operators.math_operators import (
operators_registry as math_operators_registry,
)


def foo(c, a, b):
return torch.addmm(c, a, b)


dynamo_compiler = BuddyDynamoCompiler(
func_name="forward", aot_autograd_decomposition=inductor_decomp
func_name="forward",
aot_autograd_decomposition=inductor_decomp,
operators_registry={**tosa_operators_registry, **math_operators_registry},
)

foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
Expand Down
10 changes: 9 additions & 1 deletion examples/MLIRPython/arith_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
import torch._dynamo as dynamo
from buddy.compiler import BuddyDynamoCompiler
from torch._inductor.decomposition import decompositions as inductor_decomp
from buddy.operators.tosa_operators import (
operators_registry as tosa_operators_registry,
)
from buddy.operators.math_operators import (
operators_registry as math_operators_registry,
)


def foo(x, y):
return x + y


dynamo_compiler = BuddyDynamoCompiler(
func_name="forward", aot_autograd_decomposition=inductor_decomp
func_name="forward",
aot_autograd_decomposition=inductor_decomp,
operators_registry={**tosa_operators_registry, **math_operators_registry},
)
foo_mlir = dynamo.optimize(dynamo_compiler)(foo)
float32_in1 = torch.randn(10).to(torch.float32)
Expand Down
27 changes: 15 additions & 12 deletions examples/MLIRPython/buddy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,25 @@
# ===---------------------------------------------------------------------------

import operator
from typing import Any, List, Union, Optional, Callable
from typing import Any, List, Union, Optional

import mlir.dialects.func as func
import mlir.ir as ir
import torch
from mlir.passmanager import PassManager
from torch._functorch.aot_autograd import aot_module_simplified
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.operators_gen import operation_func


class BuddyDynamoCompiler:
def __init__(
self,
operators_registry: dict,
func_name: str = "main",
aot_autograd_decomposition: Optional[dict] = None,
) -> None:
self.func_name = func_name
self.aot_autograd_decoposition = aot_autograd_decomposition
self._operators_registry = operators_registry
self._func_name = func_name
self._aot_autograd_decoposition = aot_autograd_decomposition
self._bufferize_pipelines = [
"func.func(tosa-to-linalg-named)",
"func.func(tosa-to-linalg)",
Expand Down Expand Up @@ -74,7 +73,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
# Initialize the MLIR context.
ctx = ir.Context()
with ir.Location.unknown(ctx):
fx_importer = FXGraphImporter(_gm, _inputs)
fx_importer = FXGraphImporter(
_gm, _inputs, self._operators_registry
)
llvm_lowerer = LLVMLowerer(
self._bufferize_pipelines, self._llvm_lower_pipelines
)
Expand All @@ -87,7 +88,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
gm,
inputs,
fw_compiler=_compiler,
decompositions=self.aot_autograd_decoposition,
decompositions=self._aot_autograd_decoposition,
)


Expand All @@ -98,6 +99,7 @@ def __init__(
self,
gm: torch.fx.GraphModule,
inputs: List[torch.Tensor],
operators_registry: dict,
func_name: str = "forward",
):
"""
Expand All @@ -113,6 +115,7 @@ def __init__(
self._inputs = inputs
self._num_input_visited = 0
self._module = ir.Module.create()
self._operators_registry = operators_registry

def import_graph(self) -> ir.Module:
"""Import the FX graph, generate an MLIR module in high-level dialects.
Expand Down Expand Up @@ -178,9 +181,9 @@ def _import_placeholder(self, node: torch.fx.Node, args_list):
def _import_op(self, node: torch.fx.Node):
op_name = node.target.__name__

op_ret: Union[ir.Operation, tuple] = operation_func[op_name](
node, self._symbol_table
)
op_ret: ir.Operation | ir.Value | tuple = self._operators_registry[
op_name
](node, self._symbol_table)
if isinstance(op_ret, tuple):
for i, operation in enumerate(op_ret):
self._symbol_table[(str(node.name), i)] = operation.result
Expand Down Expand Up @@ -216,5 +219,5 @@ def lower(self, module: ir.Module) -> Any:
pm.add(pipeline)
pm.run(module.operation)
print(module)

return module
12 changes: 12 additions & 0 deletions examples/MLIRPython/buddy/operators/math_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from mlir.dialects import math


def erf_op(node, symbol_table):
input_ = symbol_table.get((str(node.args[0]), 0))
op = math.ErfOp(input_)
return op


operators_registry = {
"erf.default": erf_op,
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, List, Union, Tuple

import mlir.ir as ir
from mlir.dialects import tosa, math, tensor
from mlir.dialects import tosa, tensor
import torch


Expand Down Expand Up @@ -195,12 +195,6 @@ def _inner_op(result_type, input1, input2):
return _gen_arith_binary_op(input1, input2, _inner_op)


def erf_op(node, symbol_table):
input_ = symbol_table.get((str(node.args[0]), 0))
op = math.ErfOp(input_)
return op


def tanh_op(node, symbol_table):
input1 = symbol_table.get((str(node.args[0]), 0))
sizes = ir.RankedTensorType(input1.type).shape
Expand Down Expand Up @@ -629,7 +623,7 @@ def sum_op(node, symbol_table):
# div, embedding, erf, exp, expand, getitem, gt, inductor_lookup_seed
# inductor_random, inductor_seeds, mul, permute, reshape, rsqrt
# select, slice, sub, tanh, unsqueeze, var_mean
operation_func = {
operators_registry = {
"add.Tensor": add_op,
"mul.Tensor": mul_op,
"sub.Tensor": sub_op,
Expand All @@ -640,7 +634,6 @@ def sum_op(node, symbol_table):
"bmm.default": bmm_op,
"clone.default": clone_op,
"div.Tensor": div_op,
"erf.default": erf_op,
"exp.default": exp_op,
"expand.default": expand_op,
"var_mean.correction": var_mean_op,
Expand Down

0 comments on commit a58079f

Please sign in to comment.