diff --git a/examples/MLIRPython/.style.yapf b/examples/MLIRPython/.style.yapf deleted file mode 100644 index 9ef1dc15ba..0000000000 --- a/examples/MLIRPython/.style.yapf +++ /dev/null @@ -1,4 +0,0 @@ -[style] - based_on_style = google - column_limit = 80 - indent_width = 2 diff --git a/examples/MLIRPython/addmm.py b/examples/MLIRPython/addmm.py deleted file mode 100644 index 29cf695635..0000000000 --- a/examples/MLIRPython/addmm.py +++ /dev/null @@ -1,20 +0,0 @@ -from buddy.compiler import DynamoCompiler -import torch -import torch._dynamo as dynamo - - -def foo(c, a, b): - return torch.addmm(c, a, b) - - -foo_mlir = dynamo.optimize(DynamoCompiler)(foo) - -a_float32 = torch.randn(3, 2) -b_float32 = torch.randn(2, 3) -c_float32 = torch.randn(3, 3) -foo_mlir(c_float32, a_float32, b_float32) - -a_int32 = torch.randint(10, (3, 2)).to(torch.int32) -b_int32 = torch.randint(10, (2, 3)).to(torch.int32) -c_int32 = torch.randint(10, (3, 3)).to(torch.int32) -foo_mlir(c_int32, a_int32, b_int32) diff --git a/examples/MLIRPython/arith_add.py b/examples/MLIRPython/arith_add.py deleted file mode 100644 index 0c2b7a11cf..0000000000 --- a/examples/MLIRPython/arith_add.py +++ /dev/null @@ -1,17 +0,0 @@ -from buddy import compiler -import torch -import torch._dynamo as dynamo - - -def foo(x, y): - return x + y - - -foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo) -float32_in1 = torch.randn(10).to(torch.float32) -float32_in2 = torch.randn(10).to(torch.float32) -foo_mlir(float32_in1, float32_in2) - -int32_in1 = torch.randint(0, 10, (10,)).to(torch.int32) -int32_in2 = torch.randint(0, 10, (10,)).to(torch.int32) -foo_mlir(int32_in1, int32_in2) diff --git a/examples/MLIRPython/buddy/compiler.py b/examples/MLIRPython/buddy/compiler.py deleted file mode 100644 index 6a24fcfef3..0000000000 --- a/examples/MLIRPython/buddy/compiler.py +++ /dev/null @@ -1,177 +0,0 @@ -"""The buddy compiler backend for torch dynamo. -""" -import operator -from typing import List, Union, Callable - -import torch -from torch._functorch.aot_autograd import aot_module_simplified -import mlir.ir as ir -import mlir.dialects.func as func -from mlir.passmanager import PassManager - -from .operators_gen import operation_func - - -def DynamoCompiler(gm: torch.fx.GraphModule, - inputs: List[torch.Tensor]) -> Callable: - """The main entry point of buddy compiler for torch dynamo. It takes a FX - graph module and a list of inputs as parameters. The compiler will first use - PyTorch's AOT autograd to lower FX graph in Torch IR to Aten/Prims IR. Then - it will map the operators in Aten/Prims IR to MLIR operations and generate an - MLIR module. Finally, It will lower the MLIR module to LLVM dialect. - - Args: - gm (torch.fx.GraphModule): The FX graph module to be compiled. - inputs (List[torch.Tensor]): The inputs of the FX graph module. - - Returns: - Callable: A compiled function that equivalent to the FX graph. - - """ - - def _compiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): - """Compile a FX graph in Aten/Prims IR to MLIR.""" - print("Custom Compiler from FX Graph to MLIR:") - print("-------------------------------------------------------------------") - gm.graph.print_tabular() - # Initialize the MLIR context. - ctx = ir.Context() - with ir.Location.unknown(ctx): - fx_importer = FXGraphImporter(gm, inputs) - module = fx_importer.import_graph() - module = Lowering(module) - return gm.forward - - return aot_module_simplified(gm, inputs, fw_compiler=_compiler) - - -class FXGraphImporter: - """The FX graph importer class.""" - - def __init__( - self, - gm: torch.fx.GraphModule, - inputs: List[torch.Tensor], - func_name: str = "main", - ): - """ - Args: - gm (torch.fx.GraphModule): The FX graph module that will be imported. - inputs (List[torch.Tensor]): Input tensor(s) of the FX graph. - func_name (str): Name of the generated MLIR func. - - """ - self._symbol_table = {} - self._gm = gm - self._func_name = func_name - self._inputs = inputs - self._num_input_visited = 0 - self._module = ir.Module.create() - - def import_graph(self) -> ir.Module: - """Import the FX graph, generate an MLIR module in high-level dialects. - - Returns: - mlir.ir.Module: An MLIR module in high-level dialects. - - """ - with ir.InsertionPoint(self._module.body): - arguments = [] - for arg in self._inputs: - shape_list = list(arg.shape) - dtype = arg.dtype - match dtype: - case torch.int32: - mlir_dtype = ir.IntegerType.get_signless(32) - case torch.float32: - mlir_dtype = ir.F32Type.get() - case _: - raise NotImplementedError( - f"Unsupported dtype {dtype} for argument {arg}") - tensor_arg = ir.RankedTensorType.get(shape_list, mlir_dtype) - arguments.append(tensor_arg) - - @func.FuncOp.from_py_func(*arguments, name=self._func_name) - def generated_func(*args): - args_list = list(args) - for node in self._gm.graph.nodes: - if node.op == "output": - output_node_args = node.args[0] - returns = [] - for output_arg in output_node_args: - op = self._symbol_table.get((str(output_arg), 0)) - returns.append(op) - - self._symbol_table[("output", 0)] = returns - elif node.op == "placeholder": - self._import_placeholder(node, args_list) - else: - if node.target is operator.getitem: - self._symbol_table[(str(node.name), - 0)] = self._symbol_table[(node.args[0], - node.args[1])] - else: - self._import_op(node) - - return self._symbol_table.get(("output", 0)) - - print("Printing the generated MLIR...") - print(self._module) - return self._module - - def _import_placeholder(self, node: torch.fx.Node, args_list): - placeholder_name = args_list[self._num_input_visited] - self._symbol_table[(str(node.name), 0)] = placeholder_name - self._num_input_visited += 1 - - def _import_op(self, node: torch.fx.Node): - op_name = node.target.__name__ - - op_ret: Union[ir.Operation, - tuple] = operation_func[op_name](node, self._symbol_table) - if isinstance(op_ret, tuple): - for i, operation in op_ret: - self._symbol_table[(str(node.name), i)] = operation.result - else: - self._symbol_table[(str(node.name), 0)] = op_ret.result - - -def Lowering(module: ir.Module): - """Lower an MLIR module to LLVM dialect. - - Args: - module (mlir.ir.Module): An MLIR module that need to be lowered. - - Returns: - mlir.ir.Module: An MLIR module in LLVM dialect. - - """ - print("-------------------------------------------------------------------") - print("Bufferizing the module ...") - pm = PassManager("builtin.module") - pm.add("func.func(tosa-to-linalg-named)") - pm.add("func.func(tosa-to-linalg)") - pm.add("func.func(tosa-to-tensor)") - pm.add("func.func(tosa-to-arith)") - pm.add("empty-tensor-to-alloc-tensor") - pm.add("convert-elementwise-to-linalg") - pm.add("arith-bufferize") - pm.add("func.func(linalg-bufferize)") - pm.add("func.func(tensor-bufferize)") - pm.add("func-bufferize") - pm.run(module.operation) - print(module) - print("-------------------------------------------------------------------") - print("Lowering the module to LLVM dialect ...") - pm.add("func.func(buffer-deallocation)") - pm.add("func.func(convert-linalg-to-loops)") - pm.add("convert-scf-to-cf") - pm.add("convert-linalg-to-llvm") - pm.add("convert-arith-to-llvm") - pm.add("expand-strided-metadata") - pm.add("finalize-memref-to-llvm") - pm.add("convert-func-to-llvm") - pm.add("reconcile-unrealized-casts") - pm.run(module.operation) - print(module) - return module diff --git a/examples/MLIRPython/buddy/operators_gen.py b/examples/MLIRPython/buddy/operators_gen.py deleted file mode 100644 index 378d40ee92..0000000000 --- a/examples/MLIRPython/buddy/operators_gen.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Generate the MLIR operations for the operators in the FX graph. -""" -from typing import Dict, Tuple, List - -import torch - -import mlir.ir as ir -from mlir.dialects import tosa, linalg, arith - - -def _broadcast_shape(tensor_input1: ir.Value, - tensor_input2: ir.Value) -> List[int]: - """Calculate the broadcast shape of two tensors with broadcastable shapes - according to PyTorch's broadcast semantics: https://pytorch.org/docs/stable/notes/broadcasting.html""" - shp1 = ir.RankedTensorType(tensor_input1.type).shape - shp2 = ir.RankedTensorType(tensor_input2.type).shape - if len(shp1) < len(shp2): - shp1, shp2 = shp2, shp1 - while len(shp2) < len(shp1): - shp2.insert(0, 1) - for idx, (dim1, dim2) in enumerate(zip(shp1, shp2)): - shp1[idx] = shp2[idx] = max(dim1, dim2) - - return shp1 - - -def AddOp(node: torch.fx.Node, - symbol_table: Dict[Tuple[str, int], ir.Operation]) -> ir.Operation: - """Map aten.add.Tensor to tosa.add. - - Args: - node: A FX graph containing the aten.add.Tensor operator and its parameter. - symbol_table: The symbol table that records the mapping between symbols and operations. - - Returns: - ir.Operation: The generated tosa.add operation. - - """ - input1 = symbol_table.get((str(node.args[0]), 0)) - input2 = symbol_table.get((str(node.args[1]), 0)) - broadcasted_shp = _broadcast_shape(input1, input2) - sizes = broadcasted_shp - result_element_type = ir.RankedTensorType(input1.type).element_type - add_result_tensor_type = ir.RankedTensorType.get(sizes, result_element_type) - op = tosa.AddOp(add_result_tensor_type, input1, input2) - return op - - -def AddMMOp(node: torch.fx.Node, - symbol_table: Dict[Tuple[str, int], ir.Operation]) -> ir.Operation: - """Map aten.addmm.default to MLIR operation. - - Args: - node (torch.fx.Node): A FX graph containing the aten.addmm.default operator and its parameter. - symbol_table (Dict[Tuple[str, int], ir.Operation]): The symbol table that records the mapping between symbols and operations. - - Returns: - ir.Operation: The generated MLIR operation representing aten.addmm.default - - """ - input_ = symbol_table.get((str(node.args[0]), 0)) - mat1 = symbol_table.get((str(node.args[1]), 0)) - mat2 = symbol_table.get((str(node.args[2]), 0)) - mat1_shp = ir.RankedTensorType(mat1.type).shape - mat2_shp = ir.RankedTensorType(mat2.type).shape - mat1 = tosa.ReshapeOp(mat1, [1, *mat1_shp]).output - mat2 = tosa.ReshapeOp(mat2, [1, *mat2_shp]).output - - matmul_result_shp = [1, mat1_shp[0], mat2_shp[1]] - result_element_type = ir.RankedTensorType(input_.type).element_type - matmul_result_type = ir.RankedTensorType.get(matmul_result_shp, result_element_type) - matmul_op = tosa.MatMulOp(matmul_result_type, mat1, mat2) - matmul_result = tosa.ReshapeOp(matmul_op.c, matmul_result_shp[1:]) - - add_result_shp = [mat1_shp[0], mat2_shp[1]] - add_result_tensor_type = ir.RankedTensorType.get(add_result_shp, result_element_type) - op = tosa.AddOp(add_result_tensor_type, input_, matmul_result) - return op - - -operation_func = {"add.Tensor": AddOp, "addmm.default": AddMMOp} diff --git a/examples/MLIRPython/matmul.py b/examples/MLIRPython/matmul.py deleted file mode 100644 index d15ae1df68..0000000000 --- a/examples/MLIRPython/matmul.py +++ /dev/null @@ -1,11 +0,0 @@ -from buddy import compiler -import torch -import torch._dynamo as dynamo - -def foo(x, y): - return torch.matmul(x, y) - -foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo) -in1 = torch.randn(2, 3) -in2 = torch.randn(3, 5) -foo_mlir(in1, in2) diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 68c278d1c4..bcda5121d3 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -18,12 +18,12 @@ # # ===--------------------------------------------------------------------------- +import torch import array -from typing import Dict, List, Union, Tuple +from typing import Dict, List, Tuple, Union import mlir.ir as ir -from mlir.dialects import tosa, tensor -import torch +from mlir.dialects import tensor, tosa def _normalize_binary_operator_shape(shp1, shp2): @@ -118,12 +118,71 @@ def _normalize_binary_operator_args(arg1, arg2): ) +def addmm_op( + node, symbol_table: Dict[Tuple[str, int], ir.Operation] +) -> ir.Operation: + # get input + input_ = symbol_table.get((str(node.args[0]), 0)) + mat1 = symbol_table.get((str(node.args[1]), 0)) + mat2 = symbol_table.get((str(node.args[2]), 0)) + # get input shape + mat1_shp = ir.RankedTensorType(mat1.type).shape + mat2_shp = ir.RankedTensorType(mat2.type).shape + # append index because tosa.MatMulOp doesn't accept 2D tensor + mat1_reshape_op = tosa.ReshapeOp( + mat1, memoryview(array.array("i", [1, *mat1_shp])) + ) + mat2_reshape_op = tosa.ReshapeOp( + mat2, memoryview(array.array("i", [1, *mat2_shp])) + ) + # do matmul + result_element_type = ir.RankedTensorType(mat1.type).element_type + matmul_result_shp = [1, mat1_shp[0], mat2_shp[1]] + matmul_result_type = ir.RankedTensorType.get( + matmul_result_shp, result_element_type + ) + matmul_op = tosa.MatMulOp( + matmul_result_type, mat1_reshape_op.result, mat2_reshape_op.result + ) + # restore the shape + final_result_shape = [mat1_shp[0], mat2_shp[1]] + matmul_result_reshape_op = tosa.ReshapeOp( + matmul_op.c, memoryview(array.array("i", final_result_shape)) + ) + add_result_tensor_type = ir.RankedTensorType.get( + final_result_shape, result_element_type + ) + + op = _gen_arith_binary_op( + input_, matmul_result_reshape_op.result, tosa.AddOp + ) + return op + + +def bmm_op(node, symbol_table) -> ir.Operation: + input_ = symbol_table.get((str(node.args[0]), 0)) + mat2 = symbol_table.get((str(node.args[1]), 0)) + input_shp = ir.RankedTensorType(input_.type).shape + mat2_shp = ir.RankedTensorType(mat2.type).shape + sizes = [input_shp[0], input_shp[1], mat2_shp[2]] + result_element_type = ir.RankedTensorType(input_.type).element_type + result_type = ir.RankedTensorType.get(sizes, result_element_type) + op = tosa.MatMulOp(result_type, input_, mat2) + return op + + def add_op(node, symbol_table): input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) return _gen_arith_binary_op(input1, input2, tosa.AddOp) +def sub_op(node, symbol_table): + input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) + input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) + return _gen_arith_binary_op(input1, input2, tosa.SubOp) + + def mul_op(node, symbol_table): def _inner_op(result_type, input1, input2): return tosa.MulOp( @@ -139,7 +198,464 @@ def _inner_op(result_type, input1, input2): return _gen_arith_binary_op(input1, input2, _inner_op) +def div_op(node, symbol_table): + def _inner_op(result_type, input1, input2): + return tosa.MulOp( + result_type, + input1, + tosa.ReciprocalOp(input2.type, input2).result, + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0), + ) + + input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) + input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) + + return _gen_arith_binary_op(input1, input2, _inner_op) + + +def tanh_op(node, symbol_table): + input1 = symbol_table.get((str(node.args[0]), 0)) + sizes = ir.RankedTensorType(input1.type).shape + result_element_type = ir.RankedTensorType(input1.type).element_type + tanhResultTensorType = ir.RankedTensorType.get(sizes, result_element_type) + op = tosa.TanhOp(tanhResultTensorType, input1) + return op + + +def exp_op(node, symbol_table): + input1 = symbol_table.get((str(node.args[0]), 0)) + sizes = ir.RankedTensorType(input1.type).shape + result_element_type = ir.RankedTensorType(input1.type).element_type + expResultTensorType = ir.RankedTensorType.get(sizes, result_element_type) + op = tosa.ExpOp(expResultTensorType, input1) + return op + + +def rsqrt_op(node, symbol_table): + input1 = symbol_table.get((str(node.args[0]), 0)) + sizes = ir.RankedTensorType(input1.type).shape + result_element_type = ir.RankedTensorType(input1.type).element_type + rsqrt_result_tensor_type = ir.RankedTensorType.get( + sizes, result_element_type + ) + op = tosa.RsqrtOp(rsqrt_result_tensor_type, input1) + return op + + +def amax_op(node, symbol_table): + input1 = symbol_table.get((str(node.args[0]), 0)) + dim_val = node.args[1][0] + if dim_val < 0: + dim_val += len(ir.RankedTensorType(input1.type).shape) + signless_type = ir.IntegerType.get_signless(64) + dim_attr = ir.IntegerAttr.get(signless_type, dim_val) + op = tosa.ReduceMaxOp(input1, dim_attr) + return op + + +def reshape_op(node, symbol_table): + input1 = symbol_table.get((str(node.args[0]), 0)) + new_shape = node.args[1] + total_size = 1 + now_shape = ir.RankedTensorType(input1.type).shape + for dim_siz in now_shape: + total_size *= dim_siz + + neg_one_cnt = 0 + rest_size = 1 + for dim_siz in new_shape: + if dim_siz == -1: + neg_one_cnt += 1 + continue + rest_size *= dim_siz + + if neg_one_cnt != 0: + if neg_one_cnt > 1 or total_size % rest_size != 0: + raise ValueError("Can not infer the new shape!") + infer_dim_size = total_size // rest_size + for i, _ in enumerate(new_shape): + if new_shape[i] == -1: + new_shape[i] = infer_dim_size + + new_shape_content = array.array("i", new_shape) + new_shape_content = memoryview(new_shape_content) + op = tosa.ReshapeOp(input1, new_shape_content) + + return op + + +def unsqueeze_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + dim = node.args[1] + sizes = ir.RankedTensorType(input_tensor.type).shape + sizes.insert(dim, 1) + new_shape_content = array.array("i", sizes) + new_shape_content = memoryview(new_shape_content) + op = tosa.ReshapeOp(input_tensor, new_shape_content) + return op + + +def select_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + dim = node.args[1] + index = node.args[2] + + sizes = ir.RankedTensorType(input_tensor.type).shape + + new_sizes = sizes[:dim] + [1] + sizes[dim + 1 :] + new_sizes_attr = ir._denseI64ArrayAttr(new_sizes, None) + + start = [0] * len(sizes) + start[dim] = index + start_attr = ir._denseI64ArrayAttr(start, None) + + result_element_type = ir.RankedTensorType(input_tensor.type).element_type + output_type = ir.RankedTensorType.get(new_sizes, result_element_type) + op = tosa.SliceOp(output_type, input_tensor, start_attr, new_sizes_attr) + + reshape_sizes = sizes[:dim] + sizes[dim + 1 :] + reshape_sizes_content = array.array("i", reshape_sizes) + reshape_sizes_content = memoryview(reshape_sizes_content) + op = tosa.ReshapeOp(op.results[0], reshape_sizes_content) + + return op + + +def slice_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + dim = node.args[1] + start_idx = node.args[2] + end_idx = node.args[3] + + sizes = ir.RankedTensorType(input_tensor.type).shape + + if start_idx < 0: + start_idx += sizes[dim] + + if end_idx < 0: + end_idx += sizes[dim] + + if start_idx < 0: + start_idx = 0 + elif start_idx >= sizes[dim]: + start_idx = sizes[dim] + + if end_idx < start_idx: + end_idx = start_idx + elif end_idx >= sizes[dim]: + end_idx = sizes[dim] + + new_sizes = [x for x in sizes] + new_sizes[dim] = end_idx - start_idx + new_sizes_attr = ir._denseI64ArrayAttr(new_sizes, None) + + offsets = [0] * len(sizes) + offsets[dim] = start_idx + offsets_attr = ir._denseI64ArrayAttr(offsets, None) + + strides = [1] * len(sizes) + strides_attr = ir._denseI64ArrayAttr(strides, None) + + result_element_type = ir.RankedTensorType(input_tensor.type).element_type + extract_slice_result_type = ir.RankedTensorType.get( + new_sizes, result_element_type + ) + op = tensor.ExtractSliceOp( + extract_slice_result_type, + input_tensor, + [], + [], + [], + offsets_attr, + new_sizes_attr, + strides_attr, + ) + + return op + + +def convert_element_type_op(node, symbol_table): + # maintain a mapping of torch types and mlir types + types_mapping = { + torch.float64: ir.F64Type.get(), + torch.float32: ir.F32Type.get(), + torch.float16: ir.F16Type.get(), + } + input_tensor = symbol_table.get((str(node.args[0]), 0)) + to_cast_type = types_mapping[node.args[1]] + sizes = ir.RankedTensorType(input_tensor.type).shape + output_type = ir.RankedTensorType.get(sizes, to_cast_type) + return tosa.CastOp(output_type, input_tensor) + + +def clone_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + sizes = ir.RankedTensorType(input_tensor.type).shape + result_element_type = ir.RankedTensorType(input_tensor.type).element_type + output_type = ir.RankedTensorType.get(sizes, result_element_type) + + return tosa.IdentityOp(output_type, input_tensor) + + +def var_mean_op(node, symbol_table): + def mean_dim_op(_input_tensor: ir.Value, _dim) -> ir.Operation: + if isinstance(_dim, int): + _dim = [_dim] + + # `_input_tensor` is the first tensor we need to reduce + reduce_sum_result = _input_tensor + + # reduce along each dimension in `_dim` + for _dim_item in _dim: + reduce_dim_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), _dim_item + ) + reduce_sum_op: ir.Operation = tosa.ReduceSumOp( + reduce_sum_result, reduce_dim_attr + ) + # Next reduction is executed based on this time's reduction result + reduce_sum_result = reduce_sum_op.results[0] + + tensor_shp = ir.RankedTensorType(_input_tensor.type).shape + dim_size = 1 + # calculate the total size on all reduction dimensions to get the denominator + for _dim_item in _dim: + dim_size *= tensor_shp[_dim_item] + + denominator_const_op: ir.Operation = tosa.ConstOp( + ir.DenseElementsAttr.get(memoryview(array.array("f", [dim_size]))) + ) + + reciprocal_op: ir.Operation = tosa.ReciprocalOp( + denominator_const_op.results[0].type, + denominator_const_op.results[0], + ) + + return tosa.MulOp( + reduce_sum_op.results[0].type, + reciprocal_op.results[0], + reduce_sum_op.results[0], + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0), + ) + + def var_dim_op( + _input_tensor: ir.Value, _mean_tensor: ir.Value, _dim, _correction + ) -> ir.Operation: + if isinstance(_dim, int): + _dim = [_dim] + # get (\bar{x} - x_i) + sub_op: ir.Operation = tosa.SubOp( + _input_tensor.type, _input_tensor, _mean_tensor + ) + + # get (\bar{x} - x_i) ^ 2 + mul_op: ir.Operation = tosa.MulOp( + _input_tensor.type, + sub_op.results[0], + sub_op.results[0], + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0), + ) + + # the result of `mul_op` is the first tensor we need to reduce + reduce_sum_op = mul_op + for _dim_item in _dim: + reduce_dim_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), _dim_item + ) + reduce_sum_op: ir.Operation = tosa.ReduceSumOp( + reduce_sum_op.results[0], reduce_dim_attr + ) + + tensor_shp = ir.RankedTensorType(_input_tensor.type).shape + dim_size = 1 + # calculate the denominator + for _dim_item in _dim: + dim_size *= tensor_shp[_dim_item] + biased_denominator_const_op: ir.Operation = tosa.ConstOp( + ir.DenseElementsAttr.get( + memoryview(array.array("f", [dim_size - _correction])) + ) + ) + reciprocal_op: ir.Operation = tosa.ReciprocalOp( + biased_denominator_const_op.results[0].type, + biased_denominator_const_op.results[0], + ) + + return tosa.MulOp( + reduce_sum_op.results[0].type, + reciprocal_op.results[0], + reduce_sum_op.results[0], + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0), + ) + + mean_input_tensor = symbol_table.get((str(node.args[0]), 0)) + var_input_tensor = symbol_table.get((str(node.args[0]), 0)) + + kwargs = node.kwargs + keepdim = kwargs.get("keepdim", False) + correction = kwargs.get("correction", 1.0) + + mean_op = None + var_op = None + if len(node.args) == 1: + calc_dims = range( + len(ir.RankedTensorType(mean_input_tensor.type).shape) + ) + else: + calc_dims = node.args[1] + + mean_op = mean_dim_op(mean_input_tensor, calc_dims) + var_op = var_dim_op( + var_input_tensor, mean_op.results[0], calc_dims, correction + ) + mean_input_tensor = mean_op.results[0] + var_input_tensor = var_op.results[0] + + if not keepdim: + result_shp = ir.RankedTensorType(var_op.results[0].type).shape + result_shp = [siz for siz in result_shp if siz != 1] + var_op = tosa.ReshapeOp( + var_op.results[0], memoryview(array.array("i", result_shp)) + ) + mean_op = tosa.ReshapeOp( + mean_op.results[0], memoryview(array.array("i", result_shp)) + ) + + return var_op, mean_op + + +def permute_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + perm = node.args[1] + perm_const_op = tosa.ConstOp( + ir.DenseElementsAttr.get(memoryview(array.array("i", perm))) + ) + result_element_type = ir.RankedTensorType(input_tensor.type).element_type + init_shape = ir.RankedTensorType(input_tensor.type).shape + new_shape = [] + for perm_item in perm: + new_shape.append(init_shape[perm_item]) + + permute_result_type = ir.RankedTensorType.get( + new_shape, result_element_type + ) + permute_op = tosa.TransposeOp( + permute_result_type, input_tensor, perm_const_op.results[0] + ) + return permute_op + + +def embedding_op(node, symbol_table): + indices = symbol_table.get((str(node.args[1]), 0)) + weight = symbol_table.get((str(node.args[0]), 0)) + padding_idx = None if len(node.args) < 3 else node.args[2] + + indices_size = ir.RankedTensorType(indices.type).shape + weight_size = ir.RankedTensorType(weight.type).shape + result_element_type = ir.RankedTensorType(weight.type).element_type + assert len(indices_size) == 2 + + if indices_size[0] != 1: + total_size = 1 + for x in indices_size: + total_size *= x + indices_reshape_op = tosa.ReshapeOp( + indices, memoryview(array.array("i", [1, total_size])) + ) + indices = indices_reshape_op.result + gather_result_type = ir.RankedTensorType.get( + [1, total_size, weight_size[1]], result_element_type + ) + else: + gather_result_type = ir.RankedTensorType.get( + [*indices_size, weight_size[1]], result_element_type + ) + + # tosa.gather doesn't support i64, so we need to cast it to i32 + if str(ir.RankedTensorType(indices.type).element_type) != "i32": + indices = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(indices.type).shape, + ir.IntegerType.get_signless(32), + ), + indices, + ) + + weight_reshape_op = tosa.ReshapeOp( + weight, memoryview(array.array("i", [1, *weight_size])) + ) + + gather_op = tosa.GatherOp( + gather_result_type, weight_reshape_op.result, indices + ) + op = tosa.ReshapeOp( + gather_op.output, + memoryview(array.array("i", [*indices_size, weight_size[1]])), + ) + + return op + + +def expand_op(node, symbol_table) -> ir.Operation: + to_expand_tensor = symbol_table.get((str(node.args[0]), 0)) + new_size = node.args[1] + result_element_type = ir.RankedTensorType( + to_expand_tensor.type + ).element_type + element = ir.FloatAttr.get(result_element_type, 0.0) + new_size_tensor_type = ir.RankedTensorType.get( + new_size, result_element_type + ) + new_size_attr = ir.DenseElementsAttr.get_splat( + new_size_tensor_type, element + ) + new_size_tensor = tosa.ConstOp(new_size_attr).results[0] + op = _gen_arith_binary_op(to_expand_tensor, new_size_tensor, tosa.AddOp) + return op + + +def sum_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + reduce_sum_dims = node.args[1] + dim_cnt = len(ir.RankedTensorType(input_tensor.type).shape) + reduce_sum_dims = [ + dim if dim >= 0 else dim_cnt + dim for dim in reduce_sum_dims + ] + _reduce_sum_input_tensor = input_tensor + reduce_sum_op = None + for dim in reduce_sum_dims: + reduce_dim_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), dim + ) + reduce_sum_op = tosa.ReduceSumOp( + _reduce_sum_input_tensor, reduce_dim_attr + ) + _reduce_sum_input_tensor = reduce_sum_op.results[0] + + return reduce_sum_op + + ops_registry = { "add.Tensor": add_op, "mul.Tensor": mul_op, + "sub.Tensor": sub_op, + "sum.dim_IntList": sum_op, + "tanh.default": tanh_op, + "amax.default": amax_op, + "rsqrt.default": rsqrt_op, + "bmm.default": bmm_op, + "clone.default": clone_op, + "div.Tensor": div_op, + "exp.default": exp_op, + "expand.default": expand_op, + "var_mean.correction": var_mean_op, + "addmm.default": addmm_op, + "reshape.default": reshape_op, + "view.default": reshape_op, + "select.int": select_op, + "slice.Tensor": slice_op, + "embedding.default": embedding_op, + "convert_element_type.default": convert_element_type_op, + "permute.default": permute_op, + "unsqueeze.default": unsqueeze_op, }