From acd2024d6dee89ac9aa65ea81571419abcaa624a Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 12:08:34 -0700 Subject: [PATCH 01/17] feat: added weighted losses Signed-off-by: Lee, Kin Long Kelvin --- matsciml/models/losses.py | 57 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 matsciml/models/losses.py diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py new file mode 100644 index 00000000..8a995efb --- /dev/null +++ b/matsciml/models/losses.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import torch +from torch import nn + + +__all__ = ["AtomWeightedL1", "AtomWeightedMSE"] + + +class AtomWeightedL1(nn.Module): + """ + Calculates the L1 loss between predicted and targets, + weighted by the number of atoms within each graph. + """ + + def forward( + self, + predicted: torch.Tensor, + targets: torch.Tensor, + atoms_per_graph: torch.Tensor, + ) -> torch.Tensor: + # check to make sure we are broad casting correctly + if (predicted.ndim != targets.ndim) and targets.size(-1) == 1: + predicted.unsqueeze_(-1) + # for N-d targets, we might want to keep unsqueezing + while atoms_per_graph.ndim < targets.ndim: + atoms_per_graph.unsqueeze_(-1) + # ensures that atoms_per_graph is type cast correctly + squared_error = ( + (predicted - targets) / atoms_per_graph.to(predicted.dtype) + ).abs() + return squared_error.mean() + + +class AtomWeightedMSE(nn.Module): + """ + Calculates the mean-squared-error between predicted and targets, + weighted by the number of atoms within each graph. + """ + + def forward( + self, + predicted: torch.Tensor, + targets: torch.Tensor, + atoms_per_graph: torch.Tensor, + ) -> torch.Tensor: + # check to make sure we are broad casting correctly + if (predicted.ndim != targets.ndim) and targets.size(-1) == 1: + predicted.unsqueeze_(-1) + # for N-d targets, we might want to keep unsqueezing + while atoms_per_graph.ndim < targets.ndim: + atoms_per_graph.unsqueeze_(-1) + # ensures that atoms_per_graph is type cast correctly + squared_error = ( + (predicted - targets) / atoms_per_graph.to(predicted.dtype) + ) ** 2.0 + return squared_error.mean() From 18cd5cce9bee4a9cd05c081f1466c60f24126cef Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 13:11:21 -0700 Subject: [PATCH 02/17] refactor: adding exception handling for impossible broadcasting --- matsciml/models/losses.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 8a995efb..d15721aa 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -44,6 +44,12 @@ def forward( targets: torch.Tensor, atoms_per_graph: torch.Tensor, ) -> torch.Tensor: + if atoms_per_graph.size(0) != targets.size(0): + raise RuntimeError( + "Dimensions for atom-weighted loss do not match:" + f" expected atoms_per_graph to have {targets.size(0)} elements; got {atoms_per_graph.size(0)}." + "This loss is intended to be applied to scalar targets only." + ) # check to make sure we are broad casting correctly if (predicted.ndim != targets.ndim) and targets.size(-1) == 1: predicted.unsqueeze_(-1) From 84950bcb1237e924ffd660cc7d4a7e1032a43720 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 15:41:08 -0700 Subject: [PATCH 03/17] refactor: specifying energy and force loss calculators in forceregressiontask --- matsciml/models/base.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 4a3e168c..f966970d 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -22,6 +22,7 @@ from matsciml.common.types import AbstractGraph, BatchDict, DataDict, Embeddings from matsciml.models.common import OutputHead from matsciml.modules.normalizer import Normalizer +from matsciml.models import losses as matsciml_losses logger = getLogger("matsciml") logger.setLevel("INFO") @@ -1656,12 +1657,18 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.L1Loss, + loss_func: type[nn.Module] | nn.Module | None = None, + energy_loss_func: type[nn.Module] | nn.Module = matsciml_losses.AtomWeightedMSE, + force_loss_func: type[nn.Module] | nn.Module = nn.MSELoss, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, embedding_reduction_type: str = "sum", **kwargs, ) -> None: + if loss_func is not None: + logger.warning( + "loss_func is now ignored; please set energy/force_loss_func instead." + ) super().__init__( encoder, encoder_class, @@ -1672,7 +1679,15 @@ def __init__( embedding_reduction_type=embedding_reduction_type, **kwargs, ) - self.save_hyperparameters(ignore=["encoder", "loss_func"]) + if isinstance(energy_loss_func, type): + energy_loss_func = energy_loss_func() + if isinstance(force_loss_func, type): + force_loss_func = force_loss_func() + self.energy_loss_func = energy_loss_func + self.force_loss_func = force_loss_func + self.save_hyperparameters( + ignore=["encoder", "loss_func", "energy_loss_func", "force_loss_func"] + ) # have to enable double backprop self.automatic_optimization = False From 737912c58d8391927b64a9123ab4dcb9f56f7b94 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 15:44:16 -0700 Subject: [PATCH 04/17] Revert "refactor: specifying energy and force loss calculators in forceregressiontask" This reverts commit 84950bcb1237e924ffd660cc7d4a7e1032a43720. Realizing it's probably better to make `loss_func` and mapping instead. --- matsciml/models/base.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index f966970d..4a3e168c 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -22,7 +22,6 @@ from matsciml.common.types import AbstractGraph, BatchDict, DataDict, Embeddings from matsciml.models.common import OutputHead from matsciml.modules.normalizer import Normalizer -from matsciml.models import losses as matsciml_losses logger = getLogger("matsciml") logger.setLevel("INFO") @@ -1657,18 +1656,12 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module | None = None, - energy_loss_func: type[nn.Module] | nn.Module = matsciml_losses.AtomWeightedMSE, - force_loss_func: type[nn.Module] | nn.Module = nn.MSELoss, + loss_func: type[nn.Module] | nn.Module = nn.L1Loss, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, embedding_reduction_type: str = "sum", **kwargs, ) -> None: - if loss_func is not None: - logger.warning( - "loss_func is now ignored; please set energy/force_loss_func instead." - ) super().__init__( encoder, encoder_class, @@ -1679,15 +1672,7 @@ def __init__( embedding_reduction_type=embedding_reduction_type, **kwargs, ) - if isinstance(energy_loss_func, type): - energy_loss_func = energy_loss_func() - if isinstance(force_loss_func, type): - force_loss_func = force_loss_func() - self.energy_loss_func = energy_loss_func - self.force_loss_func = force_loss_func - self.save_hyperparameters( - ignore=["encoder", "loss_func", "energy_loss_func", "force_loss_func"] - ) + self.save_hyperparameters(ignore=["encoder", "loss_func"]) # have to enable double backprop self.automatic_optimization = False From 35618b18943b67bab81f7ff000776880fdb0a59c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 15:52:18 -0700 Subject: [PATCH 05/17] refactor: allowing loss_func arg to be dictionary mapping --- matsciml/models/base.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 4a3e168c..188b45a2 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -674,7 +674,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module | None = None, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = None, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, lr: float = 1e-4, @@ -704,6 +707,24 @@ def __init__( raise ValueError("No valid encoder passed.") if isinstance(loss_func, type): loss_func = loss_func() + # if we have a dictionary mapping, we specify the loss function + # for each target + if isinstance(loss_func, dict): + for key in loss_func: + # initialize objects if types are provided + if isinstance(loss_func[key], type): + loss_func[key] = loss_func[key]() + if task_keys and key not in task_keys: + raise KeyError( + f"Loss dict configured with {key}/{loss_func[key]} that was not provided in task keys." + ) + if not task_keys: + logger.warning( + f"Task keys were not specified, using loss_func keys instead {loss_func.keys()}" + ) + task_keys = list(loss_func.keys()) + # convert to a module dict for consistent API usage + loss_func = nn.ModuleDict({key: value for key, value in loss_func.items()}) self.loss_func = loss_func default_heads = {"act_last": None, "hidden_dim": 128} default_heads.update(output_kwargs) From c0e9bd76c6c7d2e41c69d18cc90c7a43f24eb31c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 16:03:28 -0700 Subject: [PATCH 06/17] refactor: make task keys setter convert single loss funcs to dict mapping --- matsciml/models/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 188b45a2..a050c927 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from contextlib import ExitStack, nullcontext +from copy import deepcopy from pathlib import Path from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, Union from warnings import warn @@ -763,6 +764,10 @@ def task_keys(self, values: set | list[str] | None) -> None: if not self.has_initialized: self.output_heads = self._make_output_heads() self.normalizers = self._make_normalizers() + # homogenize it into a dictionary mapping + if isinstance(self.loss_func, nn.Module): + loss_dict = nn.ModuleDict({key: deepcopy(self.loss_func) for key in values}) + self.loss_func = loss_dict self.hparams["task_keys"] = self._task_keys @property From ad8771b52df389084b66bdd320ffb83a326dabbe Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 16:05:12 -0700 Subject: [PATCH 07/17] refactor: treat loss_func as moduledict now --- matsciml/models/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index a050c927..606774c2 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -1038,7 +1038,8 @@ def _compute_losses( target_val = targets[key] if self.uses_normalizers: target_val = self.normalizers[key].norm(target_val) - loss = self.loss_func(predictions[key], target_val) + loss_func = self.loss_func[key] + loss = loss_func(predictions[key], target_val) losses[key] = loss * self.task_loss_scaling[key] total_loss: torch.Tensor = sum(losses.values()) From 1bb29ae0aeb86066e59baab8c896c3b22acf9c39 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 16:08:44 -0700 Subject: [PATCH 08/17] refactor: adding exception handling for missing loss functions --- matsciml/models/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 606774c2..483a2338 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -768,6 +768,13 @@ def task_keys(self, values: set | list[str] | None) -> None: if isinstance(self.loss_func, nn.Module): loss_dict = nn.ModuleDict({key: deepcopy(self.loss_func) for key in values}) self.loss_func = loss_dict + # if a task key was given but not contained in loss_func + # user needs to figure out what to do + for key in values: + if key not in self.loss_func.keys(): + raise KeyError( + f"Task key {key} was specified but no loss function was specified." + ) self.hparams["task_keys"] = self._task_keys @property From 95c7f6059f1961cfab62a2479ebaa3708a9f4e81 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 16:17:43 -0700 Subject: [PATCH 09/17] refactor: making child tasks accept dict mapping loss_func --- matsciml/models/base.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 483a2338..11bfb499 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -23,6 +23,7 @@ from matsciml.common.types import AbstractGraph, BatchDict, DataDict, Embeddings from matsciml.models.common import OutputHead from matsciml.modules.normalizer import Normalizer +from matsciml.models import losses as matsciml_losses logger = getLogger("matsciml") logger.setLevel("INFO") @@ -1269,7 +1270,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.MSELoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.MSELoss, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, **kwargs: Any, @@ -1387,7 +1391,10 @@ def __init__( encoder: Optional[nn.Module] = None, encoder_class: Optional[Type[nn.Module]] = None, encoder_kwargs: Optional[Dict[str, Any]] = None, - loss_func: Union[Type[nn.Module], nn.Module] = nn.MSELoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.MSELoss, loss_coeff: Optional[Dict[str, Any]] = None, task_keys: Optional[List[str]] = None, output_kwargs: Dict[str, Any] = {}, @@ -1614,7 +1621,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.BCEWithLogitsLoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.BCEWithLogitsLoss, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, **kwargs, @@ -1690,12 +1700,24 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.L1Loss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = None, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, embedding_reduction_type: str = "sum", **kwargs, ) -> None: + if not loss_func: + logger.warning( + "Loss functions were not specified. " + "Defaulting to AtomWeightedMSE for energy and MSE for force." + ) + loss_func = { + "energy": matsciml_losses.AtomWeightedMSE(), + "force": nn.MSELoss(), + } super().__init__( encoder, encoder_class, @@ -2109,7 +2131,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.MSELoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.MSELoss, output_kwargs: dict[str, Any] = {}, **kwargs: Any, ) -> None: @@ -2246,7 +2271,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.CrossEntropyLoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.CrossEntropyLoss, output_kwargs: dict[str, Any] = {}, normalize_kwargs: dict[str, float] | None = None, freeze_embedding: bool = False, From 8bc17bc4ef721dcdd17f6880b5bf637d99dbe291 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 16:32:16 -0700 Subject: [PATCH 10/17] fix: stopping redundant nesting with moduledict --- matsciml/models/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 11bfb499..18e07aef 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -766,7 +766,9 @@ def task_keys(self, values: set | list[str] | None) -> None: self.output_heads = self._make_output_heads() self.normalizers = self._make_normalizers() # homogenize it into a dictionary mapping - if isinstance(self.loss_func, nn.Module): + if isinstance(self.loss_func, nn.Module) and not isinstance( + self.loss_func, nn.ModuleDict + ): loss_dict = nn.ModuleDict({key: deepcopy(self.loss_func) for key in values}) self.loss_func = loss_dict # if a task key was given but not contained in loss_func From 6718b201204e0be5f68396ce7171464934f11cb1 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 18 Jul 2024 16:38:37 -0700 Subject: [PATCH 11/17] refactor: aligning loss args with pytorch forward signature --- matsciml/models/losses.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index d15721aa..2b436d56 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -15,20 +15,18 @@ class AtomWeightedL1(nn.Module): def forward( self, - predicted: torch.Tensor, - targets: torch.Tensor, + input: torch.Tensor, + target: torch.Tensor, atoms_per_graph: torch.Tensor, ) -> torch.Tensor: # check to make sure we are broad casting correctly - if (predicted.ndim != targets.ndim) and targets.size(-1) == 1: - predicted.unsqueeze_(-1) + if (input.ndim != target.ndim) and target.size(-1) == 1: + input.unsqueeze_(-1) # for N-d targets, we might want to keep unsqueezing - while atoms_per_graph.ndim < targets.ndim: + while atoms_per_graph.ndim < target.ndim: atoms_per_graph.unsqueeze_(-1) # ensures that atoms_per_graph is type cast correctly - squared_error = ( - (predicted - targets) / atoms_per_graph.to(predicted.dtype) - ).abs() + squared_error = ((input - target) / atoms_per_graph.to(input.dtype)).abs() return squared_error.mean() @@ -40,24 +38,22 @@ class AtomWeightedMSE(nn.Module): def forward( self, - predicted: torch.Tensor, - targets: torch.Tensor, + input: torch.Tensor, + target: torch.Tensor, atoms_per_graph: torch.Tensor, ) -> torch.Tensor: - if atoms_per_graph.size(0) != targets.size(0): + if atoms_per_graph.size(0) != target.size(0): raise RuntimeError( "Dimensions for atom-weighted loss do not match:" - f" expected atoms_per_graph to have {targets.size(0)} elements; got {atoms_per_graph.size(0)}." + f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}." "This loss is intended to be applied to scalar targets only." ) # check to make sure we are broad casting correctly - if (predicted.ndim != targets.ndim) and targets.size(-1) == 1: - predicted.unsqueeze_(-1) + if (input.ndim != target.ndim) and target.size(-1) == 1: + input.unsqueeze_(-1) # for N-d targets, we might want to keep unsqueezing - while atoms_per_graph.ndim < targets.ndim: + while atoms_per_graph.ndim < target.ndim: atoms_per_graph.unsqueeze_(-1) # ensures that atoms_per_graph is type cast correctly - squared_error = ( - (predicted - targets) / atoms_per_graph.to(predicted.dtype) - ) ** 2.0 + squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0 return squared_error.mean() From 1f34c33cfb1f0b3d1dd5b42a749cb5b15a7db408 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 09:13:58 -0700 Subject: [PATCH 12/17] refactor & fix: making denoising task use MSELoss by default --- matsciml/models/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 18e07aef..7ff557d5 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -3437,7 +3437,7 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module | None = None, + loss_func: type[nn.Module] | nn.Module | None = nn.MSELoss, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, lr: float = 0.0001, @@ -3464,7 +3464,6 @@ def __init__( scheduler_kwargs, **kwargs, ) - self.loss_func = nn.MSELoss() def _make_output_heads(self) -> nn.ModuleDict: # make a single output head for noise prediction applied to nodes From 025eeafca2bc8525d4827ea1d550ff9327dbcb85 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 09:14:32 -0700 Subject: [PATCH 13/17] refactor: using logger.warning instead of warn --- matsciml/models/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 7ff557d5..5815a451 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -3448,7 +3448,9 @@ def __init__( **kwargs, ) -> None: if task_keys is not None: - warn("Task keys were passed to NodeDenoisingTask, but is not used.") + logger.warning( + "Task keys were passed to NodeDenoisingTask, but is not used." + ) task_keys = ["denoise"] super().__init__( encoder, From 7abf08c739e6443f142843666da52c7e744f06cf Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 09:27:23 -0700 Subject: [PATCH 14/17] refactor: determining atoms counts to be passed into weighted calculation --- matsciml/models/base.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 5815a451..0b4084e4 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -10,6 +10,7 @@ from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, Union from warnings import warn from logging import getLogger +from inspect import signature import pytorch_lightning as pl from pytorch_lightning import loggers as pl_loggers @@ -1049,7 +1050,27 @@ def _compute_losses( if self.uses_normalizers: target_val = self.normalizers[key].norm(target_val) loss_func = self.loss_func[key] - loss = loss_func(predictions[key], target_val) + # determine if we need additional arguments + loss_func_signature = signature(loss_func.forward).parameters + kwargs = {"input": predictions[key], "target": target_val} + # pack atoms per graph information too + if "atoms_per_graph" in loss_func_signature: + if graph := batch.get("graph", None): + if isinstance(graph, dgl.DGLGraph): + num_atoms = graph.batch_num_nodes() + else: + # in the pyg case we use the pointer tensor + num_atoms = graph.ptr[1:] - graph.ptr[:-1] + else: + # in MP at least this is provided by the dataset class + num_atoms = batch.get("sizes", None) + if not num_atoms: + raise NotImplementedError( + "Unable to determine number of atoms for dataset. " + "This is required for the atom-weighted loss functions." + ) + kwargs["atoms_per_graph"] = num_atoms + loss = loss_func(**kwargs) losses[key] = loss * self.task_loss_scaling[key] total_loss: torch.Tensor = sum(losses.values()) From 9d0d66909c0cea96640f7a0aae4751795e26ad7a Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 09:27:54 -0700 Subject: [PATCH 15/17] fix: correcting task keys for force test --- matsciml/models/tests/test_task_loss_scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/models/tests/test_task_loss_scaling.py b/matsciml/models/tests/test_task_loss_scaling.py index f89cccc0..88f477dd 100644 --- a/matsciml/models/tests/test_task_loss_scaling.py +++ b/matsciml/models/tests/test_task_loss_scaling.py @@ -89,7 +89,7 @@ def test_force_regression(egnn_config): # Scenario where one task_key is set. Expect to use one task_loss_scaling value. task = ForceRegressionTask( **egnn_config, - task_keys=["force"], + task_keys=["force", "energy"], task_loss_scaling={"force": 10}, ) trainer = pl.Trainer(max_steps=5, logger=False, enable_checkpointing=False) From ba6930667baa481272b18b8ec22da95cf53ed777 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 09:40:02 -0700 Subject: [PATCH 16/17] refactor: adding exception handling for L1 weighted loss as well --- matsciml/models/losses.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 2b436d56..02205310 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -19,6 +19,12 @@ def forward( target: torch.Tensor, atoms_per_graph: torch.Tensor, ) -> torch.Tensor: + if atoms_per_graph.size(0) != target.size(0): + raise RuntimeError( + "Dimensions for atom-weighted loss do not match:" + f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}." + "This loss is intended to be applied to scalar targets only." + ) # check to make sure we are broad casting correctly if (input.ndim != target.ndim) and target.size(-1) == 1: input.unsqueeze_(-1) From 04d359f46a411fecca34d067c8d97599592d8909 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 19 Jul 2024 11:26:15 -0700 Subject: [PATCH 17/17] test: added unit test for MSE loss test Signed-off-by: Lee, Kin Long Kelvin --- matsciml/models/tests/test_losses.py | 31 ++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 matsciml/models/tests/test_losses.py diff --git a/matsciml/models/tests/test_losses.py b/matsciml/models/tests/test_losses.py new file mode 100644 index 00000000..b39373e9 --- /dev/null +++ b/matsciml/models/tests/test_losses.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import pytest +import torch + +from matsciml.models import losses + + +@pytest.fixture +def atom_weighted_l1(): + return losses.AtomWeightedL1() + + +@pytest.fixture +def atom_weighted_mse(): + return losses.AtomWeightedMSE() + + +@pytest.mark.parametrize( + "shape", + [ + (10,), + (10, 3), + (120, 1, 5), + ], +) +def test_weighted_mse(atom_weighted_mse, shape): + pred = torch.rand(*shape) + target = torch.rand_like(pred) + ptr = torch.randint(1, 100, (shape[0],)) + atom_weighted_mse(pred, target, ptr)