Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
Use torchdistx.fake.meta_like for fake->meta conversion (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
yf225 authored Jun 11, 2022
1 parent 9f41aa3 commit b47da7b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
5 changes: 3 additions & 2 deletions alpa/torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor, nn
from torch.fx.experimental.normalize import NormalizeOperators
from torchdistx import deferred_init as torchdistx_deferred_init
from torchdistx.fake import meta_like
import alpa.torch as atorch
from alpa.torch.tensor_utils import make_shaped_array_from_pt_tensor
from alpa.torch.ops.mapping import zeros_like_on_device
Expand Down Expand Up @@ -463,5 +464,5 @@ def _zeros_init_dict(tensor_dict):
def meta_init(module_fn: Callable[..., torch.nn.Module], *args, **kwargs):
pt_module = torchdistx_deferred_init.deferred_init(module_fn, *args,
**kwargs)
pt_module = pt_module.to(device="meta")
return pt_module
# pylint: disable=protected-access
return pt_module._apply(meta_like)
1 change: 0 additions & 1 deletion tests/torch/test_torch_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class TorchReshapeTest(unittest.TestCase):
def test_reshape(self):
B = 64

# `meta_init` allows a PyTorch model to be created with shape-only tensors as weights.
pt_module_gen = lambda: MyModule()

dataloader = [
Expand Down

0 comments on commit b47da7b

Please sign in to comment.