-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
116 lines (96 loc) · 3.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# General imports
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import glob
import numpy as np
import tensorflow as tf
import wandb
from absl import app, flags
from ml_collections.config_flags import config_flags
from sklearn.utils import class_weight
from tensorflow.keras.callbacks import LearningRateScheduler
from wandb.keras import WandbCallback
from ssl_study import callbacks
# Import modules
from ssl_study.data import (GetDataloader, download_dataset,
preprocess_dataframe)
from ssl_study.models import SimpleSupervisedModel
from ssl_study.pipeline import SupervisedPipeline
FLAGS = flags.FLAGS
CONFIG = config_flags.DEFINE_config_file("config")
flags.DEFINE_bool("wandb", False, "MLOps pipeline for our classifier.")
flags.DEFINE_bool("log_model", False, "Checkpoint model while training.")
flags.DEFINE_bool(
"log_eval", False, "Log model prediction, needs --wandb argument as well."
)
def main(_):
# Get configs from the config file.
config = CONFIG.value
print(config)
CALLBACKS = []
sync_tensorboard = None
if config.callback_config.use_tensorboard:
sync_tensorboard = True
# Initialize a Weights and Biases run.
if FLAGS.wandb:
run = wandb.init(
entity=CONFIG.value.wandb_config.entity,
project=CONFIG.value.wandb_config.project,
job_type="train",
config=config.to_dict(),
sync_tensorboard=sync_tensorboard,
)
# Initialize W&B metrics logger callback.
CALLBACKS += [callbacks.WandBMetricsLogger()]
# Load the dataframes
train_df = download_dataset("train", "labelled-dataset")
valid_df = download_dataset("val", "labelled-dataset")
# Preprocess the DataFrames
train_paths, train_labels = preprocess_dataframe(train_df, is_labelled=True)
valid_paths, valid_labels = preprocess_dataframe(valid_df, is_labelled=True)
# Compute class weights if use_class_weights is True.
class_weights = None
if config.bool_config.use_class_weights:
class_weights = class_weight.compute_class_weight(
class_weight="balanced", classes=np.unique(train_labels), y=train_labels
)
class_weights = dict(zip(np.unique(train_labels), class_weights))
# Build dataloaders
dataset = GetDataloader(config)
trainloader = dataset.get_dataloader(
train_paths, train_labels, dataloader_type="train"
)
validloader = dataset.get_dataloader(
valid_paths, valid_labels, dataloader_type="valid"
)
# Initialize callbacks
callback_config = config.callback_config
# Builtin early stopping callback
if callback_config.use_earlystopping:
earlystopper = callbacks.get_earlystopper(config)
CALLBACKS += [earlystopper]
# Built in callback to reduce learning rate on plateau
if callback_config.use_reduce_lr_on_plateau:
reduce_lr_on_plateau = callbacks.get_reduce_lr_on_plateau(config)
CALLBACKS += [reduce_lr_on_plateau]
# Initialize Model checkpointing callback
if FLAGS.log_model:
# Custom W&B model checkpoint callback
model_checkpointer = callbacks.get_model_checkpoint_callback(config)
CALLBACKS += [model_checkpointer]
if wandb.run is not None:
if FLAGS.log_eval:
model_pred_viz = callbacks.get_evaluation_callback(config, validloader)
CALLBACKS += [model_pred_viz]
if callback_config.use_tensorboard:
CALLBACKS += [tf.keras.callbacks.TensorBoard()]
# Build the model
tf.keras.backend.clear_session()
model = SimpleSupervisedModel(config).get_model()
model.summary()
# Build the pipeline
pipeline = SupervisedPipeline(model, config, class_weights, CALLBACKS)
# Train and Evaluate
pipeline.train_and_evaluate(valid_df, trainloader, validloader)
if __name__ == "__main__":
app.run(main)