Skip to content

Commit

Permalink
Add check before loading metrics data from checkpoint (#2323)
Browse files Browse the repository at this point in the history
Add check before loading from checkpoint

Signed-off-by: Blaz Rolih <blaz.rolih@gmail.com>
Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
blaz-r and samet-akcay authored Sep 26, 2024
1 parent 983ec58 commit f473df8
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,19 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True
if "pixel_threshold_class" in state_dict:
self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class")

if "anomaly_maps_normalization_class" in state_dict:
self.anomaly_maps_normalization_metrics = self._get_instance(state_dict, "anomaly_maps_normalization_class")
if "box_scores_normalization_class" in state_dict:
self.box_scores_normalization_metrics = self._get_instance(state_dict, "box_scores_normalization_class")
# check only for pred score normalization metrics, because if this one is present, all others are too
if "pred_scores_normalization_class" in state_dict:
self.box_scores_normalization_metrics = self._get_instance(state_dict, "box_scores_normalization_class")
self.anomaly_maps_normalization_metrics = self._get_instance(state_dict, "anomaly_maps_normalization_class")
self.pred_scores_normalization_metrics = self._get_instance(state_dict, "pred_scores_normalization_class")

self.normalization_metrics = MetricCollection(
{
"anomaly_maps": self.anomaly_maps_normalization_metrics,
"box_scores": self.box_scores_normalization_metrics,
"pred_scores": self.pred_scores_normalization_metrics,
},
)
self.normalization_metrics = MetricCollection(
{
"anomaly_maps": self.anomaly_maps_normalization_metrics,
"box_scores": self.box_scores_normalization_metrics,
"pred_scores": self.pred_scores_normalization_metrics,
},
)
# Used to load metrics if there is any related data in state_dict
self._load_metrics(state_dict)

Expand Down

0 comments on commit f473df8

Please sign in to comment.