import os import torch import subprocess def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def init_distributed_mode(args): if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) args.dist_url = 'env://' os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) elif 'SLURM_PROCID' in os.environ: proc_id = int(os.environ['SLURM_PROCID']) ntasks = int(os.environ['SLURM_NTASKS']) node_list = os.environ['SLURM_NODELIST'] num_gpus = torch.cuda.device_count() addr = subprocess.getoutput( 'scontrol show hostname {} | head -n1'.format(node_list)) os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') os.environ['MASTER_ADDR'] = addr os.environ['WORLD_SIZE'] = str(ntasks) os.environ['RANK'] = str(proc_id) os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) os.environ['LOCAL_SIZE'] = str(num_gpus) args.dist_url = 'env://' args.world_size = ntasks args.rank = proc_id args.gpu = proc_id % num_gpus else: print('Not using distributed mode') args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}'.format( args.rank, args.dist_url), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() setup_for_distributed(args.rank == 0)