Skip to content

Commit

Permalink
Merge branch 'main' into ema-support
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Jul 19, 2024
2 parents f19fa79 + 8f8ec4b commit 73b00a9
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 12 deletions.
108 changes: 97 additions & 11 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from contextlib import ExitStack, nullcontext
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, Union
from warnings import warn
from logging import getLogger
from inspect import signature

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
Expand All @@ -22,6 +24,7 @@
from matsciml.common.types import AbstractGraph, BatchDict, DataDict, Embeddings
from matsciml.models.common import OutputHead
from matsciml.modules.normalizer import Normalizer
from matsciml.models import losses as matsciml_losses

logger = getLogger("matsciml")
logger.setLevel("INFO")
Expand Down Expand Up @@ -674,7 +677,10 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module] | nn.Module | None = None,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = None,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
lr: float = 1e-4,
Expand Down Expand Up @@ -704,6 +710,24 @@ def __init__(
raise ValueError("No valid encoder passed.")
if isinstance(loss_func, type):
loss_func = loss_func()
# if we have a dictionary mapping, we specify the loss function
# for each target
if isinstance(loss_func, dict):
for key in loss_func:
# initialize objects if types are provided
if isinstance(loss_func[key], type):
loss_func[key] = loss_func[key]()
if task_keys and key not in task_keys:
raise KeyError(
f"Loss dict configured with {key}/{loss_func[key]} that was not provided in task keys."
)
if not task_keys:
logger.warning(
f"Task keys were not specified, using loss_func keys instead {loss_func.keys()}"
)
task_keys = list(loss_func.keys())
# convert to a module dict for consistent API usage
loss_func = nn.ModuleDict({key: value for key, value in loss_func.items()})
self.loss_func = loss_func
default_heads = {"act_last": None, "hidden_dim": 128}
default_heads.update(output_kwargs)
Expand Down Expand Up @@ -742,6 +766,19 @@ def task_keys(self, values: set | list[str] | None) -> None:
if not self.has_initialized:
self.output_heads = self._make_output_heads()
self.normalizers = self._make_normalizers()
# homogenize it into a dictionary mapping
if isinstance(self.loss_func, nn.Module) and not isinstance(
self.loss_func, nn.ModuleDict
):
loss_dict = nn.ModuleDict({key: deepcopy(self.loss_func) for key in values})
self.loss_func = loss_dict
# if a task key was given but not contained in loss_func
# user needs to figure out what to do
for key in values:
if key not in self.loss_func.keys():
raise KeyError(
f"Task key {key} was specified but no loss function was specified."
)
self.hparams["task_keys"] = self._task_keys

@property
Expand Down Expand Up @@ -1017,7 +1054,28 @@ def _compute_losses(
target_val = targets[key]
if self.uses_normalizers:
target_val = self.normalizers[key].norm(target_val)
loss = self.loss_func(predictions[key], target_val)
loss_func = self.loss_func[key]
# determine if we need additional arguments
loss_func_signature = signature(loss_func.forward).parameters
kwargs = {"input": predictions[key], "target": target_val}
# pack atoms per graph information too
if "atoms_per_graph" in loss_func_signature:
if graph := batch.get("graph", None):
if isinstance(graph, dgl.DGLGraph):
num_atoms = graph.batch_num_nodes()
else:
# in the pyg case we use the pointer tensor
num_atoms = graph.ptr[1:] - graph.ptr[:-1]
else:
# in MP at least this is provided by the dataset class
num_atoms = batch.get("sizes", None)
if not num_atoms:
raise NotImplementedError(
"Unable to determine number of atoms for dataset. "
"This is required for the atom-weighted loss functions."
)
kwargs["atoms_per_graph"] = num_atoms
loss = loss_func(**kwargs)
losses[key] = loss * self.task_loss_scaling[key]

total_loss: torch.Tensor = sum(losses.values())
Expand Down Expand Up @@ -1245,7 +1303,10 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module] | nn.Module = nn.MSELoss,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.MSELoss,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
**kwargs: Any,
Expand Down Expand Up @@ -1363,7 +1424,10 @@ def __init__(
encoder: Optional[nn.Module] = None,
encoder_class: Optional[Type[nn.Module]] = None,
encoder_kwargs: Optional[Dict[str, Any]] = None,
loss_func: Union[Type[nn.Module], nn.Module] = nn.MSELoss,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.MSELoss,
loss_coeff: Optional[Dict[str, Any]] = None,
task_keys: Optional[List[str]] = None,
output_kwargs: Dict[str, Any] = {},
Expand Down Expand Up @@ -1590,7 +1654,10 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module] | nn.Module = nn.BCEWithLogitsLoss,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.BCEWithLogitsLoss,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
**kwargs,
Expand Down Expand Up @@ -1666,12 +1733,24 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module] | nn.Module = nn.L1Loss,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = None,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
embedding_reduction_type: str = "sum",
**kwargs,
) -> None:
if not loss_func:
logger.warning(
"Loss functions were not specified. "
"Defaulting to AtomWeightedMSE for energy and MSE for force."
)
loss_func = {
"energy": matsciml_losses.AtomWeightedMSE(),
"force": nn.MSELoss(),
}
super().__init__(
encoder,
encoder_class,
Expand Down Expand Up @@ -2085,7 +2164,10 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module] | nn.Module = nn.MSELoss,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.MSELoss,
output_kwargs: dict[str, Any] = {},
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -2222,7 +2304,10 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module] | nn.Module = nn.CrossEntropyLoss,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.CrossEntropyLoss,
output_kwargs: dict[str, Any] = {},
normalize_kwargs: dict[str, float] | None = None,
freeze_embedding: bool = False,
Expand Down Expand Up @@ -3383,7 +3468,7 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module] | nn.Module | None = None,
loss_func: type[nn.Module] | nn.Module | None = nn.MSELoss,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
lr: float = 0.0001,
Expand All @@ -3394,7 +3479,9 @@ def __init__(
**kwargs,
) -> None:
if task_keys is not None:
warn("Task keys were passed to NodeDenoisingTask, but is not used.")
logger.warning(
"Task keys were passed to NodeDenoisingTask, but is not used."
)
task_keys = ["denoise"]
super().__init__(
encoder,
Expand All @@ -3410,7 +3497,6 @@ def __init__(
scheduler_kwargs,
**kwargs,
)
self.loss_func = nn.MSELoss()

def _make_output_heads(self) -> nn.ModuleDict:
# make a single output head for noise prediction applied to nodes
Expand Down
65 changes: 65 additions & 0 deletions matsciml/models/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import torch
from torch import nn


__all__ = ["AtomWeightedL1", "AtomWeightedMSE"]


class AtomWeightedL1(nn.Module):
"""
Calculates the L1 loss between predicted and targets,
weighted by the number of atoms within each graph.
"""

def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
atoms_per_graph: torch.Tensor,
) -> torch.Tensor:
if atoms_per_graph.size(0) != target.size(0):
raise RuntimeError(
"Dimensions for atom-weighted loss do not match:"
f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}."
"This loss is intended to be applied to scalar targets only."
)
# check to make sure we are broad casting correctly
if (input.ndim != target.ndim) and target.size(-1) == 1:
input.unsqueeze_(-1)
# for N-d targets, we might want to keep unsqueezing
while atoms_per_graph.ndim < target.ndim:
atoms_per_graph.unsqueeze_(-1)
# ensures that atoms_per_graph is type cast correctly
squared_error = ((input - target) / atoms_per_graph.to(input.dtype)).abs()
return squared_error.mean()


