-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_hydra.py
47 lines (37 loc) · 1.46 KB
/
run_hydra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from typing import Optional, Literal
import pytorch_lightning as pl
import hydra
from hydra.utils import instantiate
from hydra_plugins.hydra_optuna_pruning_sweeper import trial_provider
from omegaconf import DictConfig
from optuna import Trial
from run import train_nn
@hydra.main(version_base='1.3', config_path='config_hydra', config_name='config')
def hydra_train_nn(cfg: DictConfig) -> Optional[float]:
if cfg.get('seed'):
pl.seed_everything(cfg.seed, workers=True)
# Inject trial in `PyTorchLightningPruningCallback`
callbacks = []
for cb_cfg in cfg.trainer.callbacks:
if 'PruningCallback' in cb_cfg._target_:
trial = trial_provider.trial
assert isinstance(trial, Trial)
cb = instantiate(cb_cfg, trial=trial)
else:
cb = instantiate(cb_cfg)
callbacks.append(cb)
del cfg.trainer._callback_dict
trainer: pl.Trainer = instantiate(cfg.trainer, callbacks=callbacks)
task: pl.LightningModule = instantiate(cfg.task)
datamodule: pl.LightningDataModule = instantiate(cfg.datamodule)
compile_mode: Literal[False] | str = cfg.compile_mode
auto_lr_find: bool = cfg.auto_lr_find
monitor: str = cfg.monitor
test: bool = cfg.test
return train_nn(
trainer=trainer, task=task, datamodule=datamodule,
compile_mode=compile_mode, auto_lr_find=auto_lr_find,
monitor=monitor, test=test
)
if __name__ == '__main__':
hydra_train_nn()