From 52a8c586413f2aa9ddf153a211e72c67974224c1 Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Tue, 5 Mar 2024 05:28:28 +0000 Subject: [PATCH] Fix Abstract class hierarchy for models --- config/train_cls.yaml | 2 +- sage/models/base.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/config/train_cls.yaml b/config/train_cls.yaml index c0d930a..00fae94 100644 --- a/config/train_cls.yaml +++ b/config/train_cls.yaml @@ -8,7 +8,7 @@ defaults: - _self_ - model: convnext_cls - dataset: ppmi - - scheduler: cosine_anneal_warmup + - scheduler: exp_decay - optim: adamw dataloader: diff --git a/sage/models/base.py b/sage/models/base.py index 7bd2d9a..588aaa3 100644 --- a/sage/models/base.py +++ b/sage/models/base.py @@ -41,18 +41,18 @@ def conv_layers(self): return find_conv_modules(self.backbone) -class ClsBase(ModelBase): +class RegBase(ModelBase): def forward(self, brain: torch.Tensor, age: torch.Tensor): pred = self.backbone(brain).squeeze() - loss = self.criterion(pred, age.long()) - return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu().long()) + loss = self.criterion(pred, age.float()) + return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu()) -class RegBase(ModelBase): +class ClsBase(ModelBase): def forward(self, brain: torch.Tensor, age: torch.Tensor): pred = self.backbone(brain).squeeze() - loss = self.criterion(pred, age.float()) - return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu()) + loss = self.criterion(pred, age.long()) + return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu().long()) class ResNet(RegBase): @@ -65,12 +65,12 @@ def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) -class ResNetCls(RegBase): +class ResNetCls(ClsBase): def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) -class ConvNextCls(RegBase): +class ConvNextCls(ClsBase): def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name)