diff --git a/alpa/torch/nn/__init__.py b/alpa/torch/nn/__init__.py index 1918b3c35..433b8cbd2 100644 --- a/alpa/torch/nn/__init__.py +++ b/alpa/torch/nn/__init__.py @@ -7,6 +7,7 @@ from torch import Tensor, nn from torch.fx.experimental.normalize import NormalizeOperators from torchdistx import deferred_init as torchdistx_deferred_init +from torchdistx.fake import meta_like import alpa.torch as atorch from alpa.torch.tensor_utils import make_shaped_array_from_pt_tensor from alpa.torch.ops.mapping import zeros_like_on_device @@ -463,5 +464,5 @@ def _zeros_init_dict(tensor_dict): def meta_init(module_fn: Callable[..., torch.nn.Module], *args, **kwargs): pt_module = torchdistx_deferred_init.deferred_init(module_fn, *args, **kwargs) - pt_module = pt_module.to(device="meta") - return pt_module + # pylint: disable=protected-access + return pt_module._apply(meta_like) diff --git a/tests/torch/test_torch_reshape.py b/tests/torch/test_torch_reshape.py index c40e26cae..08f65a928 100644 --- a/tests/torch/test_torch_reshape.py +++ b/tests/torch/test_torch_reshape.py @@ -35,7 +35,6 @@ class TorchReshapeTest(unittest.TestCase): def test_reshape(self): B = 64 - # `meta_init` allows a PyTorch model to be created with shape-only tensors as weights. pt_module_gen = lambda: MyModule() dataloader = [