from typing import Dict, List import torch if torch.__version__ < '1.9': Iterable = torch._six.container_abcs.Iterable else: import collections Iterable = collections.abc.Iterable from torch.cuda.amp import GradScaler class _MultiDeviceReplicator(object): """ Lazily serves copies of a tensor to requested devices. Copies are cached per-device. """ def __init__(self, master_tensor: torch.Tensor) -> None: assert master_tensor.is_cuda self.master = master_tensor self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} def get(self, device) -> torch.Tensor: retval = self._per_device_tensors.get(device, None) if retval is None: retval = self.master.to(device=device, non_blocking=True, copy=True) self._per_device_tensors[device] = retval return retval class MaxClipGradScaler(GradScaler): def __init__(self, init_scale, max_scale: float, growth_interval=100): GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) self.max_scale = max_scale def scale_clip(self): if self.get_scale() == self.max_scale: self.set_growth_factor(1) elif self.get_scale() < self.max_scale: self.set_growth_factor(2) elif self.get_scale() > self.max_scale: self._scale.fill_(self.max_scale) self.set_growth_factor(1) def scale(self, outputs): """ Multiplies ('scales') a tensor or list of tensors by the scale factor. Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified. Arguments: outputs (Tensor or iterable of Tensors): Outputs to scale. """ if not self._enabled: return outputs self.scale_clip() # Short-circuit for the common case. if isinstance(outputs, torch.Tensor): assert outputs.is_cuda if self._scale is None: self._lazy_init_scale_growth_tracker(outputs.device) assert self._scale is not None return outputs * self._scale.to(device=outputs.device, non_blocking=True) # Invoke the more complex machinery only if we're treating multiple outputs. stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale def apply_scale(val): if isinstance(val, torch.Tensor): assert val.is_cuda if len(stash) == 0: if self._scale is None: self._lazy_init_scale_growth_tracker(val.device) assert self._scale is not None stash.append(_MultiDeviceReplicator(self._scale)) return val * stash[0].get(val.device) elif isinstance(val, Iterable): iterable = map(apply_scale, val) if isinstance(val, list) or isinstance(val, tuple): return type(val)(iterable) else: return iterable else: raise ValueError("outputs must be a Tensor or an iterable of Tensors") return apply_scale(outputs)