-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_dp.py
33 lines (22 loc) · 804 Bytes
/
run_dp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import sys
import importlib
from utils import get_conf
from mpi4py import MPI
nproc = MPI.COMM_WORLD.Get_size() # Size of communicator
iproc = MPI.COMM_WORLD.Get_rank() # Ranks in communicator
inode = MPI.Get_processor_name() # Node where this MPI process runs
# dynamically import the experiment
experiment = sys.argv[1].split("/")[1].split(".")[0]
module = importlib.import_module("." + experiment, package="experiments")
exp = getattr(module, "Experiment")
cnf = get_conf("conf/main.yaml")
if cnf.wandb.dryrun:
os.environ["WANDB_MODE"] = "dryrun"
os.environ["WANDB_DISABLE_CODE"] = "true"
# set the experiment seeds according to the number of the process
cnf.env.torch_seed += iproc
cnf.env.np_seed += iproc
experiment = exp(cnf, iproc)
experiment.run()
MPI.Finalize()