Skip to content

Commit

Permalink
Add PR/AUROC curve. Fix wandb log_table logic. Fix regression earlyst…
Browse files Browse the repository at this point in the history
…op criteria
  • Loading branch information
1pha committed Mar 20, 2024
1 parent 227a5a1 commit f1c3138
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 39 deletions.
8 changes: 4 additions & 4 deletions config/callbacks/early_stop/binary.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: epoch/valid_BinaryF1Score
mode: max
patience: 10
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: epoch/valid_BinaryF1Score
mode: max
patience: 10
2 changes: 1 addition & 1 deletion config/callbacks/early_stop/reg.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: epoch/valid_MeanSquaredError
mode: max
mode: min
patience: 10
21 changes: 21 additions & 0 deletions config/sweep/adni_sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: "ADNI Sweep"
description: "HPO for ADNI classification task"
method: bayes
metric:
goal: maximize
name: test_acc
parameters:
model:
values: [ resnet_binary , convnext_binary ]
optim:
values: [ adamw , lion ]
scheduler:
values: [ exp_decay , cosine_anneal_warmup ]
optim.lr:
values: [ 5e-3 , 1e-3 , 1e-4 , 5e-5 ]
early_terminate:
type: hyperband
s: 2
eta: 3
max_iter: 27
run_cap: 50
1 change: 0 additions & 1 deletion sage/data/adni.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _load_data(self, idx: int) -> Tuple[torch.Tensor]:
""" Make sure to properly return PPMI """
raise NotImplementedError

@overrides
def _exclude_data(self, labels: pd.DataFrame, pk_col: str, root: Path,
exclusion_fname: str = "donotuse-adni.txt") -> pd.DataFrame:
""" TODO: Remove exclude from label """
Expand Down
1 change: 0 additions & 1 deletion sage/data/ppmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Tuple, List

import torch
import pandas as pd

from sage.data.dataloader import DatasetBase, open_scan
import sage.constants as C
Expand Down
5 changes: 3 additions & 2 deletions sage/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ def conv_layers(self):

class RegBase(ModelBase):
def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
# Specify squeeze dimension to prevent batch_size=1 being squeezed to a singel scalar.
pred = self.backbone(brain).squeeze(dim=1)
loss = self.criterion(pred, age.float())
return dict(loss=loss, pred=pred.detach(), target=age.detach())


class ClsBase(ModelBase):
def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
pred = self.backbone(brain).squeeze(dim=1)
loss = self.criterion(pred, age.long())
return dict(loss=loss, pred=pred.detach(), target=age.detach().long())

Expand Down
26 changes: 4 additions & 22 deletions sage/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def forward(self, batch, mode: str = "train"):
logger.exception(e)
breakpoint()
raise e

def move_device(self,
result: Dict[str, torch.Tensor],
exclude_keys: List[str] = ["loss"]) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -243,26 +243,6 @@ def log_confusion_matrix(self, result: dict):
pr = wandb.plot.pr_curve(y_true=labels, y_probas=probs)
self.logger.experiment.log({"confusion_matrix": cf, "roc_curve": roc, "pr_curve": pr})

def log_table(self, batch: Dict[str, torch.Tensor], result: Dict[str, torch.Tensor]):
""" Preparing table logging to wandb. """
if not hasattr(self, "table_columns"):
self.table_columns = ["PID", "Image", "Target", "Prediction", "Entropy"] + \
[f"Logit {c}" for c in range(result["cls_pred"].size(1))]
if not hasattr(self, "table_data"):
self.table_data = []

img_path, img = batch["image_path"], batch["image"]
for i, ind in enumerate(batch["indicator"]):
x, path = img[:ind], img_path[:ind]
pred = result["cls_pred"][i]
prediction = int(pred.argmax())
entropy = -float((pred * pred.log()).sum())
pred, target = pred.tolist(), int(result["cls_target"][i])
self.table_data.append(
["\n".join(path), wandb.Image(x), target, prediction, entropy] + pred
)
img, img_path = img[ind:], img_path[ind:]

def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False):
output = {f"{unit}/{k}": float(v) for k, v in output.items()}
self.log_dict(dictionary=output,
Expand Down Expand Up @@ -321,7 +301,7 @@ def on_predict_end(self):
self.log_confusion_matrix(result=result)
self.logger.log_table(key="Test Prediction", columns=["Target", "Prediction"],
data=[(t, p) for t, p in zip(result["target"].tolist(),
result["pred"].tolist())])
nn.functional.sigmoid(result["pred"]).tolist())])

def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
result: dict = self.forward(batch, mode="test")
Expand Down Expand Up @@ -429,6 +409,8 @@ def train(config: omegaconf.DictConfig) -> Dict[str, float]:
name=config.logger.name,
root_dir=Path(config.callbacks.checkpoint.dirpath))
trainer.logger.log_metrics(metric)
else:
metric = None

if config_update:
# Update configuration if needed
Expand Down
46 changes: 38 additions & 8 deletions sweep_command.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=1

python sweep.py --sweep_cfg_name=ppmi_sweep.yaml\
--wandb_project=ppmi\
--config_name=train_binary.yaml\
--sweep_prefix='Scratch'
# --overrides="['module.load_model_ckpt=meta_brain/weights/default/resnet10-42/156864-valid_mae3.465.ckpt',\
# '+module.load_model_strict=False']"

read -p "Enter devices: " device
export CUDA_VISIBLE_DEVICES=$device

read -p "Enter devices: ppmi|adni " ds
dataset=$ds

sweep_ppmi() {
echo "Sweep on PPMI"
python sweep.py --sweep_cfg_name=ppmi_sweep.yaml\
--wandb_project=ppmi\
--config_name=train_binary.yaml\
--sweep_prefix='Scratch'\
--overrides="['dataset=ppmi_binary', \
'+dataset.modality=[t2]', \
'dataloader.batch_size=4', \
'dataloader.num_workers=2', \
'trainer.accumulate_grad_batches=8']"
}

sweep_adni() {
echo "Sweep on ADNI"
python sweep.py --sweep_cfg_name=adni_sweep.yaml\
--wandb_project=adni\
--config_name=train_binary.yaml\
--sweep_prefix='Scratch'\
--overrides="['dataset=adni']"
}

# Check the input argument and call the appropriate function
if [ $dataset = "ppmi" ]; then
sweep_ppmi
elif [ $dataset = "adni" ]; then
sweep_adni
else
echo "Invalid argument. Usage: $0 [ppmi|adni]. Got $dataset instead"
exit 1
fi

0 comments on commit f1c3138

Please sign in to comment.