zjowowen's picture
init space
079c32c
raw
history blame
5.62 kB
import math
from collections.abc import Callable
import torch
import torch.nn as nn
class Lambda(nn.Module):
"""
Overview:
A custom lambda module for constructing custom layers.
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self, f: Callable):
"""
Overview:
Initialize the lambda module with a given function.
Arguments:
- f (:obj:`Callable`): a python function
"""
super(Lambda, self).__init__()
self.f = f
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute the function of the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
"""
return self.f(x)
class GLU(nn.Module):
"""
Overview:
Gating Linear Unit (GLU), a specific type of activation function, which is first proposed in
[Language Modeling with Gated Convolutional Networks](https://arxiv.org/pdf/1612.08083.pdf).
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None:
"""
Overview:
Initialize the GLU module.
Arguments:
- input_dim (:obj:`int`): The dimension of the input tensor.
- output_dim (:obj:`int`): The dimension of the output tensor.
- context_dim (:obj:`int`): The dimension of the context tensor.
- input_type (:obj:`str`): The type of input, now supports ['fc', 'conv2d']
"""
super(GLU, self).__init__()
assert (input_type in ['fc', 'conv2d'])
if input_type == 'fc':
self.layer1 = nn.Linear(context_dim, input_dim)
self.layer2 = nn.Linear(input_dim, output_dim)
elif input_type == 'conv2d':
self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0)
self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0)
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute the GLU transformation of the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
- context (:obj:`torch.Tensor`): The context tensor.
Returns:
- x (:obj:`torch.Tensor`): The output tensor after GLU transformation.
"""
gate = self.layer1(context)
gate = torch.sigmoid(gate)
x = gate * x
x = self.layer2(x)
return x
class Swish(nn.Module):
"""
Overview:
Swish activation function, which is a smooth, non-monotonic activation function. For more details, please refer
to [Searching for Activation Functions](https://arxiv.org/pdf/1710.05941.pdf).
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self):
"""
Overview:
Initialize the Swish module.
"""
super(Swish, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute the Swish transformation of the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- x (:obj:`torch.Tensor`): The output tensor after Swish transformation.
"""
return x * torch.sigmoid(x)
class GELU(nn.Module):
"""
Overview:
Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT.
For more details, please refer to the original paper: https://arxiv.org/pdf/1606.08415.pdf.
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self):
"""
Overview:
Initialize the GELU module.
"""
super(GELU, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute the GELU transformation of the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- x (:obj:`torch.Tensor`): The output tensor after GELU transformation.
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def build_activation(activation: str, inplace: bool = None) -> nn.Module:
"""
Overview:
Build and return the activation module according to the given type.
Arguments:
- activation (:obj:`str`): The type of activation module, now supports \
['relu', 'glu', 'prelu', 'swish', 'gelu', 'tanh', 'sigmoid', 'softplus', 'elu', 'square', 'identity'].
- inplace (Optional[:obj:`bool`): Execute the operation in-place in activation, defaults to None.
Returns:
- act_func (:obj:`nn.module`): The corresponding activation module.
"""
if inplace is not None:
assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation)
else:
inplace = False
act_func = {
'relu': nn.ReLU(inplace=inplace),
'glu': GLU,
'prelu': nn.PReLU(),
'swish': Swish(),
'gelu': GELU(),
"tanh": nn.Tanh(),
"sigmoid": nn.Sigmoid(),
"softplus": nn.Softplus(),
"elu": nn.ELU(),
"square": Lambda(lambda x: x ** 2),
"identity": Lambda(lambda x: x),
}
if activation.lower() in act_func.keys():
return act_func[activation]
else:
raise KeyError("invalid key for activation: {}".format(activation))