Skip to content

Commit

Permalink
Diverge abstract class
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 5, 2024
1 parent a80b466 commit 0913fe2
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
2 changes: 1 addition & 1 deletion config/model/convnext_cls.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: sage.models.base.ConvNext
_target_: sage.models.base.ConvNextCls
backbone:
_target_: sage.models.model_zoo.convnext.build_convnext
model_name: convnext-base
Expand Down
2 changes: 1 addition & 1 deletion config/model/resnet_cls.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: sage.models.base.ResNet
_target_: sage.models.base.ResNetCls
backbone:
_target_: sage.models.model_zoo.resnet.build_resnet
model_depth: 10
Expand Down
38 changes: 36 additions & 2 deletions sage/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,49 @@ def conv_layers(self):
return find_conv_modules(self.backbone)


class ResNet(ModelBase):
class ClsBase(ModelBase):
def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
loss = self.criterion(pred, age.long())
return dict(loss=loss,
pred=pred.detach().cpu(),
target=age.detach().cpu().long())


class RegBase(ModelBase):
def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
loss = self.criterion(pred, age.float())
return dict(loss=loss,
pred=pred.detach().cpu(),
target=age.detach().cpu())


class ResNet(RegBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class ConvNext(RegBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class ResNetCls(RegBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class ConvNext(ModelBase):
class ConvNextCls(RegBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
Expand Down

0 comments on commit 0913fe2

Please sign in to comment.