From 17fe0c2f75ac086ec5efbdcc65f068b4952a385b Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 10:03:24 -0700 Subject: [PATCH 01/12] chore: removing unused module for ema Signed-off-by: Lee, Kin Long Kelvin --- .../modules/exponential_moving_average.py | 207 ------------------ 1 file changed, 207 deletions(-) delete mode 100644 matsciml/modules/exponential_moving_average.py diff --git a/matsciml/modules/exponential_moving_average.py b/matsciml/modules/exponential_moving_average.py deleted file mode 100644 index d87e622b..00000000 --- a/matsciml/modules/exponential_moving_average.py +++ /dev/null @@ -1,207 +0,0 @@ -""" -Copied (and improved) from: -https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py (MIT license) -""" -from __future__ import annotations - -import copy -import weakref -from collections.abc import Iterable -from typing import Optional - -import torch - - -# Partially based on: -# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py -class ExponentialMovingAverage: - """ - Maintains (exponential) moving average of a set of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter` (typically from - `model.parameters()`). - decay: The exponential decay. - use_num_updates: Whether to use number of updates when computing - averages. - """ - - def __init__( - self, - parameters: Iterable[torch.nn.Parameter], - decay: float, - use_num_updates: bool = False, - ): - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - self.decay = decay - self.num_updates = 0 if use_num_updates else None - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] - self.collected_params = [] - # By maintaining only a weakref to each parameter, - # we maintain the old GC behaviour of ExponentialMovingAverage: - # if the model goes out of scope but the ExponentialMovingAverage - # is kept, no references to the model or its parameters will be - # maintained, and the model will be cleaned up. - self._params_refs = [weakref.ref(p) for p in parameters if p.requires_grad] - - def _get_parameters( - self, - parameters: Iterable[torch.nn.Parameter] | None, - ) -> Iterable[torch.nn.Parameter]: - if parameters is None: - parameters = [p() for p in self._params_refs] - if any(p is None for p in parameters): - raise ValueError( - "(One of) the parameters with which this " - "ExponentialMovingAverage " - "was initialized no longer exists (was garbage collected);" - " please either provide `parameters` explicitly or keep " - "the model to which they belong from being garbage " - "collected.", - ) - return parameters - else: - return [p for p in parameters if p.requires_grad] - - def update( - self, - parameters: Iterable[torch.nn.Parameter] | None = None, - ) -> None: - """ - Update currently maintained parameters. - - Call this every time the parameters are updated, such as the result of - the `optimizer.step()` call. - - Args: - parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = self._get_parameters(parameters) - decay = self.decay - if self.num_updates is not None: - self.num_updates += 1 - decay = min( - decay, - (1 + self.num_updates) / (10 + self.num_updates), - ) - one_minus_decay = 1.0 - decay - with torch.no_grad(): - for s_param, param in zip(self.shadow_params, parameters): - tmp = param - s_param - s_param.add_(tmp, alpha=one_minus_decay) - - def copy_to( - self, - parameters: Iterable[torch.nn.Parameter] | None = None, - ) -> None: - """ - Copy current parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = self._get_parameters(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def store( - self, - parameters: Iterable[torch.nn.Parameter] | None = None, - ) -> None: - """ - Save the current parameters for restoring later. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. If `None`, the parameters of with which this - `ExponentialMovingAverage` was initialized will be used. - """ - parameters = self._get_parameters(parameters) - self.collected_params = [param.clone() for param in parameters] - - def restore( - self, - parameters: Iterable[torch.nn.Parameter] | None = None, - ) -> None: - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = self._get_parameters(parameters) - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) - - def state_dict(self) -> dict: - r"""Returns the state of the ExponentialMovingAverage as a dict.""" - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "num_updates": self.num_updates, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r"""Loads the ExponentialMovingAverage state. - - Args: - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict["decay"] - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.num_updates = state_dict["num_updates"] - assert self.num_updates is None or isinstance( - self.num_updates, - int, - ), "Invalid num_updates" - - assert isinstance( - state_dict["shadow_params"], - list, - ), "shadow_params must be a list" - self.shadow_params = [ - p.to(self.shadow_params[i].device) - for i, p in enumerate(state_dict["shadow_params"]) - ] - assert all( - isinstance(p, torch.Tensor) for p in self.shadow_params - ), "shadow_params must all be Tensors" - - assert isinstance( - state_dict["collected_params"], - list, - ), "collected_params must be a list" - # collected_params is empty at initialization, - # so use shadow_params for device instead - self.collected_params = [ - p.to(self.shadow_params[i].device) - for i, p in enumerate(state_dict["collected_params"]) - ] - assert all( - isinstance(p, torch.Tensor) for p in self.collected_params - ), "collected_params must all be Tensors" From b2c1f3430863e54e711458f98ac8d3ca7492e274 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 10:22:31 -0700 Subject: [PATCH 02/12] feat: added EMA callback --- matsciml/lightning/callbacks.py | 59 +++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 107a18eb..1c31adc7 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -24,6 +24,7 @@ from torch.optim import Optimizer from dgl import DGLGraph from scipy.signal import correlate +from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn, update_bn from matsciml.common.packages import package_registry from matsciml.datasets.utils import concatenate_keys @@ -1407,3 +1408,61 @@ def on_before_optimizer_step( _ = self.history["grads"].get() if self.is_active: self.run_analysis(pl_module.logger) + + +class ExponentialMovingAverageCallback(Callback): + def __init__(self, decay: float = 0.99) -> None: + """ + Initialize an exponential moving average callback. + + This callback attaches a ``ema_module`` attribute to + the current training task, which duplicates the model + weights that are tracked with an exponential moving + average, parametrized by the ``decay`` value. + + This will double the memory footprint of your model, + but has been shown to considerably improve generalization. + + Parameters + ---------- + decay : float + Exponential decay factor to apply to updates. + """ + super().__init__() + self.decay = decay + self.logger = getLogger("matsciml.ema_callback") + + def on_fit_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + # in the case that there is already an EMA state we don't initialize + if hasattr(pl_module, "ema_module"): + self.logger.info( + "Task has an existing EMA state; not initializing a new one." + ) + self.ema_module = pl_module.ema_module + else: + # hook to the task module and in the current callback + ema_module = AveragedModel(pl_module, get_ema_multi_avg_fn(self.decay)) + self.logger.info("Task does not have an existing EMA state; creating one.") + # setting the callback ema_module attribute allows ease of access + self.ema_module = ema_module + pl_module.ema_module = ema_module + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs, + batch: Any, + batch_idx: int, + ) -> None: + self.logger.info("Updating EMA state.") + pl_module.ema_module.update_parameters(pl_module) + + def on_fit_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + loader = trainer.train_dataloader() + self.logger.info("Fit finished - updating EMA batch normalization state.") + update_bn(loader, pl_module.ema_module) From 3ddedf1a101b99f32cc50c188b6dc7446711d859 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 10:35:55 -0700 Subject: [PATCH 03/12] fix: correctly mapping avg func as argument --- matsciml/lightning/callbacks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 1c31adc7..a904f2e8 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -1443,7 +1443,9 @@ def on_fit_start( self.ema_module = pl_module.ema_module else: # hook to the task module and in the current callback - ema_module = AveragedModel(pl_module, get_ema_multi_avg_fn(self.decay)) + ema_module = AveragedModel( + pl_module, multi_avg_fn=get_ema_multi_avg_fn(self.decay) + ) self.logger.info("Task does not have an existing EMA state; creating one.") # setting the callback ema_module attribute allows ease of access self.ema_module = ema_module From 81e096a150e1b70e21f5bd4215d735fbbaa56cd4 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 10:41:24 -0700 Subject: [PATCH 04/12] refactor: adding exception handling for lazy modules --- matsciml/lightning/callbacks.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index a904f2e8..666bd824 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -1435,6 +1435,14 @@ def __init__(self, decay: float = 0.99) -> None: def on_fit_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: + # check to make sure the module has no lazy layers + for layer in pl_module.modules(): + if isinstance(layer, nn.modules.lazy.LazyModuleMixin): + if layer.has_uninitialized_params(): + raise RuntimeError( + "EMA callback does not support lazy layers. Please " + "re-run without using lazy layers." + ) # in the case that there is already an EMA state we don't initialize if hasattr(pl_module, "ema_module"): self.logger.info( From 24665cbc4ac371940e3e227fdebb2f4c68aaab9c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 10:41:58 -0700 Subject: [PATCH 05/12] fix: grabbing dataloader as property not method --- matsciml/lightning/callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 666bd824..e0be2b31 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -1473,6 +1473,6 @@ def on_train_batch_end( def on_fit_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: - loader = trainer.train_dataloader() + loader = trainer.train_dataloader self.logger.info("Fit finished - updating EMA batch normalization state.") update_bn(loader, pl_module.ema_module) From 6230cb720100647479d0b6dcd34d9fe96a65283f Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 10:55:13 -0700 Subject: [PATCH 06/12] test: added ema unit test suite Signed-off-by: Lee, Kin Long Kelvin --- matsciml/lightning/tests/test_ema.py | 330 +++++++++++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 matsciml/lightning/tests/test_ema.py diff --git a/matsciml/lightning/tests/test_ema.py b/matsciml/lightning/tests/test_ema.py new file mode 100644 index 00000000..c743eb66 --- /dev/null +++ b/matsciml/lightning/tests/test_ema.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +import pytorch_lightning as pl +import pytest + +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, + FrameAveraging, + UnitCellCalculator, +) + +from matsciml.lightning import MatSciMLDataModule, MultiDataModule +from matsciml.datasets import MultiDataset, IS2REDataset, S2EFDataset +from matsciml.models.pyg import EGNN +from matsciml.lightning.callbacks import ExponentialMovingAverageCallback +from matsciml.models.pyg import FAENet +from torch import nn +from e3nn.o3 import Irreps +from mace.modules.blocks import RealAgnosticInteractionBlock +from matsciml.models.pyg.mace import MACEWrapper +from matsciml.models.dgl import PLEGNNBackbone +from matsciml.models.base import ( + MultiTaskLitModule, + ForceRegressionTask, + ScalarRegressionTask, +) + + +def test_egnn_end_to_end_with_ema(): + """ + Test the end to end pipeline using a devset with EGNN and ema callback. + + The idea is that this basically mimics an example script to + try and maximize coverage across dataset to training, which + is particularly useful for checking new dependencies, etc. + """ + dm = MatSciMLDataModule.from_devset( + "MaterialsProjectDataset", + dset_kwargs={ + "transforms": [ + PeriodicPropertiesTransform(5.0, adaptive_cutoff=True), + PointCloudToGraphTransform("pyg"), + ] + }, + batch_size=8, + ) + + # this specifies a whole lot to make sure we have coverage + task = ScalarRegressionTask( + encoder_class=EGNN, + encoder_kwargs={ + "hidden_dim": 48, + "output_dim": 32, + "num_conv": 2, + "num_atom_embedding": 200, + }, + scheduler_kwargs={ + "CosineAnnealingLR": { + "T_max": 5, + "eta_min": 1e-7, + } + }, + lr=1e-3, + weight_decay=0.0, + output_kwargs={ + "lazy": False, + "hidden_dim": 48, + "input_dim": 48, + "dropout": 0.2, + "num_hidden": 2, + }, + task_keys=["band_gap"], + ) + + trainer = pl.Trainer(fast_dev_run=5, callbacks=[ExponentialMovingAverageCallback()]) + trainer.fit(task, datamodule=dm) + assert hasattr(task, "ema_module") + + +def test_lazy_fail(): + """Lazy modules are not supported right now, so this ensures they fail""" + dm = MatSciMLDataModule.from_devset( + "MaterialsProjectDataset", + dset_kwargs={ + "transforms": [ + PeriodicPropertiesTransform(5.0, adaptive_cutoff=True), + PointCloudToGraphTransform("pyg"), + ] + }, + batch_size=8, + ) + + # this specifies a whole lot to make sure we have coverage + task = ScalarRegressionTask( + encoder_class=EGNN, + encoder_kwargs={ + "hidden_dim": 48, + "output_dim": 32, + "num_conv": 2, + "num_atom_embedding": 200, + }, + scheduler_kwargs={ + "CosineAnnealingLR": { + "T_max": 5, + "eta_min": 1e-7, + } + }, + lr=1e-3, + weight_decay=0.0, + output_kwargs={"lazy": True}, + task_keys=["band_gap"], + ) + + trainer = pl.Trainer(fast_dev_run=5, callbacks=[ExponentialMovingAverageCallback()]) + with pytest.raises(RuntimeError): + trainer.fit(task, datamodule=dm) + + +def test_mace_with_ema(): + """ + Test the MACE Wrapper with ema callback. + """ + # Construct MACE relaxed energy regression with PyG implementation of E(n)-GNN + task = ScalarRegressionTask( + encoder_class=MACEWrapper, + encoder_kwargs={ + "r_max": 6.0, + "num_bessel": 3, + "num_polynomial_cutoff": 3, + "max_ell": 2, + "interaction_cls": RealAgnosticInteractionBlock, + "interaction_cls_first": RealAgnosticInteractionBlock, + "num_interactions": 2, + "atom_embedding_dim": 64, + "MLP_irreps": Irreps("256x0e"), + "avg_num_neighbors": 10.0, + "correlation": 1, + "radial_type": "bessel", + "gate": nn.Identity(), + }, + task_keys=["energy_relaxed"], + output_kwargs={"lazy": False, "hidden_dim": 128, "input_dim": 128}, + ) + + # Prepare data module + dm = MatSciMLDataModule.from_devset( + "IS2REDataset", + dset_kwargs={ + "transforms": [ + PeriodicPropertiesTransform(5.0, adaptive_cutoff=True), + PointCloudToGraphTransform( + "pyg", + node_keys=["pos", "atomic_numbers"], + ), + ], + }, + ) + + # Run a quick training loop + trainer = pl.Trainer( + fast_dev_run=5, callbacks=[ExponentialMovingAverageCallback(0.999)] + ) + trainer.fit(task, datamodule=dm) + assert hasattr(task, "ema_module") + + +def test_faenet_with_ema(): + """ + Test FAENet with ema Callback. + """ + task = ScalarRegressionTask( + encoder_class=FAENet, + encoder_kwargs={ + "pred_as_dict": False, + "hidden_dim": 128, + "out_dim": 64, + "tag_hidden_channels": 0, + "input_dim": 128, + }, + output_kwargs={"lazy": False, "input_dim": 64, "hidden_dim": 64}, + task_keys=["band_gap"], + ) + + dm = MatSciMLDataModule.from_devset( + "MaterialsProjectDataset", + dset_kwargs={ + "transforms": [ + UnitCellCalculator(), + PointCloudToGraphTransform( + "pyg", + cutoff_dist=20.0, + node_keys=["pos", "atomic_numbers"], + ), + FrameAveraging(frame_averaging="3D", fa_method="stochastic"), + ], + }, + ) + + # run a quick training loop + trainer = pl.Trainer(fast_dev_run=5, callbacks=[ExponentialMovingAverageCallback()]) + trainer.fit(task, datamodule=dm) + assert hasattr(task, "ema_module") + + +def test_force_regression_with_ema(): + """ + Tests force regression with ema using PLEGNNBackbone. + """ + devset = MatSciMLDataModule.from_devset( + "S2EFDataset", + dset_kwargs={ + "transforms": [ + PointCloudToGraphTransform( + "dgl", + cutoff_dist=20.0, + node_keys=["pos", "atomic_numbers"], + ), + ], + }, + ) + model_args = { + "embed_in_dim": 128, + "embed_hidden_dim": 32, + "embed_out_dim": 128, + "embed_depth": 5, + "embed_feat_dims": [128, 128, 128], + "embed_message_dims": [128, 128, 128], + "embed_position_dims": [64, 64], + "embed_edge_attributes_dim": 0, + "embed_activation": "relu", + "embed_residual": True, + "embed_normalize": True, + "embed_tanh": True, + "embed_activate_last": False, + "embed_k_linears": 1, + "embed_use_attention": False, + "embed_attention_norm": "sigmoid", + "readout": "sum", + "node_projection_depth": 3, + "node_projection_hidden_dim": 128, + "node_projection_activation": "relu", + "prediction_out_dim": 1, + "prediction_depth": 3, + "prediction_hidden_dim": 128, + "prediction_activation": "relu", + "encoder_only": True, + } + + task = ForceRegressionTask( + encoder_class=PLEGNNBackbone, + encoder_kwargs=model_args, + output_kwargs={"lazy": False, "hidden_dim": 128, "input_dim": 128}, + ) + trainer = pl.Trainer(fast_dev_run=5, callbacks=[ExponentialMovingAverageCallback()]) + trainer.fit(task, datamodule=devset) + # make sure losses are tracked + for key in ["energy", "force"]: + assert f"train_{key}" in trainer.logged_metrics + assert hasattr(task, "ema_module") + + +def test_multitask_ema(): + transforms = [ + PeriodicPropertiesTransform(6.0, adaptive_cutoff=True), + PointCloudToGraphTransform( + "dgl", + cutoff_dist=6.0, + node_keys=["pos", "atomic_numbers"], + ), + ] + dm = MultiDataModule( + train_dataset=MultiDataset( + [ + IS2REDataset.from_devset(transforms=transforms), + S2EFDataset.from_devset(transforms=transforms), + ], + ), + batch_size=8, + ) + model_args = { + "embed_in_dim": 128, + "embed_hidden_dim": 32, + "embed_out_dim": 128, + "embed_depth": 5, + "embed_feat_dims": [128, 128, 128], + "embed_message_dims": [128, 128, 128], + "embed_position_dims": [64, 64], + "embed_edge_attributes_dim": 0, + "embed_activation": "relu", + "embed_residual": True, + "embed_normalize": True, + "embed_tanh": True, + "embed_activate_last": False, + "embed_k_linears": 1, + "embed_use_attention": False, + "embed_attention_norm": "sigmoid", + "readout": "sum", + "node_projection_depth": 3, + "node_projection_hidden_dim": 128, + "node_projection_activation": "relu", + "prediction_out_dim": 1, + "prediction_depth": 3, + "prediction_hidden_dim": 128, + "prediction_activation": "relu", + "encoder_only": True, + } + + is2re = ScalarRegressionTask( + encoder_class=PLEGNNBackbone, + encoder_kwargs=model_args, + task_keys=["energy_init", "energy_relaxed"], + output_kwargs={"lazy": False, "hidden_dim": 128, "input_dim": 128}, + ) + s2ef = ForceRegressionTask( + encoder_class=PLEGNNBackbone, + encoder_kwargs=model_args, + task_keys=["energy", "force"], + output_kwargs={"lazy": False, "hidden_dim": 128, "input_dim": 128}, + ) + + task = MultiTaskLitModule( + ("IS2REDataset", is2re), + ("S2EFDataset", s2ef), + ) + trainer = pl.Trainer( + fast_dev_run=5, callbacks=ExponentialMovingAverageCallback(0.9) + ) + trainer.fit(task, datamodule=dm) + assert hasattr(task, "ema_module") From d4e0b72444f4e20f9c309a7b40e1bee6ea63c1de Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 10:58:28 -0700 Subject: [PATCH 07/12] refactor: using ema module weights if available and not training --- matsciml/models/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 4a3e168c..f911ae13 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -1006,7 +1006,12 @@ def _compute_losses( containing each individual target loss. """ targets = self._get_targets(batch) - predictions = self(batch) + # if we have EMA weights, use them for prediction instead + if hasattr(self, "ema_module") and not self.training: + wrapper = self.ema_module + else: + wrapper = self + predictions = wrapper(batch) losses = {} for key in self.task_keys: target_val = targets[key] From ff240ccac4edc2594a1a7a0e73b7594d1778ca59 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 11:01:39 -0700 Subject: [PATCH 08/12] refactor: making predict method use ema weights as well --- matsciml/models/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index f911ae13..a2c14a3f 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -1157,7 +1157,12 @@ def predict(self, batch: BatchDict) -> dict[str, torch.Tensor]: normalizers are available for a given task, we apply the inverse norm on the value. """ - outputs = self(batch) + # use EMA weights instead if they are available + if hasattr(self, "ema_module"): + wrapper = self.ema_module + else: + wrapper = self + outputs = wrapper(batch) if self.uses_normalizers: for key in self.task_keys: if key in self.normalizers: From 98c8eb9a3a25edd00817fb762b2af65028d1f178 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 11:12:01 -0700 Subject: [PATCH 09/12] feat: added example script for ema Signed-off-by: Lee, Kin Long Kelvin --- .../callbacks/exponential_moving_average.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 examples/callbacks/exponential_moving_average.py diff --git a/examples/callbacks/exponential_moving_average.py b/examples/callbacks/exponential_moving_average.py new file mode 100644 index 00000000..db969a7e --- /dev/null +++ b/examples/callbacks/exponential_moving_average.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import StochasticWeightAveraging + +from matsciml.datasets.transforms import DistancesTransform, PointCloudToGraphTransform +from matsciml.lightning.data_utils import MatSciMLDataModule +from matsciml.lightning.callbacks import ExponentialMovingAverageCallback +from matsciml.models import SchNet +from matsciml.models.base import ScalarRegressionTask + +""" +This script demonstrates how to use the EMA and SWA callbacks, +which are pretty necessary for models such as MACE. + +EMA is implemented within ``matsciml`` using native PyTorch, whereas SWA uses +the PyTorch Lightning implementation. +""" + +# construct a scalar regression task with SchNet encoder +task = ScalarRegressionTask( + encoder_class=SchNet, + # kwargs to be passed into the creation of SchNet model + encoder_kwargs={ + "encoder_only": True, + "hidden_feats": [128, 128, 128], + "atom_embedding_dim": 128, + }, + output_kwargs={"lazy": False, "hidden_dim": 128, "input_dim": 128}, + # which keys to use as targets + task_keys=["energy_relaxed"], +) + +# Use IS2RE devset to test workflow +# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances +dm = MatSciMLDataModule.from_devset( + "IS2REDataset", + dset_kwargs={ + "transforms": [ + PointCloudToGraphTransform( + "dgl", + cutoff_dist=20.0, + node_keys=["pos", "atomic_numbers"], + ), + DistancesTransform(), + ], + }, +) + +# run several epochs with a limited number of train batches +# to make sure nothing breaks between updates +trainer = pl.Trainer( + max_epochs=5, + limit_train_batches=10, + logger=False, + enable_checkpointing=False, + callbacks=[ + StochasticWeightAveraging(swa_lrs=1e-2, swa_epoch_start=1), + ExponentialMovingAverageCallback(decay=0.99), + ], +) +trainer.fit(task, datamodule=dm) From a4a35f0253f252d656a8e00a6e8963d2f014ad7f Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 11:13:40 -0700 Subject: [PATCH 10/12] refactor: setting callback level to warn by default Signed-off-by: Lee, Kin Long Kelvin --- matsciml/lightning/callbacks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index e0be2b31..6a9b1165 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -1431,6 +1431,7 @@ def __init__(self, decay: float = 0.99) -> None: super().__init__() self.decay = decay self.logger = getLogger("matsciml.ema_callback") + self.logger.setLevel("WARN") def on_fit_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" From 17c69c26c5eb73d3b2f07c16954776a5faae731c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 11:16:19 -0700 Subject: [PATCH 11/12] refactor: allowing control over verbosity of callback --- matsciml/lightning/callbacks.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 6a9b1165..1c829010 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -1411,7 +1411,11 @@ def on_before_optimizer_step( class ExponentialMovingAverageCallback(Callback): - def __init__(self, decay: float = 0.99) -> None: + def __init__( + self, + decay: float = 0.99, + verbose: bool | Literal["WARN", "INFO", "DEBUG"] = "WARN", + ) -> None: """ Initialize an exponential moving average callback. @@ -1431,7 +1435,18 @@ def __init__(self, decay: float = 0.99) -> None: super().__init__() self.decay = decay self.logger = getLogger("matsciml.ema_callback") - self.logger.setLevel("WARN") + if isinstance(verbose, bool): + if not verbose: + verbose = "WARN" + else: + verbose = "INFO" + if isinstance(verbose, str): + assert verbose in [ + "WARN", + "INFO", + "DEBUG", + ], "Invalid verbosity setting in EMA callback." + self.logger.setLevel(verbose) def on_fit_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" From f19fa79c1e5822b1fe503e3c2e57feaf62cb4804 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 11:24:16 -0700 Subject: [PATCH 12/12] refactor: making validation and test log to progress bar as well --- matsciml/models/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index a2c14a3f..7bcf61b6 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -1083,7 +1083,7 @@ def validation_step( "Unable to parse batch size from data, defaulting to `None` for logging.", ) batch_size = None - self.log_dict(metrics, batch_size=batch_size) + self.log_dict(metrics, batch_size=batch_size, on_step=True, prog_bar=True) if self.hparams.log_embeddings and "embeddings" in batch: self._log_embedding(batch["embeddings"]) return loss_dict @@ -1105,7 +1105,7 @@ def test_step( "Unable to parse batch size from data, defaulting to `None` for logging.", ) batch_size = None - self.log_dict(metrics, batch_size=batch_size) + self.log_dict(metrics, batch_size=batch_size, on_epoch=True, prog_bar=True) if self.hparams.log_embeddings and "embeddings" in batch: self._log_embedding(batch["embeddings"]) return loss_dict @@ -1510,7 +1510,7 @@ def validation_step( "Unable to parse batch size from data, defaulting to `None` for logging." ) batch_size = None - self.log_dict(metrics, batch_size=batch_size) + self.log_dict(metrics, batch_size=batch_size, on_step=True, prog_bar=True) return loss_dict def test_step( @@ -1531,7 +1531,7 @@ def test_step( "Unable to parse batch size from data, defaulting to `None` for logging." ) batch_size = None - self.log_dict(metrics, batch_size=batch_size) + self.log_dict(metrics, batch_size=batch_size, on_step=True, prog_bar=True) return loss_dict def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]: