Skip to content

Commit

Permalink
Merge pull request IntelLabs#257 from laserkelvin/ema-support
Browse files Browse the repository at this point in the history
Implementing exponential moving average Lightning callback
  • Loading branch information
laserkelvin authored Jul 19, 2024
2 parents ab4c1d2 + 73b00a9 commit 269b85d
Show file tree
Hide file tree
Showing 5 changed files with 493 additions and 213 deletions.
62 changes: 62 additions & 0 deletions examples/callbacks/exponential_moving_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import pytorch_lightning as pl
from pytorch_lightning.callbacks import StochasticWeightAveraging

from matsciml.datasets.transforms import DistancesTransform, PointCloudToGraphTransform
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.lightning.callbacks import ExponentialMovingAverageCallback
from matsciml.models import SchNet
from matsciml.models.base import ScalarRegressionTask

"""
This script demonstrates how to use the EMA and SWA callbacks,
which are pretty necessary for models such as MACE.
EMA is implemented within ``matsciml`` using native PyTorch, whereas SWA uses
the PyTorch Lightning implementation.
"""

# construct a scalar regression task with SchNet encoder
task = ScalarRegressionTask(
encoder_class=SchNet,
# kwargs to be passed into the creation of SchNet model
encoder_kwargs={
"encoder_only": True,
"hidden_feats": [128, 128, 128],
"atom_embedding_dim": 128,
},
output_kwargs={"lazy": False, "hidden_dim": 128, "input_dim": 128},
# which keys to use as targets
task_keys=["energy_relaxed"],
)

# Use IS2RE devset to test workflow
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
dset_kwargs={
"transforms": [
PointCloudToGraphTransform(
"dgl",
cutoff_dist=20.0,
node_keys=["pos", "atomic_numbers"],
),
DistancesTransform(),
],
},
)

# run several epochs with a limited number of train batches
# to make sure nothing breaks between updates
trainer = pl.Trainer(
max_epochs=5,
limit_train_batches=10,
logger=False,
enable_checkpointing=False,
callbacks=[
StochasticWeightAveraging(swa_lrs=1e-2, swa_epoch_start=1),
ExponentialMovingAverageCallback(decay=0.99),
],
)
trainer.fit(task, datamodule=dm)
85 changes: 85 additions & 0 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.optim import Optimizer
from dgl import DGLGraph
from scipy.signal import correlate
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn, update_bn

from matsciml.common.packages import package_registry
from matsciml.datasets.utils import concatenate_keys
Expand Down Expand Up @@ -1407,3 +1408,87 @@ def on_before_optimizer_step(
_ = self.history["grads"].get()
if self.is_active:
self.run_analysis(pl_module.logger)


class ExponentialMovingAverageCallback(Callback):
def __init__(
self,
decay: float = 0.99,
verbose: bool | Literal["WARN", "INFO", "DEBUG"] = "WARN",
) -> None:
"""
Initialize an exponential moving average callback.
This callback attaches a ``ema_module`` attribute to
the current training task, which duplicates the model
weights that are tracked with an exponential moving
average, parametrized by the ``decay`` value.
This will double the memory footprint of your model,
but has been shown to considerably improve generalization.
Parameters
----------
decay : float
Exponential decay factor to apply to updates.
"""
super().__init__()
self.decay = decay
self.logger = getLogger("matsciml.ema_callback")
if isinstance(verbose, bool):
if not verbose:
verbose = "WARN"
else:
verbose = "INFO"
if isinstance(verbose, str):
assert verbose in [
"WARN",
"INFO",
"DEBUG",
], "Invalid verbosity setting in EMA callback."
self.logger.setLevel(verbose)

def on_fit_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
# check to make sure the module has no lazy layers
for layer in pl_module.modules():
if isinstance(layer, nn.modules.lazy.LazyModuleMixin):
if layer.has_uninitialized_params():
raise RuntimeError(
"EMA callback does not support lazy layers. Please "
"re-run without using lazy layers."
)
# in the case that there is already an EMA state we don't initialize
if hasattr(pl_module, "ema_module"):
self.logger.info(
"Task has an existing EMA state; not initializing a new one."
)
self.ema_module = pl_module.ema_module
else:
# hook to the task module and in the current callback
ema_module = AveragedModel(
pl_module, multi_avg_fn=get_ema_multi_avg_fn(self.decay)
)
self.logger.info("Task does not have an existing EMA state; creating one.")
# setting the callback ema_module attribute allows ease of access
self.ema_module = ema_module
pl_module.ema_module = ema_module

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs,
batch: Any,
batch_idx: int,
) -> None:
self.logger.info("Updating EMA state.")
pl_module.ema_module.update_parameters(pl_module)

def on_fit_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
loader = trainer.train_dataloader
self.logger.info("Fit finished - updating EMA batch normalization state.")
update_bn(loader, pl_module.ema_module)
Loading

0 comments on commit 269b85d

Please sign in to comment.