diff --git a/examples/MLIRPython/buddy/operators_gen.py b/examples/MLIRPython/buddy/operators_gen.py index 3119f1ff69..32610995ba 100644 --- a/examples/MLIRPython/buddy/operators_gen.py +++ b/examples/MLIRPython/buddy/operators_gen.py @@ -10,7 +10,7 @@ def _broadcast_shape(tensor_input1: ir.Value, tensor_input2: ir.Value) -> List[int]: - """Calculate the broadcast shape of two tensors with broadcastable shapes + """Calculate the broadcast shape of two tensors with broadcastable shapes according to PyTorch's broadcast semantics: https://pytorch.org/docs/stable/notes/broadcasting.html""" shp1 = ir.RankedTensorType(tensor_input1.type).shape shp2 = ir.RankedTensorType(tensor_input2.type).shape @@ -40,9 +40,10 @@ def AddOp(node: torch.fx.Node, input2 = symbol_table.get((str(node.args[1]), 0)) broadcasted_shp = _broadcast_shape(input1, input2) sizes = broadcasted_shp - f32 = ir.F32Type.get() - addResultTensorType = ir.RankedTensorType.get(sizes, f32) - op = tosa.AddOp(addResultTensorType, input1, input2) + # f32 = ir.F32Type.get() + result_element_type = ir.RankedTensorType(input1.type).element_type + add_result_tensor_type = ir.RankedTensorType.get(sizes, result_element_type) + op = tosa.AddOp(add_result_tensor_type, input1, input2) return op diff --git a/examples/MLIRPython/test/test_different_dtype.py b/examples/MLIRPython/test/test_different_dtype.py new file mode 100644 index 0000000000..02d9f29950 --- /dev/null +++ b/examples/MLIRPython/test/test_different_dtype.py @@ -0,0 +1,16 @@ +from buddy import compiler +import torch +import torch._dynamo as dynamo + +def foo(x, y): + return x + y + +foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo) +float32_in1 = torch.randn(10).to(torch.float32) +float32_in2 = torch.randn(10).to(torch.float32) +foo_mlir(float32_in1, float32_in2) + +foo_int32_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo) +int32_in1 = torch.randint(0, 10, (10,)).to(torch.int32) +int32_in2 = torch.randint(0, 10, (10,)).to(torch.int32) +foo_mlir(int32_in1, int32_in2)