diff --git a/iree/turbine/ops/iree.py b/iree/turbine/ops/iree.py index 05484fe7..c65f40f2 100644 --- a/iree/turbine/ops/iree.py +++ b/iree/turbine/ops/iree.py @@ -89,7 +89,7 @@ def select(self, ksel: KernelSelection): ksel.return_tensor(ta.t).specialize_dims(*spec) def eager_execute(self, device_moniker, tensor): - return tensor + return tensor.clone() def generate(self, ksel: KernelSelection, kb: KernelBuilder): moniker = cast(AttrArg, ksel.arg_descs[0]).v diff --git a/tests/ops/iree_test.py b/tests/ops/iree_test.py index facbf545..b2dca1b0 100644 --- a/tests/ops/iree_test.py +++ b/tests/ops/iree_test.py @@ -36,7 +36,7 @@ class TransferToLogicalDeviceTest(unittest.TestCase): def testEager(self): t1 = torch.randn(3, 4) t2 = ops.iree.transfer_to_logical_device("1", t1) - self.assertIs(t1, t2) + assert torch.all(t1 == t2) def testAOT(self): class MyModule(nn.Module):