forked from IntelLabs/matsciml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into ema-support
- Loading branch information
Showing
4 changed files
with
194 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
__all__ = ["AtomWeightedL1", "AtomWeightedMSE"] | ||
|
||
|
||
class AtomWeightedL1(nn.Module): | ||
""" | ||
Calculates the L1 loss between predicted and targets, | ||
weighted by the number of atoms within each graph. | ||
""" | ||
|
||
def forward( | ||
self, | ||
input: torch.Tensor, | ||
target: torch.Tensor, | ||
atoms_per_graph: torch.Tensor, | ||
) -> torch.Tensor: | ||
if atoms_per_graph.size(0) != target.size(0): | ||
raise RuntimeError( | ||
"Dimensions for atom-weighted loss do not match:" | ||
f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}." | ||
"This loss is intended to be applied to scalar targets only." | ||
) | ||
# check to make sure we are broad casting correctly | ||
if (input.ndim != target.ndim) and target.size(-1) == 1: | ||
input.unsqueeze_(-1) | ||
# for N-d targets, we might want to keep unsqueezing | ||
while atoms_per_graph.ndim < target.ndim: | ||
atoms_per_graph.unsqueeze_(-1) | ||
# ensures that atoms_per_graph is type cast correctly | ||
squared_error = ((input - target) / atoms_per_graph.to(input.dtype)).abs() | ||
return squared_error.mean() | ||
|
||
|
||
class AtomWeightedMSE(nn.Module): | ||
""" | ||
Calculates the mean-squared-error between predicted and targets, | ||
weighted by the number of atoms within each graph. | ||
""" | ||
|
||
def forward( | ||
self, | ||
input: torch.Tensor, | ||
target: torch.Tensor, | ||
atoms_per_graph: torch.Tensor, | ||
) -> torch.Tensor: | ||
if atoms_per_graph.size(0) != target.size(0): | ||
raise RuntimeError( | ||
"Dimensions for atom-weighted loss do not match:" | ||
f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}." | ||
"This loss is intended to be applied to scalar targets only." | ||
) | ||
# check to make sure we are broad casting correctly | ||
if (input.ndim != target.ndim) and target.size(-1) == 1: | ||
input.unsqueeze_(-1) | ||
# for N-d targets, we might want to keep unsqueezing | ||
while atoms_per_graph.ndim < target.ndim: | ||
atoms_per_graph.unsqueeze_(-1) | ||
# ensures that atoms_per_graph is type cast correctly | ||
squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0 | ||
return squared_error.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
import torch | ||
|
||
from matsciml.models import losses | ||
|
||
|
||
@pytest.fixture | ||
def atom_weighted_l1(): | ||
return losses.AtomWeightedL1() | ||
|
||
|
||
@pytest.fixture | ||
def atom_weighted_mse(): | ||
return losses.AtomWeightedMSE() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape", | ||
[ | ||
(10,), | ||
(10, 3), | ||
(120, 1, 5), | ||
], | ||
) | ||
def test_weighted_mse(atom_weighted_mse, shape): | ||
pred = torch.rand(*shape) | ||
target = torch.rand_like(pred) | ||
ptr = torch.randint(1, 100, (shape[0],)) | ||
atom_weighted_mse(pred, target, ptr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters