import torch.nn as nn | |
import torch | |
import torch.distributed as dist | |
class GlobalAvgPool2d(nn.Module): | |
def __init__(self): | |
"""Global average pooling over the input's spatial dimensions""" | |
super(GlobalAvgPool2d, self).__init__() | |
def forward(self, inputs): | |
in_size = inputs.size() | |
return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) | |
class SingleGPU(nn.Module): | |
def __init__(self, module): | |
super(SingleGPU, self).__init__() | |
self.module=module | |
def forward(self, input): | |
return self.module(input.cuda(non_blocking=True)) | |