Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import subprocess | |
def setup_for_distributed(is_master): | |
""" | |
This function disables printing when not in master process | |
""" | |
import builtins as __builtin__ | |
builtin_print = __builtin__.print | |
def print(*args, **kwargs): | |
force = kwargs.pop('force', False) | |
if is_master or force: | |
builtin_print(*args, **kwargs) | |
__builtin__.print = print | |
def init_distributed_mode(args): | |
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
args.rank = int(os.environ["RANK"]) | |
args.world_size = int(os.environ['WORLD_SIZE']) | |
args.gpu = int(os.environ['LOCAL_RANK']) | |
args.dist_url = 'env://' | |
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) | |
elif 'SLURM_PROCID' in os.environ: | |
proc_id = int(os.environ['SLURM_PROCID']) | |
ntasks = int(os.environ['SLURM_NTASKS']) | |
node_list = os.environ['SLURM_NODELIST'] | |
num_gpus = torch.cuda.device_count() | |
addr = subprocess.getoutput( | |
'scontrol show hostname {} | head -n1'.format(node_list)) | |
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') | |
os.environ['MASTER_ADDR'] = addr | |
os.environ['WORLD_SIZE'] = str(ntasks) | |
os.environ['RANK'] = str(proc_id) | |
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) | |
os.environ['LOCAL_SIZE'] = str(num_gpus) | |
args.dist_url = 'env://' | |
args.world_size = ntasks | |
args.rank = proc_id | |
args.gpu = proc_id % num_gpus | |
else: | |
print('Not using distributed mode') | |
args.distributed = False | |
return | |
args.distributed = True | |
torch.cuda.set_device(args.gpu) | |
args.dist_backend = 'nccl' | |
print('| distributed init (rank {}): {}'.format( | |
args.rank, args.dist_url), flush=True) | |
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, | |
world_size=args.world_size, rank=args.rank) | |
torch.distributed.barrier() | |
setup_for_distributed(args.rank == 0) | |