Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using torch.compile in Pyro models #2256

Open
vitkl opened this issue Sep 5, 2023 · 10 comments
Open

Using torch.compile in Pyro models #2256

vitkl opened this issue Sep 5, 2023 · 10 comments
Assignees

Comments

@vitkl
Copy link
Contributor

vitkl commented Sep 5, 2023

Hi @adamgayoso and others (also cc @fritzo, @martinjankowiak @eb8680)

It would be great if the new torch.compile function could be used with the Pyro model and guide in scvi-tools.

I am happy to contribute this functionality, however, I need your recommendations on what to do with the following problem. Suppose we create add torch.compile as shown below:

class MyBaseModule(PyroBaseModuleClass):
    def __init__(
        self,
        model,
        guide_class,  # such as AutoNormalMessenger
        **kwargs,
    ):
        """
        Module class which defines AutoGuide given model. 
        """
        super().__init__()
        self.hist = []

        _model = model(**kwargs)
        self._model = torch.compile(_model)
        _guide = guide_class(**guide_kwargs)
        self._guide = torch.compile(_guide)

The problem is that Pyro creates guide parameters when they are first needed - requiring these callbacks

class PyroJitGuideWarmup(Callback):
"""A callback to warmup a Pyro guide.
This helps initialize all the relevant parameters by running
one minibatch through the Pyro model.
"""
def __init__(self, dataloader: AnnDataLoader = None) -> None:
super().__init__()
self.dataloader = dataloader
def on_train_start(self, trainer, pl_module):
"""Way to warmup Pyro Guide in an automated way.
Also device agnostic.
"""
# warmup guide for JIT
pyro_guide = pl_module.module.guide
if self.dataloader is None:
dl = trainer.datamodule.train_dataloader()
else:
dl = self.dataloader
for tensors in dl:
tens = {k: t.to(pl_module.device) for k, t in tensors.items()}
args, kwargs = pl_module.module._get_fn_args_from_batch(tens)
pyro_guide(*args, **kwargs)
break
class PyroModelGuideWarmup(Callback):
"""A callback to warmup a Pyro guide and model.
This helps initialize all the relevant parameters by running
one minibatch through the Pyro model. This warmup occurs on the CPU.
"""
def __init__(self, dataloader: AnnDataLoader) -> None:
super().__init__()
self.dataloader = dataloader
def setup(self, trainer, pl_module, stage=None):
"""Way to warmup Pyro Model and Guide in an automated way.
Setup occurs before any device movement, so params are iniitalized on CPU.
"""
if stage == "fit":
pyro_guide = pl_module.module.guide
dl = self.dataloader
for tensors in dl:
tens = {k: t.to(pl_module.device) for k, t in tensors.items()}
args, kwargs = pl_module.module._get_fn_args_from_batch(tens)
pyro_guide(*args, **kwargs)
break
. My understanding is that this means that torch.compile(_guide) should similarly be called only after the parameters are created.

I see one solution to this. Run the following code

pyro_guide = pl_module.module.guide
dl = self.dataloader
for tensors in dl:
tens = {k: t.to(pl_module.device) for k, t in tensors.items()}
args, kwargs = pl_module.module._get_fn_args_from_batch(tens)
pyro_guide(*args, **kwargs)
break
in model.train() manually without using a callback after creating data loaders but before creating TrainRunner and TrainingPlan.

Then modify the training plan as follows:

class PyroCompiledTrainingPlan(LowLevelPyroTrainingPlan):
    """
    Lightning module task to train Pyro scvi-tools modules.
    """

    def __init__(
        self,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.module.model_compiled = torch.compile(self.module.model)
        self.module.guide_compiled = torch.compile(self.module.guide)

        self.svi = pyro.infer.SVI(
                model=self.module.model_compiled,
                guide=self.module.guide_compiled,
                optim=self.optim,
                loss=self.loss_fn,
        )

What do you think about this? Do you have any better ideas on how to implement this?

@canergen
Copy link
Member

canergen commented Sep 6, 2023

Hi @vitkl,
What's the purpose of the request?
In our hands, speed-up was not reproducibly high. Have you made other experience?
We see larger speed-up using JAX, so if it's about speed rewriting in numpyro should be the best option.

@vitkl
Copy link
Contributor Author

vitkl commented Sep 7, 2023

The purpose is to enable general support for Pyro scvi-tools models. It is possible that some models benefit from this more than other models but it's good to have this option. Pyro adds additional challenges to using torch.compile (as mentioned above) - so I was not able to try using torch.compile without additional input on how to solve those issues. Once, I understand how to implement this, I will test this for cell2location and other unpublished models which both use all GPU memory (incl multi-GPU) as opposed to scVI which mostly uses a few GB.

Re-implementation of models in numpyro is not always practical because i) numpyro doesn't cover all functionality and because ii) we observed in the past that JAX uses 2-4x of GPU memory for the same data size -meaning> less practical to use for larger datasets where every bit of GPU memory matters.

