Skip to content

Commit

Permalink
fix pred_label validator
Browse files Browse the repository at this point in the history
  • Loading branch information
djdameln committed Oct 29, 2024
1 parent b719c95 commit ecd6e68
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/anomalib/data/validators/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,15 @@ def validate_pred_label(pred_label: torch.Tensor | None) -> torch.Tensor | None:
msg = f"Predicted label must be 1-dimensional or 2-dimensional, got shape {pred_label.shape}."
raise ValueError(msg)
if pred_label.ndim == 2:
if not (pred_label.shape[0] == 1 or pred_label.shape[1] == 1):
msg = f"Predicted label with 2 dimensions must have shape [N, 1], got shape {pred_label.shape}."
if pred_label.shape[0] == 1:
pred_label = pred_label.squeeze(0)
elif pred_label.shape[1] == 1:
pred_label = pred_label.squeeze(1)
else:
msg = (
f"Predicted label with 2 dimensions must have shape [N, 1] or [1, N], got shape {pred_label.shape}."
)
raise ValueError(msg)
pred_label = pred_label.squeeze()

return pred_label.to(torch.bool)

@staticmethod
Expand Down

0 comments on commit ecd6e68

Please sign in to comment.