import torch import torch.nn as nn class DataParallel(nn.DataParallel): """ Overview: A wrapper class for nn.DataParallel. Interfaces: ``__init__``, ``parameters`` """ def __init__(self, module, device_ids=None, output_device=None, dim=0): """ Overview: Initialize the DataParallel object. Arguments: - module (:obj:`nn.Module`): The module to be parallelized. - device_ids (:obj:`list`): The list of GPU ids. - output_device (:obj:`int`): The output GPU id. - dim (:obj:`int`): The dimension to be parallelized. """ super().__init__(module, device_ids=None, output_device=None, dim=0) self.module = module def parameters(self, recurse: bool = True): """ Overview: Return the parameters of the module. Arguments: - recurse (:obj:`bool`): Whether to return the parameters of the submodules. Returns: - params (:obj:`generator`): The generator of the parameters. """ return self.module.parameters(recurse=True)