Skip to content

Commit

Permalink
Remove setup method from the model implementations
Browse files Browse the repository at this point in the history
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
samet-akcay committed Oct 29, 2024
1 parent 03196fa commit d579312
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 62 deletions.
10 changes: 0 additions & 10 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(
self.post_processor = post_processor or self.default_post_processor()

self._input_size: tuple[int, int] | None = None

self._is_setup = False # flag to track if setup has been called from the trainer

@property
Expand All @@ -94,15 +93,6 @@ def _setup(self) -> None:
initialization.
"""

def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Called when loading a checkpoint.
This method is called to ensure that the `TorchModel` is built before
loading the state dict.
"""
del checkpoint # `checkpoint` variable is not used.
self.setup(stage="load_checkpoint")

def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Configure default callbacks for AnomalyModule."""
return [self.pre_processor]
Expand Down
14 changes: 5 additions & 9 deletions src/anomalib/models/image/csflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,15 @@ def __init__(
) -> None:
super().__init__(pre_processor=pre_processor)

if self.input_size is None:
msg = "CsFlow needs input size to build torch model."
raise ValueError(msg)

self.cross_conv_hidden_channels = cross_conv_hidden_channels
self.n_coupling_blocks = n_coupling_blocks
self.clamp = clamp
self.num_channels = num_channels

self.loss = CsFlowLoss()

self.model: CsFlowModel

def _setup(self) -> None:
if self.input_size is None:
msg = "CsFlow needs input size to build torch model."
raise ValueError(msg)

self.model = CsFlowModel(
input_size=self.input_size,
cross_conv_hidden_channels=self.cross_conv_hidden_channels,
Expand All @@ -71,6 +66,7 @@ def _setup(self) -> None:
num_channels=self.num_channels,
)
self.model.feature_extractor.eval()
self.loss = CsFlowLoss()

def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
"""Perform the training step of CS-Flow.
Expand Down
13 changes: 5 additions & 8 deletions src/anomalib/models/image/fastflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,16 @@ def __init__(
) -> None:
super().__init__(pre_processor=pre_processor)

if self.input_size is None:
msg = "Fastflow needs input size to build torch model."
raise ValueError(msg)

self.backbone = backbone
self.pre_trained = pre_trained
self.flow_steps = flow_steps
self.conv3x3_only = conv3x3_only
self.hidden_ratio = hidden_ratio

self.model: FastflowModel
self.loss = FastflowLoss()

def _setup(self) -> None:
if self.input_size is None:
msg = "Fastflow needs input size to build torch model."
raise ValueError(msg)

self.model = FastflowModel(
input_size=self.input_size,
backbone=self.backbone,
Expand All @@ -73,6 +69,7 @@ def _setup(self) -> None:
conv3x3_only=self.conv3x3_only,
hidden_ratio=self.hidden_ratio,
)
self.loss = FastflowLoss()

def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
"""Perform the training step input and return the loss.
Expand Down
27 changes: 13 additions & 14 deletions src/anomalib/models/image/ganomaly/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def __init__(
) -> None:
super().__init__(pre_processor=pre_processor)

if self.input_size is None:
msg = "GANomaly needs input size to build torch model."
raise ValueError(msg)

self.n_features = n_features
self.latent_vec_size = latent_vec_size
self.extra_layers = extra_layers
Expand All @@ -83,6 +87,15 @@ def __init__(
self.min_scores: torch.Tensor = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
self.max_scores: torch.Tensor = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable

self.model = GanomalyModel(
input_size=self.input_size,
num_input_channels=3,
n_features=self.n_features,
latent_vec_size=self.latent_vec_size,
extra_layers=self.extra_layers,
add_final_conv_layer=self.add_final_conv_layer,
)

self.generator_loss = GeneratorLoss(wadv, wcon, wenc)
self.discriminator_loss = DiscriminatorLoss()
self.automatic_optimization = False
Expand All @@ -95,20 +108,6 @@ def __init__(

self.model: GanomalyModel

def _setup(self) -> None:
if self.input_size is None:
msg = "GANomaly needs input size to build torch model."
raise ValueError(msg)

self.model = GanomalyModel(
input_size=self.input_size,
num_input_channels=3,
n_features=self.n_features,
latent_vec_size=self.latent_vec_size,
extra_layers=self.extra_layers,
add_final_conv_layer=self.add_final_conv_layer,
)

def _reset_min_max(self) -> None:
"""Reset min_max scores."""
self.min_scores = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,23 @@ def __init__(
) -> None:
super().__init__(pre_processor=pre_processor)

if self.input_size is None:
msg = "Input size is required for Reverse Distillation model."
raise ValueError(msg)

self.backbone = backbone
self.pre_trained = pre_trained
self.layers = layers
self.anomaly_map_mode = anomaly_map_mode

self.model: ReverseDistillationModel
self.loss = ReverseDistillationLoss()

def _setup(self) -> None:
if self.input_size is None:
msg = "Input size is required for Reverse Distillation model."
raise ValueError(msg)

self.model = ReverseDistillationModel(
backbone=self.backbone,
pre_trained=self.pre_trained,
layers=self.layers,
input_size=self.input_size,
anomaly_map_mode=self.anomaly_map_mode,
)
self.loss = ReverseDistillationLoss()

def configure_optimizers(self) -> optim.Adam:
"""Configure optimizers for decoder and bottleneck.
Expand Down
5 changes: 1 addition & 4 deletions src/anomalib/models/image/stfpm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ def __init__(
) -> None:
super().__init__(pre_processor=pre_processor)

self.model = STFPMModel(
backbone=backbone,
layers=layers,
)
self.model = STFPMModel(backbone=backbone, layers=layers)
self.loss = STFPMLoss()

def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
Expand Down
14 changes: 5 additions & 9 deletions src/anomalib/models/image/uflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,16 @@ def __init__(
"""
super().__init__(pre_processor=pre_processor)

if self.input_size is None:
msg = "Input size is required for UFlow model."
raise ValueError(msg)

self.backbone = backbone
self.flow_steps = flow_steps
self.affine_clamp = affine_clamp
self.affine_subnet_channels_ratio = affine_subnet_channels_ratio
self.permute_soft = permute_soft

self.loss = UFlowLoss()

self.model: UflowModel

def _setup(self) -> None:
if self.input_size is None:
msg = "Input size is required for UFlow model."
raise ValueError(msg)

self.model = UflowModel(
input_size=self.input_size,
backbone=self.backbone,
Expand All @@ -77,6 +72,7 @@ def _setup(self) -> None:
affine_subnet_channels_ratio=self.affine_subnet_channels_ratio,
permute_soft=self.permute_soft,
)
self.loss = UFlowLoss()

@classmethod
def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor:
Expand Down

0 comments on commit d579312

Please sign in to comment.