Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii Inc. All rights reserved. | |
import pickle | |
from collections import OrderedDict | |
import torch | |
from torch import distributed as dist | |
from torch import nn | |
from .dist import _get_global_gloo_group, get_world_size | |
ASYNC_NORM = ( | |
nn.BatchNorm1d, | |
nn.BatchNorm2d, | |
nn.BatchNorm3d, | |
nn.InstanceNorm1d, | |
nn.InstanceNorm2d, | |
nn.InstanceNorm3d, | |
) | |
__all__ = [ | |
"get_async_norm_states", | |
"pyobj2tensor", | |
"tensor2pyobj", | |
"all_reduce", | |
"all_reduce_norm", | |
] | |
def get_async_norm_states(module): | |
async_norm_states = OrderedDict() | |
for name, child in module.named_modules(): | |
if isinstance(child, ASYNC_NORM): | |
for k, v in child.state_dict().items(): | |
async_norm_states[".".join([name, k])] = v | |
return async_norm_states | |
def pyobj2tensor(pyobj, device="cuda"): | |
"""serialize picklable python object to tensor""" | |
storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) | |
return torch.ByteTensor(storage).to(device=device) | |
def tensor2pyobj(tensor): | |
"""deserialize tensor to picklable python object""" | |
return pickle.loads(tensor.cpu().numpy().tobytes()) | |
def _get_reduce_op(op_name): | |
return { | |
"sum": dist.ReduceOp.SUM, | |
"mean": dist.ReduceOp.SUM, | |
}[op_name.lower()] | |
def all_reduce(py_dict, op="sum", group=None): | |
""" | |
Apply all reduce function for python dict object. | |
NOTE: make sure that every py_dict has the same keys and values are in the same shape. | |
Args: | |
py_dict (dict): dict to apply all reduce op. | |
op (str): operator, could be "sum" or "mean". | |
""" | |
world_size = get_world_size() | |
if world_size == 1: | |
return py_dict | |
if group is None: | |
group = _get_global_gloo_group() | |
if dist.get_world_size(group) == 1: | |
return py_dict | |
# all reduce logic across different devices. | |
py_key = list(py_dict.keys()) | |
py_key_tensor = pyobj2tensor(py_key) | |
dist.broadcast(py_key_tensor, src=0) | |
py_key = tensor2pyobj(py_key_tensor) | |
tensor_shapes = [py_dict[k].shape for k in py_key] | |
tensor_numels = [py_dict[k].numel() for k in py_key] | |
flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) | |
dist.all_reduce(flatten_tensor, op=_get_reduce_op(op)) | |
if op == "mean": | |
flatten_tensor /= world_size | |
split_tensors = [ | |
x.reshape(shape) | |
for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes) | |
] | |
return OrderedDict({k: v for k, v in zip(py_key, split_tensors)}) | |
def all_reduce_norm(module): | |
""" | |
All reduce norm statistics in different devices. | |
""" | |
states = get_async_norm_states(module) | |
states = all_reduce(states, op="mean") | |
module.load_state_dict(states, strict=False) | |