-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
57 lines (40 loc) · 1.39 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import os
import random
from typing import Any, List, Tuple, Dict
from types import ModuleType
import numpy as np
import torch
import torch.nn as nn
import torch.optim as module_optimizer
import torch.optim.lr_scheduler as module_scheduler
from torchvision import datasets, transforms
from sodium.utils import setup_logger, load_config, seed_everything
from sodium.trainer import Trainer
import sodium.runner as runner
logger = setup_logger(__name__)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train a Sodium Model')
parser.add_argument('-c', '--config', default=None,
type=str, help='config file path (default: None)')
parser.add_argument('--tsai-mode', action='store_true',
help='Enable TSAI Mode')
# parse the arguments
args = parser.parse_args()
# load the config
config = load_config(args.config)
# create a runner
runner = runner.Runner(config)
# setup train parameters
runner.setup_train(tsai_mode=args.tsai_mode)
# print the model summary
runner.print_summary(input_size=(3, 32, 32))
# find lr
runner.find_lr()
# train the network
runner.train(use_bestlr=True)
# plot metrics
runner.plot_metrics()
# plot gradcam
target_layers = ["layer1", "layer2", "layer3", "layer4"]
runner.plot_gradcam(target_layers=target_layers)