@canergen
Copy link
Member

canergen commented Sep 8, 2023

I agree that speed-up is expected to be largely model-dependent and that scVI is small and might be a bad proxy. Adam and Martin experimented with torch.compile, however, only in the pytorch models. I would expect it's more straightforward to train the model/guide for one step (similar to our current load procedure)

def on_load(self, model):
. Afterwards the guide can be compiled.

@vitkl
Copy link
Contributor Author

vitkl commented Sep 8, 2023

Do you suggest to modify model.train() as shown below?

self.module.on_load(self)
self.module._model = torch.compile(self.module.model)
self.module._guide = torch.compile(self.module.guide)

Are self.module.model and self.module.guide modifyable as shown here?

@vitkl
Copy link
Contributor Author

vitkl commented Sep 8, 2023

As a proxy for compilation effect on cell2location, I can mention that our old theano+pymc3 implementation was 2-4 times faster for the same number of training steps. Would be great to see what happens here. A 2-4x speedup would be really nice.

@canergen
Copy link
Member

I tried it out on my side and got some cryptic error messages (it was on a private repo with a not published model though). My idea was to call self.train(max_steps=1) once and afterwards compile. So using the guide warmup by running a single train step. I'm happy to review if you have a PR.

@vitkl
Copy link
Contributor Author

vitkl commented Sep 14, 2023

I will try your suggestion. Do I get this right that you suggest to

def train(self, ...):
  
    self.train(..., max_steps=1)
    self.module._model = torch.compile(self.module.model)
    self.module._guide = torch.compile(self.module.guide)
    self.train(...)

?

@canergen
Copy link
Member

Yes, that's my understanding of how we do guide warmups for Pyro (e.g. during loading a trained model). I don't think pyro.clear_param_store() is necessary here.

@vitkl
Copy link
Contributor Author

vitkl commented Sep 14, 2023

This is a good point. I will test this. Lets see what happens with cell2location.

@vitkl
Copy link
Contributor Author

vitkl commented Nov 8, 2023

Looks like torch.compile works for cell2location using the modified train method below. Code runs but there is no speed benefit (5h12min with, 5h25min without, using torch.set_float32_matmul_precision('high') on A100 is more impactful -> 4h45min).

def MyModelClass(PyroSampleMixin, PyroSviTrainMixin, BaseModelClass):
    def train_compiled(self, **kwargs):
        import torch
        self.train(**kwargs, max_steps=1)
        self.module._model = torch.compile(self.module.model)
        self.module._guide = torch.compile(self.module.guide)
        self.train(**kwargs)

The model and guide are successfully replaced:

LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel(
  (dropout): Dropout(p=0.0, inplace=False)
)
AutoNormal(
  (locs): PyroModule()
  (scales): PyroModule()
)
OptimizedModule(
  (_orig_mod): LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel(
    (dropout): Dropout(p=0.0, inplace=False)
  )
)
OptimizedModule(
  (_orig_mod): AutoNormal(
    (locs): PyroModule()
    (scales): PyroModule()
  )
)

Pytorch documentation says (https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html):

Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.

I wonder if this means that speedups only come for models that don't already have 100% GPU utilisation. Cell2location mainly uses very large full data batches.

I also get errors if I attempt using amortised inference (using encoder NN as part of the guide).

File /nfs/team283/vk7/software/miniconda3farm5/envs/cell2loc_env_2023/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py:1544, in ShapeGuardPrinter._print_Symbol(self, expr)
   1538 def repr_symbol_to_source():
   1539     return repr({
   1540         symbol: [s.name() for s in sources]
   1541         for symbol, sources in self.symbol_to_source.items()
   1542     })
-> 1544 assert self.symbol_to_source.get(expr), (
   1545     f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
   1546     f"not in {repr_symbol_to_source()}.  If this assert is failing, it could be "
   1547     "due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
   1548 )
   1549 return self.source_ref(self.symbol_to_source[expr][0])

AssertionError: s2 (could be from ["L['msg']['infer']['prior']._batch_shape[0]"]) not in {s0: ["L['msg']['value'].size()[0]"], s1: ["L['msg']['value'].size()[1]", "L['msg']['value'].stride()[0]"], s5: [], s2: [], s4: [], s3: []}.  If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants