Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tensorboard support #206

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions dptb/entrypoints/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dptb.nnops.trainer import Trainer
from dptb.nn.build import build_model
from dptb.data.build import build_dataset
from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer
from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer, TensorBoardMonitor
from dptb.plugins.train_logger import Logger
from dptb.utils.argcheck import normalize, collect_cutoffs
from dptb.plugins.saver import Saver
Expand Down Expand Up @@ -209,7 +209,9 @@ def train(
log_field.append("validation_loss")
trainer.register_plugin(TrainLossMonitor())
trainer.register_plugin(LearningRateMonitor())
trainer.register_plugin(Logger(log_field,
if jdata["train_options"]["use_tensorboard"]:
trainer.register_plugin(TensorBoardMonitor())
trainer.register_plugin(Logger(log_field,
interval=[(jdata["train_options"]["display_freq"], 'iteration'), (1, 'epoch')]))

for q in trainer.plugin_queues.values():
Expand Down
26 changes: 24 additions & 2 deletions dptb/plugins/monitor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dptb.plugins.base_plugin import Plugin
import logging
import time

import torch
from dptb.data import AtomicData
from dptb.plugins.base_plugin import Plugin
from torch.utils.tensorboard import SummaryWriter

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -140,4 +142,24 @@ def _get_value(self, **kwargs):
if kwargs.get('field') == "iteration":
return self.trainer.validation(fast=True)
else:
return self.trainer.validation()
return self.trainer.validation()


class TensorBoardMonitor(Plugin):
def __init__(self):
super(TensorBoardMonitor, self).__init__([(25, 'iteration'), (1, 'epoch')])
self.writer = SummaryWriter(log_dir='./tensorboard_logs')

def register(self, trainer):
self.trainer = trainer

def epoch(self, **kwargs):
epoch = self.trainer.ep
self.writer.add_scalar(f'lr/epoch', self.trainer.stats['lr']['last'], epoch)
self.writer.add_scalar(f'train_loss_last/epoch', self.trainer.stats['train_loss']['last'], epoch)
self.writer.add_scalar(f'train_loss_mean/epoch', self.trainer.stats['train_loss']['epoch_mean'], epoch)

def iteration(self, **kwargs):
iteration = self.trainer.iter
self.writer.add_scalar(f'lr_iter/iteration', self.trainer.stats['lr']['last'], iteration)
self.writer.add_scalar(f'train_loss_iter/iteration', self.trainer.stats['train_loss']['last'], iteration)
5 changes: 5 additions & 0 deletions dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def train_options():
doc_save_freq = "Frequency, or every how many iteration to saved the current model into checkpoints, The name of checkpoint is formulated as `latest|best_dptb|nnsk_b<bond_cutoff>_c<sk_cutoff>_w<sk_decay_w>`. Default: `10`"
doc_validation_freq = "Frequency or every how many iteration to do model validation on validation datasets. Default: `10`"
doc_display_freq = "Frequency, or every how many iteration to display the training log to screem. Default: `1`"
doc_use_tensorboard = "Set true to use tensorboard. It will record iteration error once every `25` iterations, epoch error once per epoch. " \
"There are tree types of error will be recorded. `train_loss_iter` is iteration loss, `train_loss_last` is the error of the last iteration in an epoch, `train_loss_mean` is the mean error of all iterations in an epoch." \
"Learning rates are tracked as well. A folder named `tensorboard_logs` will be created in the working directory. Use `tensorboard --logdir=tensorboard_logs` to view the logs." \
"Default: `False`"
doc_optimizer = "\
The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\
For more information about these optmization algorithm, we refer to:\n\n\
Expand All @@ -121,6 +125,7 @@ def train_options():
Argument("save_freq", int, optional=True, default=10, doc=doc_save_freq),
Argument("validation_freq", int, optional=True, default=10, doc=doc_validation_freq),
Argument("display_freq", int, optional=True, default=1, doc=doc_display_freq),
Argument("use_tensorboard", bool, optional=True, default=False, doc=doc_use_tensorboard),
Argument("max_ckpt", int, optional=True, default=4, doc=doc_max_ckpt),
loss_options()
]
Expand Down
Binary file not shown.
Binary file not shown.
64 changes: 64 additions & 0 deletions examples/tensorboard/input_short.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"common_options": {
"basis": {
"C": "5s4p1d",
"H": "3s1p",
"O": "5s4p1d"
},
"device": "cuda",
"overlap": true
},
"model_options": {
"embedding": {
"method": "lem",
"irreps_hidden": "4x0e+4x1o+4x2e+4x3o+4x4e",
"n_layers": 5,
"avg_num_neighbors": 80,
"r_max": {
"C": 7,
"O": 7,
"H": 3
},
"tp_radial_emb": true
},
"prediction": {
"method": "e3tb",
"neurons": [
64,
64
]
}
},
"train_options": {
"num_epoch": 10,
"batch_size": 2,
"optimizer": {
"lr": 0.005,
"type": "Adam"
},
"lr_scheduler": {
"type": "rop",
"factor": 0.9,
"patience": 50,
"min_lr": 0.000001
},
"loss_options": {
"train": {
"method": "hamil_abs"
}
},
"save_freq": 100,
"validation_freq": 10,
"display_freq": 1,
"use_tensorboard": true
},
"data_options": {
"train": {
"root": "./data_10",
"prefix": "data",
"type": "LMDBDataset",
"get_Hamiltonian": true,
"get_overlap": true
}
}
}
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ opt-einsum = "3.3.0"
h5py = "3.7.0"
lmdb = "1.4.1"
pyfiglet = "1.0.2"

tensorboard = "*"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.2.0"
Expand All @@ -54,6 +54,7 @@ opt-einsum = "3.3.0"
h5py = "3.7.0"
lmdb = "1.4.1"
pyfiglet = "1.0.2"
tensorboard = "*"

[tool.poetry.group.3Dfermi]
optional = true
Expand Down