Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Cyclic Learning Rate and Step-wise Learning Rate Scheduler #213

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dptb/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def train(
trainer.register_plugin(TrainLossMonitor())
trainer.register_plugin(LearningRateMonitor())
if jdata["train_options"]["use_tensorboard"]:
trainer.register_plugin(TensorBoardMonitor())
trainer.register_plugin(TensorBoardMonitor(interval=[(jdata["train_options"]["display_freq"], 'iteration'), (1, 'epoch')]))
trainer.register_plugin(Logger(log_field,
interval=[(jdata["train_options"]["display_freq"], 'iteration'), (1, 'epoch')]))

Expand Down
11 changes: 7 additions & 4 deletions dptb/nnops/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
'''
self.iter = 1
self.ep = 1
self.update_lr_per_step_flag = False

@abstractmethod
def restart(self, checkpoint):
Expand All @@ -52,10 +53,12 @@ def run(self, epochs=1):
# run plugins of epoch events.
self.call_plugins(queue_name='epoch', time=i)

if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step(self.stats["train_loss"]["epoch_mean"])
else:
self.lr_scheduler.step() # modify the lr at each epoch (should we add it to pluggins so we could record the lr scheduler process?)
if not self.update_lr_per_step_flag:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step(self.stats["train_loss"]["epoch_mean"])
else:
self.lr_scheduler.step() # modify the lr at each epoch (should we add it to pluggins so we could record the lr scheduler process? update 0927, this has been done in tensorboard monitor.)

self.update()
self.ep += 1

Expand Down
6 changes: 6 additions & 0 deletions dptb/nnops/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self.model = model.to(self.device)
self.optimizer = get_optimizer(model_param=self.model.parameters(), **train_options["optimizer"])
self.lr_scheduler = get_lr_scheduler(optimizer=self.optimizer, **train_options["lr_scheduler"]) # add optmizer
self.update_lr_per_step_flag = train_options["update_lr_per_step_flag"]
self.common_options = common_options
self.train_options = train_options

Expand Down Expand Up @@ -129,6 +130,11 @@ def iteration(self, batch, ref_batch=None):
loss.backward()
#TODO: add clip large gradient
self.optimizer.step()
if self.update_lr_per_step_flag:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step(self.stats["train_loss"]["epoch_mean"])
else:
self.lr_scheduler.step()

state = {'field':'iteration', "train_loss": loss.detach(), "lr": self.optimizer.state_dict()["param_groups"][0]['lr']}
self.call_plugins(queue_name='iteration', time=self.iter, **state)
Expand Down
4 changes: 2 additions & 2 deletions dptb/plugins/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def _get_value(self, **kwargs):


class TensorBoardMonitor(Plugin):
def __init__(self):
super(TensorBoardMonitor, self).__init__([(25, 'iteration'), (1, 'epoch')])
def __init__(self, interval):
super(TensorBoardMonitor, self).__init__(interval=interval)
self.writer = SummaryWriter(log_dir='./tensorboard_logs')

def register(self, trainer):
Expand Down
39 changes: 37 additions & 2 deletions dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ def train_options():
"There are tree types of error will be recorded. `train_loss_iter` is iteration loss, `train_loss_last` is the error of the last iteration in an epoch, `train_loss_mean` is the mean error of all iterations in an epoch." \
"Learning rates are tracked as well. A folder named `tensorboard_logs` will be created in the working directory. Use `tensorboard --logdir=tensorboard_logs` to view the logs." \
"Default: `False`"
update_lr_per_step_flag = "Set true to update learning rate per-step. By default, it's false."

doc_optimizer = "\
The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\
For more information about these optmization algorithm, we refer to:\n\n\
- `Adam`: [Adam: A Method for Stochastic Optimization.](https://arxiv.org/abs/1412.6980)\n\n\
- `SGD`: [Stochastic Gradient Descent.](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)\n\n\
- `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\
"
doc_lr_scheduler = "The learning rate scheduler tools settings, the lr scheduler is used to scales down the learning rate during the training process. Proper setting can make the training more stable and efficient. The supported lr schedular includes: `Exponential Decaying (exp)`, `Linear multiplication (linear)`"
doc_lr_scheduler = "The learning rate scheduler tools settings, the lr scheduler is used to scales down the learning rate during the training process. Proper setting can make the training more stable and efficient. The supported lr schedular includes: `Exponential Decaying (exp)`, `Linear multiplication (linear)`, `Reduce on pleatau (rop)`, `Cyclic learning rate (cyclic)`. See more documentation on Pytorch. "
doc_batch_size = "The batch size used in training, Default: 1"
doc_ref_batch_size = "The batch size used in reference data, Default: 1"
doc_val_batch_size = "The batch size used in validation data, Default: 1"
Expand All @@ -126,6 +128,7 @@ def train_options():
Argument("validation_freq", int, optional=True, default=10, doc=doc_validation_freq),
Argument("display_freq", int, optional=True, default=1, doc=doc_display_freq),
Argument("use_tensorboard", bool, optional=True, default=False, doc=doc_use_tensorboard),
Argument("update_lr_per_step_flag", bool, optional=True, default=False, doc=update_lr_per_step_flag),
Argument("max_ckpt", int, optional=True, default=4, doc=doc_max_ckpt),
loss_options()
]
Expand Down Expand Up @@ -235,14 +238,46 @@ def ReduceOnPlateau():
Argument("eps", float, optional=True, default=1e-8, doc=doc_eps),
]

def CyclicLR():
doc_base_lr = "Initial learning rate which is the lower boundary in the cycle for each parameter group."
doc_max_lr = "Upper learning rate boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_lr - base_lr). The lr at any cycle is the sum of base_lr and some scaling of the amplitude; therefore max_lr may not actually be reached depending on scaling function."
doc_step_size_up = "Number of training iterations in the increasing half of a cycle. Default: 2000"
doc_step_size_down = "Number of training iterations in the decreasing half of a cycle. If step_size_down is None, it is set to step_size_up. Default: None"
doc_mode = "One of {triangular, triangular2, exp_range}. Values correspond to policies detailed above. If scale_fn is not None, this argument is ignored. Default: 'triangular'"
doc_gamma = "Constant in 'exp_range' scaling function: gamma**(cycle iterations) Default: 1.0"
doc_scale_fn = "Custom scaling policy defined by a single argument lambda function, where 0 <= scale_fn(x) <= 1 for all x >= 0. If specified, then 'mode' is ignored. Default: None"
doc_scale_mode = "{'cycle', 'iterations'}. Defines whether scale_fn is evaluated on cycle number or cycle iterations (training iterations since start of cycle). Default: 'cycle'"
doc_cycle_momentum = "If True, momentum is cycled inversely to learning rate between 'base_momentum' and 'max_momentum'. Default: True"
doc_base_momentum = "Lower momentum boundaries in the cycle for each parameter group. Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is 'max_momentum' and learning rate is 'base_lr'. Default: 0.8"
doc_max_momentum = "Upper momentum boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_momentum - base_momentum). The momentum at any cycle is the difference of max_momentum and some scaling of the amplitude; therefore base_momentum may not actually be reached depending on scaling function. Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is 'max_momentum' and learning rate is 'base_lr'. Default: 0.9"
doc_last_epoch = "The index of the last batch. This parameter is used when resuming a training job. Since step() should be invoked after each batch instead of after each epoch, this number represents the total number of batches computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1"
doc_verbose = "If True, prints a message to stdout for each update. Default: False."

return [
Argument("base_lr", [float, list], optional=False, doc=doc_base_lr),
Argument("max_lr", [float, list], optional=False, doc=doc_max_lr),
Argument("step_size_up", int, optional=True, default=10, doc=doc_step_size_up),
Argument("step_size_down", int, optional=True, default=40, doc=doc_step_size_down),
Argument("mode", str, optional=True, default="exp_range", doc=doc_mode),
Argument("gamma", float, optional=True, default=1.0, doc=doc_gamma),
Argument("scale_fn", object, optional=True, default=None, doc=doc_scale_fn),
Argument("scale_mode", str, optional=True, default="cycle", doc=doc_scale_mode),
Argument("cycle_momentum", bool, optional=True, default=False, doc=doc_cycle_momentum),
Argument("base_momentum", [float, list], optional=True, default=0.8, doc=doc_base_momentum),
Argument("max_momentum", [float, list], optional=True, default=0.9, doc=doc_max_momentum),
Argument("last_epoch", int, optional=True, default=-1, doc=doc_last_epoch),
Argument("verbose", [bool, str], optional=True, default="deprecated", doc=doc_verbose)
]


def lr_scheduler():
doc_type = "select type of lr_scheduler, support type includes `exp`, `linear`"

return Variant("type", [
Argument("exp", dict, ExponentialLR()),
Argument("linear", dict, LinearLR()),
Argument("rop", dict, ReduceOnPlateau(), doc="rop: reduce on plateau")
Argument("rop", dict, ReduceOnPlateau(), doc="rop: reduce on plateau"),
Argument("cyclic", dict, CyclicLR(), doc="Cyclic learning rate")
],optional=True, default_tag="exp", doc=doc_type)


Expand Down
4 changes: 3 additions & 1 deletion dptb/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ def get_lr_scheduler(type: str, optimizer: optim.Optimizer, **sch_options):
scheduler = optim.lr_scheduler.LinearLR(optimizer=optimizer, **sch_options)
elif type == "rop":
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, **sch_options)
elif type == "cyclic":
scheduler = optim.lr_scheduler.CyclicLR(optimizer=optimizer, **sch_options)
else:
raise RuntimeError("Scheduler should be exp/linear/rop..., not {}".format(type))
raise RuntimeError("Scheduler should be exp/linear/rop/cyclic..., not {}".format(type))

return scheduler

Expand Down
68 changes: 68 additions & 0 deletions examples/clr_and_per_iter_update/clr_per_epoch_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{
"common_options": {
"basis": {
"C": "5s4p1d",
"H": "3s1p",
"O": "5s4p1d"
},
"device": "cuda",
"overlap": true
},
"model_options": {
"embedding": {
"method": "lem",
"irreps_hidden": "4x0e+4x1o+4x2e+4x3o+4x4e",
"n_layers": 5,
"avg_num_neighbors": 80,
"r_max": {
"C": 7,
"O": 7,
"H": 3
},
"tp_radial_emb": true
},
"prediction": {
"method": "e3tb",
"neurons": [
64,
64
]
}
},
"train_options": {
"num_epoch": 10,
"batch_size": 1,
"optimizer": {
"lr": 0.005,
"type": "Adam"
},
"lr_scheduler": {
"type": "cyclic",
"max_lr": 0.005,
"base_lr": 1e-06,
"step_size_up": 3,
"step_size_down": 7,
"mode": "exp_range",
"scale_mode": "cycle"
},
"loss_options": {
"train": {
"method": "hamil_abs"
}
},
"save_freq": 100,
"validation_freq": 10,
"display_freq": 1,
"use_tensorboard": true,
"update_lr_per_step_flag": false
},
"data_options": {
"train": {
"root": "./data_10",
"prefix": "data",
"type": "LMDBDataset",
"get_Hamiltonian": true,
"get_overlap": true
}
}
}
68 changes: 68 additions & 0 deletions examples/clr_and_per_iter_update/clr_per_iter_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{
"common_options": {
"basis": {
"C": "5s4p1d",
"H": "3s1p",
"O": "5s4p1d"
},
"device": "cuda",
"overlap": true
},
"model_options": {
"embedding": {
"method": "lem",
"irreps_hidden": "4x0e+4x1o+4x2e+4x3o+4x4e",
"n_layers": 5,
"avg_num_neighbors": 80,
"r_max": {
"C": 7,
"O": 7,
"H": 3
},
"tp_radial_emb": true
},
"prediction": {
"method": "e3tb",
"neurons": [
64,
64
]
}
},
"train_options": {
"num_epoch": 10,
"batch_size": 1,
"optimizer": {
"lr": 0.005,
"type": "Adam"
},
"lr_scheduler": {
"type": "cyclic",
"max_lr": 0.005,
"base_lr": 1e-06,
"step_size_up": 3,
"step_size_down": 7,
"mode": "exp_range",
"scale_mode": "iterations"
},
"loss_options": {
"train": {
"method": "hamil_abs"
}
},
"save_freq": 100,
"validation_freq": 10,
"display_freq": 1,
"use_tensorboard": true,
"update_lr_per_step_flag": true
},
"data_options": {
"train": {
"root": "./data_10",
"prefix": "data",
"type": "LMDBDataset",
"get_Hamiltonian": true,
"get_overlap": true
}
}
}
Binary file not shown.
Binary file not shown.
63 changes: 63 additions & 0 deletions examples/clr_and_per_iter_update/exp_per_epoch_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"common_options": {
"basis": {
"C": "5s4p1d",
"H": "3s1p",
"O": "5s4p1d"
},
"device": "cuda",
"overlap": true
},
"model_options": {
"embedding": {
"method": "lem",
"irreps_hidden": "4x0e+4x1o+4x2e+4x3o+4x4e",
"n_layers": 5,
"avg_num_neighbors": 80,
"r_max": {
"C": 7,
"O": 7,
"H": 3
},
"tp_radial_emb": true
},
"prediction": {
"method": "e3tb",
"neurons": [
64,
64
]
}
},
"train_options": {
"num_epoch": 10,
"batch_size": 1,
"optimizer": {
"lr": 0.005,
"type": "Adam"
},
"lr_scheduler": {
"type": "exp",
"gamma": 0.8
},
"loss_options": {
"train": {
"method": "hamil_abs"
}
},
"save_freq": 100,
"validation_freq": 10,
"display_freq": 1,
"use_tensorboard": true,
"update_lr_per_step_flag": false
},
"data_options": {
"train": {
"root": "./data_10",
"prefix": "data",
"type": "LMDBDataset",
"get_Hamiltonian": true,
"get_overlap": true
}
}
}
Loading