Skip to content

Commit

Permalink
Fix Abstract class hierarchy for models
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 5, 2024
1 parent 14b3012 commit 52a8c58
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion config/train_cls.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defaults:
- _self_
- model: convnext_cls
- dataset: ppmi
- scheduler: cosine_anneal_warmup
- scheduler: exp_decay
- optim: adamw

dataloader:
Expand Down
16 changes: 8 additions & 8 deletions sage/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ def conv_layers(self):
return find_conv_modules(self.backbone)


class ClsBase(ModelBase):
class RegBase(ModelBase):
def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
loss = self.criterion(pred, age.long())
return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu().long())
loss = self.criterion(pred, age.float())
return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu())


class RegBase(ModelBase):
class ClsBase(ModelBase):
def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
loss = self.criterion(pred, age.float())
return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu())
loss = self.criterion(pred, age.long())
return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu().long())


class ResNet(RegBase):
Expand All @@ -65,12 +65,12 @@ def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class ResNetCls(RegBase):
class ResNetCls(ClsBase):
def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class ConvNextCls(RegBase):
class ConvNextCls(ClsBase):
def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)

Expand Down

0 comments on commit 52a8c58

Please sign in to comment.