File size: 1,145 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn as nn
class DataParallel(nn.DataParallel):
"""
Overview:
A wrapper class for nn.DataParallel.
Interfaces:
``__init__``, ``parameters``
"""
def __init__(self, module, device_ids=None, output_device=None, dim=0):
"""
Overview:
Initialize the DataParallel object.
Arguments:
- module (:obj:`nn.Module`): The module to be parallelized.
- device_ids (:obj:`list`): The list of GPU ids.
- output_device (:obj:`int`): The output GPU id.
- dim (:obj:`int`): The dimension to be parallelized.
"""
super().__init__(module, device_ids=None, output_device=None, dim=0)
self.module = module
def parameters(self, recurse: bool = True):
"""
Overview:
Return the parameters of the module.
Arguments:
- recurse (:obj:`bool`): Whether to return the parameters of the submodules.
Returns:
- params (:obj:`generator`): The generator of the parameters.
"""
return self.module.parameters(recurse=True)
|