Spaces:
Runtime error
Runtime error
File size: 4,387 Bytes
0b7b08a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Code are based on
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Megvii, Inc. and its affiliates.
import sys
from datetime import timedelta
from loguru import logger
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import yolox.utils.dist as comm
__all__ = ["launch"]
DEFAULT_TIMEOUT = timedelta(minutes=30)
def _find_free_port():
"""
Find an available port of current machine / node.
"""
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def launch(
main_func,
num_gpus_per_machine,
num_machines=1,
machine_rank=0,
backend="nccl",
dist_url=None,
args=(),
timeout=DEFAULT_TIMEOUT,
):
"""
Args:
main_func: a function that will be called by `main_func(*args)`
num_machines (int): the total number of machines
machine_rank (int): the rank of this machine (one per machine)
dist_url (str): url to connect to for distributed training, including protocol
e.g. "tcp://127.0.0.1:8686".
Can be set to auto to automatically select a free port on localhost
args (tuple): arguments passed to main_func
"""
world_size = num_machines * num_gpus_per_machine
if world_size > 1:
# https://github.com/pytorch/pytorch/pull/14391
# TODO prctl in spawned processes
if dist_url == "auto":
assert (
num_machines == 1
), "dist_url=auto cannot work with distributed training."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
start_method = "spawn"
cache = vars(args[1]).get("cache", False)
# To use numpy memmap for caching image into RAM, we have to use fork method
if cache:
assert sys.platform != "win32", (
"As Windows platform doesn't support fork method, "
"do not add --cache in your training command."
)
start_method = "fork"
mp.start_processes(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
backend,
dist_url,
args,
),
daemon=False,
start_method=start_method,
)
else:
main_func(*args)
def _distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
backend,
dist_url,
args,
timeout=DEFAULT_TIMEOUT,
):
assert (
torch.cuda.is_available()
), "cuda is not available. Please check your installation."
global_rank = machine_rank * num_gpus_per_machine + local_rank
logger.info("Rank {} initialization finished.".format(global_rank))
try:
dist.init_process_group(
backend=backend,
init_method=dist_url,
world_size=world_size,
rank=global_rank,
timeout=timeout,
)
except Exception:
logger.error("Process group URL: {}".format(dist_url))
raise
# Setup the local process group (which contains ranks within the same machine)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg
# synchronize is needed here to prevent a possible timeout after calling init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
assert num_gpus_per_machine <= torch.cuda.device_count()
torch.cuda.set_device(local_rank)
main_func(*args)
|