diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 7bcf61b6..1637b936 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -5,10 +5,12 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from contextlib import ExitStack, nullcontext +from copy import deepcopy from pathlib import Path from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, Union from warnings import warn from logging import getLogger +from inspect import signature import pytorch_lightning as pl from pytorch_lightning import loggers as pl_loggers @@ -22,6 +24,7 @@ from matsciml.common.types import AbstractGraph, BatchDict, DataDict, Embeddings from matsciml.models.common import OutputHead from matsciml.modules.normalizer import Normalizer +from matsciml.models import losses as matsciml_losses logger = getLogger("matsciml") logger.setLevel("INFO") @@ -674,7 +677,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module | None = None, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = None, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, lr: float = 1e-4, @@ -704,6 +710,24 @@ def __init__( raise ValueError("No valid encoder passed.") if isinstance(loss_func, type): loss_func = loss_func() + # if we have a dictionary mapping, we specify the loss function + # for each target + if isinstance(loss_func, dict): + for key in loss_func: + # initialize objects if types are provided + if isinstance(loss_func[key], type): + loss_func[key] = loss_func[key]() + if task_keys and key not in task_keys: + raise KeyError( + f"Loss dict configured with {key}/{loss_func[key]} that was not provided in task keys." + ) + if not task_keys: + logger.warning( + f"Task keys were not specified, using loss_func keys instead {loss_func.keys()}" + ) + task_keys = list(loss_func.keys()) + # convert to a module dict for consistent API usage + loss_func = nn.ModuleDict({key: value for key, value in loss_func.items()}) self.loss_func = loss_func default_heads = {"act_last": None, "hidden_dim": 128} default_heads.update(output_kwargs) @@ -742,6 +766,19 @@ def task_keys(self, values: set | list[str] | None) -> None: if not self.has_initialized: self.output_heads = self._make_output_heads() self.normalizers = self._make_normalizers() + # homogenize it into a dictionary mapping + if isinstance(self.loss_func, nn.Module) and not isinstance( + self.loss_func, nn.ModuleDict + ): + loss_dict = nn.ModuleDict({key: deepcopy(self.loss_func) for key in values}) + self.loss_func = loss_dict + # if a task key was given but not contained in loss_func + # user needs to figure out what to do + for key in values: + if key not in self.loss_func.keys(): + raise KeyError( + f"Task key {key} was specified but no loss function was specified." + ) self.hparams["task_keys"] = self._task_keys @property @@ -1017,7 +1054,28 @@ def _compute_losses( target_val = targets[key] if self.uses_normalizers: target_val = self.normalizers[key].norm(target_val) - loss = self.loss_func(predictions[key], target_val) + loss_func = self.loss_func[key] + # determine if we need additional arguments + loss_func_signature = signature(loss_func.forward).parameters + kwargs = {"input": predictions[key], "target": target_val} + # pack atoms per graph information too + if "atoms_per_graph" in loss_func_signature: + if graph := batch.get("graph", None): + if isinstance(graph, dgl.DGLGraph): + num_atoms = graph.batch_num_nodes() + else: + # in the pyg case we use the pointer tensor + num_atoms = graph.ptr[1:] - graph.ptr[:-1] + else: + # in MP at least this is provided by the dataset class + num_atoms = batch.get("sizes", None) + if not num_atoms: + raise NotImplementedError( + "Unable to determine number of atoms for dataset. " + "This is required for the atom-weighted loss functions." + ) + kwargs["atoms_per_graph"] = num_atoms + loss = loss_func(**kwargs) losses[key] = loss * self.task_loss_scaling[key] total_loss: torch.Tensor = sum(losses.values()) @@ -1245,7 +1303,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.MSELoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.MSELoss, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, **kwargs: Any, @@ -1363,7 +1424,10 @@ def __init__( encoder: Optional[nn.Module] = None, encoder_class: Optional[Type[nn.Module]] = None, encoder_kwargs: Optional[Dict[str, Any]] = None, - loss_func: Union[Type[nn.Module], nn.Module] = nn.MSELoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.MSELoss, loss_coeff: Optional[Dict[str, Any]] = None, task_keys: Optional[List[str]] = None, output_kwargs: Dict[str, Any] = {}, @@ -1590,7 +1654,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.BCEWithLogitsLoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.BCEWithLogitsLoss, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, **kwargs, @@ -1666,12 +1733,24 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.L1Loss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = None, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, embedding_reduction_type: str = "sum", **kwargs, ) -> None: + if not loss_func: + logger.warning( + "Loss functions were not specified. " + "Defaulting to AtomWeightedMSE for energy and MSE for force." + ) + loss_func = { + "energy": matsciml_losses.AtomWeightedMSE(), + "force": nn.MSELoss(), + } super().__init__( encoder, encoder_class, @@ -2085,7 +2164,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.MSELoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.MSELoss, output_kwargs: dict[str, Any] = {}, **kwargs: Any, ) -> None: @@ -2222,7 +2304,10 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module = nn.CrossEntropyLoss, + loss_func: type[nn.Module] + | nn.Module + | dict[str, nn.Module | type[nn.Module]] + | None = nn.CrossEntropyLoss, output_kwargs: dict[str, Any] = {}, normalize_kwargs: dict[str, float] | None = None, freeze_embedding: bool = False, @@ -3383,7 +3468,7 @@ def __init__( encoder: nn.Module | None = None, encoder_class: type[nn.Module] | None = None, encoder_kwargs: dict[str, Any] | None = None, - loss_func: type[nn.Module] | nn.Module | None = None, + loss_func: type[nn.Module] | nn.Module | None = nn.MSELoss, task_keys: list[str] | None = None, output_kwargs: dict[str, Any] = {}, lr: float = 0.0001, @@ -3394,7 +3479,9 @@ def __init__( **kwargs, ) -> None: if task_keys is not None: - warn("Task keys were passed to NodeDenoisingTask, but is not used.") + logger.warning( + "Task keys were passed to NodeDenoisingTask, but is not used." + ) task_keys = ["denoise"] super().__init__( encoder, @@ -3410,7 +3497,6 @@ def __init__( scheduler_kwargs, **kwargs, ) - self.loss_func = nn.MSELoss() def _make_output_heads(self) -> nn.ModuleDict: # make a single output head for noise prediction applied to nodes diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py new file mode 100644 index 00000000..02205310 --- /dev/null +++ b/matsciml/models/losses.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import torch +from torch import nn + + +__all__ = ["AtomWeightedL1", "AtomWeightedMSE"] + + +class AtomWeightedL1(nn.Module): + """ + Calculates the L1 loss between predicted and targets, + weighted by the number of atoms within each graph. + """ + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + atoms_per_graph: torch.Tensor, + ) -> torch.Tensor: + if atoms_per_graph.size(0) != target.size(0): + raise RuntimeError( + "Dimensions for atom-weighted loss do not match:" + f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}." + "This loss is intended to be applied to scalar targets only." + ) + # check to make sure we are broad casting correctly + if (input.ndim != target.ndim) and target.size(-1) == 1: + input.unsqueeze_(-1) + # for N-d targets, we might want to keep unsqueezing + while atoms_per_graph.ndim < target.ndim: + atoms_per_graph.unsqueeze_(-1) + # ensures that atoms_per_graph is type cast correctly + squared_error = ((input - target) / atoms_per_graph.to(input.dtype)).abs() + return squared_error.mean() + + +class AtomWeightedMSE(nn.Module): + """ + Calculates the mean-squared-error between predicted and targets, + weighted by the number of atoms within each graph. + """ + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + atoms_per_graph: torch.Tensor, + ) -> torch.Tensor: + if atoms_per_graph.size(0) != target.size(0): + raise RuntimeError( + "Dimensions for atom-weighted loss do not match:" + f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}." + "This loss is intended to be applied to scalar targets only." + ) + # check to make sure we are broad casting correctly + if (input.ndim != target.ndim) and target.size(-1) == 1: + input.unsqueeze_(-1) + # for N-d targets, we might want to keep unsqueezing + while atoms_per_graph.ndim < target.ndim: + atoms_per_graph.unsqueeze_(-1) + # ensures that atoms_per_graph is type cast correctly + squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0 + return squared_error.mean() diff --git a/matsciml/models/tests/test_losses.py b/matsciml/models/tests/test_losses.py new file mode 100644 index 00000000..b39373e9 --- /dev/null +++ b/matsciml/models/tests/test_losses.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import pytest +import torch + +from matsciml.models import losses + + +@pytest.fixture +def atom_weighted_l1(): + return losses.AtomWeightedL1() + + +@pytest.fixture +def atom_weighted_mse(): + return losses.AtomWeightedMSE() + + +@pytest.mark.parametrize( + "shape", + [ + (10,), + (10, 3), + (120, 1, 5), + ], +) +def test_weighted_mse(atom_weighted_mse, shape): + pred = torch.rand(*shape) + target = torch.rand_like(pred) + ptr = torch.randint(1, 100, (shape[0],)) + atom_weighted_mse(pred, target, ptr) diff --git a/matsciml/models/tests/test_task_loss_scaling.py b/matsciml/models/tests/test_task_loss_scaling.py index f89cccc0..88f477dd 100644 --- a/matsciml/models/tests/test_task_loss_scaling.py +++ b/matsciml/models/tests/test_task_loss_scaling.py @@ -89,7 +89,7 @@ def test_force_regression(egnn_config): # Scenario where one task_key is set. Expect to use one task_loss_scaling value. task = ForceRegressionTask( **egnn_config, - task_keys=["force"], + task_keys=["force", "energy"], task_loss_scaling={"force": 10}, ) trainer = pl.Trainer(max_steps=5, logger=False, enable_checkpointing=False)