File size: 1,145 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn


class DataParallel(nn.DataParallel):
    """
    Overview:
        A wrapper class for nn.DataParallel.
    Interfaces:
        ``__init__``, ``parameters``
    """

    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        """
        Overview:
            Initialize the DataParallel object.
        Arguments:
            - module (:obj:`nn.Module`): The module to be parallelized.
            - device_ids (:obj:`list`): The list of GPU ids.
            - output_device (:obj:`int`): The output GPU id.
            - dim (:obj:`int`): The dimension to be parallelized.
        """
        super().__init__(module, device_ids=None, output_device=None, dim=0)
        self.module = module

    def parameters(self, recurse: bool = True):
        """
        Overview:
            Return the parameters of the module.
        Arguments:
            - recurse (:obj:`bool`): Whether to return the parameters of the submodules.
        Returns:
            - params (:obj:`generator`): The generator of the parameters.
        """
        return self.module.parameters(recurse=True)