class AtomWeightedMSE(nn.Module):
"""
Calculates the mean-squared-error between predicted and targets,
weighted by the number of atoms within each graph.
"""

def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
atoms_per_graph: torch.Tensor,
) -> torch.Tensor:
if atoms_per_graph.size(0) != target.size(0):
raise RuntimeError(
"Dimensions for atom-weighted loss do not match:"
f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}."
"This loss is intended to be applied to scalar targets only."
)
# check to make sure we are broad casting correctly
if (input.ndim != target.ndim) and target.size(-1) == 1:
input.unsqueeze_(-1)
# for N-d targets, we might want to keep unsqueezing
while atoms_per_graph.ndim < target.ndim:
atoms_per_graph.unsqueeze_(-1)
# ensures that atoms_per_graph is type cast correctly
squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0
return squared_error.mean()
31 changes: 31 additions & 0 deletions matsciml/models/tests/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import pytest
import torch

from matsciml.models import losses


@pytest.fixture
def atom_weighted_l1():
return losses.AtomWeightedL1()


@pytest.fixture
def atom_weighted_mse():
return losses.AtomWeightedMSE()


@pytest.mark.parametrize(
"shape",
[
(10,),
(10, 3),
(120, 1, 5),
],
)
def test_weighted_mse(atom_weighted_mse, shape):
pred = torch.rand(*shape)
target = torch.rand_like(pred)
ptr = torch.randint(1, 100, (shape[0],))
atom_weighted_mse(pred, target, ptr)
2 changes: 1 addition & 1 deletion matsciml/models/tests/test_task_loss_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_force_regression(egnn_config):
# Scenario where one task_key is set. Expect to use one task_loss_scaling value.
task = ForceRegressionTask(
**egnn_config,
task_keys=["force"],
task_keys=["force", "energy"],
task_loss_scaling={"force": 10},
)
trainer = pl.Trainer(max_steps=5, logger=False, enable_checkpointing=False)
Expand Down

0 comments on commit 73b00a9

Please sign in to comment.