Skip to content

Commit

Permalink
[examples][MLIRPython] add test for different dtype support
Browse files Browse the repository at this point in the history
  • Loading branch information
xTayEx committed Aug 7, 2023
1 parent a507a84 commit e98ee65
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
9 changes: 5 additions & 4 deletions examples/MLIRPython/buddy/operators_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _broadcast_shape(tensor_input1: ir.Value,
tensor_input2: ir.Value) -> List[int]:
"""Calculate the broadcast shape of two tensors with broadcastable shapes
"""Calculate the broadcast shape of two tensors with broadcastable shapes
according to PyTorch's broadcast semantics: https://pytorch.org/docs/stable/notes/broadcasting.html"""
shp1 = ir.RankedTensorType(tensor_input1.type).shape
shp2 = ir.RankedTensorType(tensor_input2.type).shape
Expand Down Expand Up @@ -40,9 +40,10 @@ def AddOp(node: torch.fx.Node,
input2 = symbol_table.get((str(node.args[1]), 0))
broadcasted_shp = _broadcast_shape(input1, input2)
sizes = broadcasted_shp
f32 = ir.F32Type.get()
addResultTensorType = ir.RankedTensorType.get(sizes, f32)
op = tosa.AddOp(addResultTensorType, input1, input2)
# f32 = ir.F32Type.get()
result_element_type = ir.RankedTensorType(input1.type).element_type
add_result_tensor_type = ir.RankedTensorType.get(sizes, result_element_type)
op = tosa.AddOp(add_result_tensor_type, input1, input2)
return op


Expand Down
16 changes: 16 additions & 0 deletions examples/MLIRPython/test/test_different_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from buddy import compiler
import torch
import torch._dynamo as dynamo

def foo(x, y):
return x + y

foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo)
float32_in1 = torch.randn(10).to(torch.float32)
float32_in2 = torch.randn(10).to(torch.float32)
foo_mlir(float32_in1, float32_in2)

foo_int32_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo)
int32_in1 = torch.randint(0, 10, (10,)).to(torch.int32)
int32_in2 = torch.randint(0, 10, (10,)).to(torch.int32)
foo_mlir(int32_in1, int32_in2)

0 comments on commit e98ee65

Please sign in to comment.