-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
36 lines (30 loc) · 1.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import argparse
import torch
import dist_utils
import dist_train
import torch.distributed as dist
def process_wrapper(rank, args, func):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
os.environ['NCCL_SOCKET_IFNAME'] = 'lo'
# os.environ['NCCL_DEBUG']='INFO'
# os.environ['NCCL_DEBUG_SUBSYS']='ALL'
# os.environ['NCCL_P2P_DISABLE']='1'
# os.environ['NCCL_ALGO'] = 'Ring'
# os.environ['NCCL_MIN_NCHANNELS'] = '1'
# os.environ['NCCL_MAX_NCHANNELS'] = '1'
env = dist_utils.DistEnv(rank, args.nprocs, args.backend)
env.half_enabled = True
env.csr_enabled = True
func(env, args)
if __name__ == "__main__":
num_GPUs = torch.cuda.device_count()
parser = argparse.ArgumentParser()
parser.add_argument("--nprocs", type=int, default=num_GPUs if num_GPUs>1 else 8)
parser.add_argument("--epoch", type=int, default=20)
parser.add_argument("--backend", type=str, default='nccl' if num_GPUs>1 else 'gloo')
parser.add_argument("--dataset", type=str, default='reddit')
args = parser.parse_args()
process_args = (args, dist_train.main)
torch.multiprocessing.spawn(process_wrapper, process_args, args.nprocs)