diff --git a/dptb/entrypoints/train.py b/dptb/entrypoints/train.py index a9e5c7d7..c3e50a92 100644 --- a/dptb/entrypoints/train.py +++ b/dptb/entrypoints/train.py @@ -1,7 +1,7 @@ from dptb.nnops.trainer import Trainer from dptb.nn.build import build_model from dptb.data.build import build_dataset -from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer +from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer, TensorBoardMonitor from dptb.plugins.train_logger import Logger from dptb.utils.argcheck import normalize, collect_cutoffs from dptb.plugins.saver import Saver @@ -209,7 +209,9 @@ def train( log_field.append("validation_loss") trainer.register_plugin(TrainLossMonitor()) trainer.register_plugin(LearningRateMonitor()) - trainer.register_plugin(Logger(log_field, + if jdata["train_options"]["use_tensorboard"]: + trainer.register_plugin(TensorBoardMonitor()) + trainer.register_plugin(Logger(log_field, interval=[(jdata["train_options"]["display_freq"], 'iteration'), (1, 'epoch')])) for q in trainer.plugin_queues.values(): diff --git a/dptb/plugins/monitor.py b/dptb/plugins/monitor.py index 449b3838..f7e7a3f5 100644 --- a/dptb/plugins/monitor.py +++ b/dptb/plugins/monitor.py @@ -1,8 +1,10 @@ -from dptb.plugins.base_plugin import Plugin import logging import time + import torch from dptb.data import AtomicData +from dptb.plugins.base_plugin import Plugin +from torch.utils.tensorboard import SummaryWriter log = logging.getLogger(__name__) @@ -140,4 +142,24 @@ def _get_value(self, **kwargs): if kwargs.get('field') == "iteration": return self.trainer.validation(fast=True) else: - return self.trainer.validation() \ No newline at end of file + return self.trainer.validation() + + +class TensorBoardMonitor(Plugin): + def __init__(self): + super(TensorBoardMonitor, self).__init__([(25, 'iteration'), (1, 'epoch')]) + self.writer = SummaryWriter(log_dir='./tensorboard_logs') + + def register(self, trainer): + self.trainer = trainer + + def epoch(self, **kwargs): + epoch = self.trainer.ep + self.writer.add_scalar(f'lr/epoch', self.trainer.stats['lr']['last'], epoch) + self.writer.add_scalar(f'train_loss_last/epoch', self.trainer.stats['train_loss']['last'], epoch) + self.writer.add_scalar(f'train_loss_mean/epoch', self.trainer.stats['train_loss']['epoch_mean'], epoch) + + def iteration(self, **kwargs): + iteration = self.trainer.iter + self.writer.add_scalar(f'lr_iter/iteration', self.trainer.stats['lr']['last'], iteration) + self.writer.add_scalar(f'train_loss_iter/iteration', self.trainer.stats['train_loss']['last'], iteration) diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index e75fd3fb..db81efe8 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -98,6 +98,10 @@ def train_options(): doc_save_freq = "Frequency, or every how many iteration to saved the current model into checkpoints, The name of checkpoint is formulated as `latest|best_dptb|nnsk_b_c_w`. Default: `10`" doc_validation_freq = "Frequency or every how many iteration to do model validation on validation datasets. Default: `10`" doc_display_freq = "Frequency, or every how many iteration to display the training log to screem. Default: `1`" + doc_use_tensorboard = "Set true to use tensorboard. It will record iteration error once every `25` iterations, epoch error once per epoch. " \ + "There are tree types of error will be recorded. `train_loss_iter` is iteration loss, `train_loss_last` is the error of the last iteration in an epoch, `train_loss_mean` is the mean error of all iterations in an epoch." \ + "Learning rates are tracked as well. A folder named `tensorboard_logs` will be created in the working directory. Use `tensorboard --logdir=tensorboard_logs` to view the logs." \ + "Default: `False`" doc_optimizer = "\ The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\ For more information about these optmization algorithm, we refer to:\n\n\ @@ -121,6 +125,7 @@ def train_options(): Argument("save_freq", int, optional=True, default=10, doc=doc_save_freq), Argument("validation_freq", int, optional=True, default=10, doc=doc_validation_freq), Argument("display_freq", int, optional=True, default=1, doc=doc_display_freq), + Argument("use_tensorboard", bool, optional=True, default=False, doc=doc_use_tensorboard), Argument("max_ckpt", int, optional=True, default=4, doc=doc_max_ckpt), loss_options() ] diff --git a/examples/tensorboard/data_10/data.1312.lmdb/data.mdb b/examples/tensorboard/data_10/data.1312.lmdb/data.mdb new file mode 100644 index 00000000..8385af77 Binary files /dev/null and b/examples/tensorboard/data_10/data.1312.lmdb/data.mdb differ diff --git a/examples/tensorboard/data_10/data.1312.lmdb/lock.mdb b/examples/tensorboard/data_10/data.1312.lmdb/lock.mdb new file mode 100644 index 00000000..800aafc0 Binary files /dev/null and b/examples/tensorboard/data_10/data.1312.lmdb/lock.mdb differ diff --git a/examples/tensorboard/input_short.json b/examples/tensorboard/input_short.json new file mode 100644 index 00000000..782147f3 --- /dev/null +++ b/examples/tensorboard/input_short.json @@ -0,0 +1,64 @@ +{ + "common_options": { + "basis": { + "C": "5s4p1d", + "H": "3s1p", + "O": "5s4p1d" + }, + "device": "cuda", + "overlap": true + }, + "model_options": { + "embedding": { + "method": "lem", + "irreps_hidden": "4x0e+4x1o+4x2e+4x3o+4x4e", + "n_layers": 5, + "avg_num_neighbors": 80, + "r_max": { + "C": 7, + "O": 7, + "H": 3 + }, + "tp_radial_emb": true + }, + "prediction": { + "method": "e3tb", + "neurons": [ + 64, + 64 + ] + } + }, + "train_options": { + "num_epoch": 10, + "batch_size": 2, + "optimizer": { + "lr": 0.005, + "type": "Adam" + }, + "lr_scheduler": { + "type": "rop", + "factor": 0.9, + "patience": 50, + "min_lr": 0.000001 + }, + "loss_options": { + "train": { + "method": "hamil_abs" + } + }, + "save_freq": 100, + "validation_freq": 10, + "display_freq": 1, + "use_tensorboard": true + }, + "data_options": { + "train": { + "root": "./data_10", + "prefix": "data", + "type": "LMDBDataset", + "get_Hamiltonian": true, + "get_overlap": true + } + } +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index bf4c6809..12acb46a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ opt-einsum = "3.3.0" h5py = "3.7.0" lmdb = "1.4.1" pyfiglet = "1.0.2" - +tensorboard = "*" [tool.poetry.group.dev.dependencies] pytest = ">=7.2.0" @@ -54,6 +54,7 @@ opt-einsum = "3.3.0" h5py = "3.7.0" lmdb = "1.4.1" pyfiglet = "1.0.2" +tensorboard = "*" [tool.poetry.group.3Dfermi] optional = true