zjowowen's picture
init space
079c32c
raw
history blame
28.8 kB
import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import distributions as torchd
from ding.torch_utils import MLP
from ding.rl_utils import symlog, inv_symlog
class Conv2dSame(torch.nn.Conv2d):
"""
Overview:
Conv2dSame Network for dreamerv3.
Interfaces:
``__init__``, ``forward``
"""
def calc_same_pad(self, i, k, s, d):
"""
Overview:
Calculate the same padding size.
Arguments:
- i (:obj:`int`): Input size.
- k (:obj:`int`): Kernel size.
- s (:obj:`int`): Stride size.
- d (:obj:`int`): Dilation size.
"""
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def forward(self, x):
"""
Overview:
compute the forward of Conv2dSame.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
"""
ih, iw = x.size()[-2:]
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
ret = F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return ret
class DreamerLayerNorm(nn.Module):
"""
Overview:
DreamerLayerNorm Network for dreamerv3.
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, ch, eps=1e-03):
"""
Overview:
Init the DreamerLayerNorm class.
Arguments:
- ch (:obj:`int`): Input channel.
- eps (:obj:`float`): Epsilon.
"""
super(DreamerLayerNorm, self).__init__()
self.norm = torch.nn.LayerNorm(ch, eps=eps)
def forward(self, x):
"""
Overview:
compute the forward of DreamerLayerNorm.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
"""
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
return x
class DenseHead(nn.Module):
"""
Overview:
DenseHead Network for value head, reward head, and discount head of dreamerv3.
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
inp_dim,
shape, # (255,)
layer_num,
units, # 512
act='SiLU',
norm='LN',
dist='normal',
std=1.0,
outscale=1.0,
device='cpu',
):
"""
Overview:
Init the DenseHead class.
Arguments:
- inp_dim (:obj:`int`): Input dimension.
- shape (:obj:`tuple`): Output shape.
- layer_num (:obj:`int`): Number of layers.
- units (:obj:`int`): Number of units.
- act (:obj:`str`): Activation function.
- norm (:obj:`str`): Normalization function.
- dist (:obj:`str`): Distribution function.
- std (:obj:`float`): Standard deviation.
- outscale (:obj:`float`): Output scale.
- device (:obj:`str`): Device.
"""
super(DenseHead, self).__init__()
self._shape = (shape, ) if isinstance(shape, int) else shape
if len(self._shape) == 0:
self._shape = (1, )
self._layer_num = layer_num
self._units = units
self._act = getattr(torch.nn, act)()
self._norm = norm
self._dist = dist
self._std = std
self._device = device
self.mlp = MLP(
inp_dim,
self._units,
self._units,
self._layer_num,
layer_fn=nn.Linear,
activation=self._act,
norm_type=self._norm
)
self.mlp.apply(weight_init)
self.mean_layer = nn.Linear(self._units, np.prod(self._shape))
self.mean_layer.apply(uniform_weight_init(outscale))
if self._std == "learned":
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
self.std_layer.apply(uniform_weight_init(outscale))
def forward(self, features):
"""
Overview:
compute the forward of DenseHead.
Arguments:
- features (:obj:`torch.Tensor`): Input tensor.
"""
x = features
out = self.mlp(x) # (batch, time, _units=512)
mean = self.mean_layer(out) # (batch, time, 255)
if self._std == "learned":
std = self.std_layer(out)
else:
std = self._std
if self._dist == "normal":
return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape)))
elif self._dist == "huber":
return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape)))
elif self._dist == "binary":
return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)))
elif self._dist == "twohot_symlog":
return TwoHotDistSymlog(logits=mean, device=self._device)
raise NotImplementedError(self._dist)
class ActionHead(nn.Module):
"""
Overview:
ActionHead Network for action head of dreamerv3.
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
inp_dim,
size,
layers,
units,
act=nn.ELU,
norm=nn.LayerNorm,
dist="trunc_normal",
init_std=0.0,
min_std=0.1,
max_std=1.0,
temp=0.1,
outscale=1.0,
unimix_ratio=0.01,
):
"""
Overview:
Initialize the ActionHead class.
Arguments:
- inp_dim (:obj:`int`): Input dimension.
- size (:obj:`int`): Output size.
- layers (:obj:`int`): Number of layers.
- units (:obj:`int`): Number of units.
- act (:obj:`str`): Activation function.
- norm (:obj:`str`): Normalization function.
- dist (:obj:`str`): Distribution function.
- init_std (:obj:`float`): Initial standard deviation.
- min_std (:obj:`float`): Minimum standard deviation.
- max_std (:obj:`float`): Maximum standard deviation.
- temp (:obj:`float`): Temperature.
- outscale (:obj:`float`): Output scale.
- unimix_ratio (:obj:`float`): Unimix ratio.
"""
super(ActionHead, self).__init__()
self._size = size
self._layers = layers
self._units = units
self._dist = dist
self._act = getattr(torch.nn, act)
self._norm = getattr(torch.nn, norm)
self._min_std = min_std
self._max_std = max_std
self._init_std = init_std
self._unimix_ratio = unimix_ratio
self._temp = temp() if callable(temp) else temp
pre_layers = []
for index in range(self._layers):
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
pre_layers.append(self._norm(self._units, eps=1e-03))
pre_layers.append(self._act())
if index == 0:
inp_dim = self._units
self._pre_layers = nn.Sequential(*pre_layers)
self._pre_layers.apply(weight_init)
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
self._dist_layer = nn.Linear(self._units, 2 * self._size)
self._dist_layer.apply(uniform_weight_init(outscale))
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
self._dist_layer = nn.Linear(self._units, self._size)
self._dist_layer.apply(uniform_weight_init(outscale))
def forward(self, features):
"""
Overview:
compute the forward of ActionHead.
Arguments:
- features (:obj:`torch.Tensor`): Input tensor.
"""
x = features
x = self._pre_layers(x)
if self._dist == "tanh_normal":
x = self._dist_layer(x)
mean, std = torch.split(x, 2, -1)
mean = torch.tanh(mean)
std = F.softplus(std + self._init_std) + self._min_std
dist = torchd.normal.Normal(mean, std)
dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector())
dist = torchd.independent.Independent(dist, 1)
dist = SampleDist(dist)
elif self._dist == "tanh_normal_5":
x = self._dist_layer(x)
mean, std = torch.split(x, 2, -1)
mean = 5 * torch.tanh(mean / 5)
std = F.softplus(std + 5) + 5
dist = torchd.normal.Normal(mean, std)
dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector())
dist = torchd.independent.Independent(dist, 1)
dist = SampleDist(dist)
elif self._dist == "normal":
x = self._dist_layer(x)
mean, std = torch.split(x, [self._size] * 2, -1)
std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std
dist = torchd.normal.Normal(torch.tanh(mean), std)
dist = ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "normal_1":
x = self._dist_layer(x)
dist = torchd.normal.Normal(mean, 1)
dist = ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "trunc_normal":
x = self._dist_layer(x)
mean, std = torch.split(x, [self._size] * 2, -1)
mean = torch.tanh(mean)
std = 2 * torch.sigmoid(std / 2) + self._min_std
dist = SafeTruncatedNormal(mean, std, -1, 1)
dist = ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "onehot":
x = self._dist_layer(x)
dist = OneHotDist(x, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble":
x = self._dist_layer(x)
temp = self._temp
dist = ContDist(torchd.gumbel.Gumbel(x, 1 / temp))
else:
raise NotImplementedError(self._dist)
return dist
class SampleDist:
"""
Overview:
A kind of sample Dist for ActionHead of dreamerv3.
Interfaces:
``__init__``, ``mean``, ``mode``, ``entropy``
"""
def __init__(self, dist, samples=100):
"""
Overview:
Initialize the SampleDist class.
Arguments:
- dist (:obj:`torch.Tensor`): Distribution.
- samples (:obj:`int`): Number of samples.
"""
self._dist = dist
self._samples = samples
def mean(self):
"""
Overview:
Calculate the mean of the distribution.
"""
samples = self._dist.sample(self._samples)
return torch.mean(samples, 0)
def mode(self):
"""
Overview:
Calculate the mode of the distribution.
"""
sample = self._dist.sample(self._samples)
logprob = self._dist.log_prob(sample)
return sample[torch.argmax(logprob)][0]
def entropy(self):
"""
Overview:
Calculate the entropy of the distribution.
"""
sample = self._dist.sample(self._samples)
logprob = self.log_prob(sample)
return -torch.mean(logprob, 0)
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
"""
Overview:
A kind of onehot Dist for dreamerv3.
Interfaces:
``__init__``, ``mode``, ``sample``
"""
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
"""
Overview:
Initialize the OneHotDist class.
Arguments:
- logits (:obj:`torch.Tensor`): Logits.
- probs (:obj:`torch.Tensor`): Probabilities.
- unimix_ratio (:obj:`float`): Unimix ratio.
"""
if logits is not None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1)
probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1]
logits = torch.log(probs)
super().__init__(logits=logits, probs=None)
else:
super().__init__(logits=logits, probs=probs)
def mode(self):
"""
Overview:
Calculate the mode of the distribution.
"""
_mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
return _mode.detach() + super().logits - super().logits.detach()
def sample(self, sample_shape=(), seed=None):
"""
Overview:
Sample from the distribution.
Arguments:
- sample_shape (:obj:`tuple`): Sample shape.
- seed (:obj:`int`): Seed.
"""
if seed is not None:
raise ValueError('need to check')
sample = super().sample(sample_shape)
probs = super().probs
while len(probs.shape) < len(sample.shape):
probs = probs[None]
sample += probs - probs.detach()
return sample
class TwoHotDistSymlog:
"""
Overview:
A kind of twohotsymlog Dist for dreamerv3.
Interfaces:
``__init__``, ``mode``, ``mean``, ``log_prob``, ``log_prob_target``
"""
def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'):
"""
Overview:
Initialize the TwoHotDistSymlog class.
Arguments:
- logits (:obj:`torch.Tensor`): Logits.
- low (:obj:`float`): Low.
- high (:obj:`float`): High.
- device (:obj:`str`): Device.
"""
self.logits = logits
self.probs = torch.softmax(logits, -1)
self.buckets = torch.linspace(low, high, steps=255).to(device)
self.width = (self.buckets[-1] - self.buckets[0]) / 255
def mean(self):
"""
Overview:
Calculate the mean of the distribution.
"""
_mean = self.probs * self.buckets
return inv_symlog(torch.sum(_mean, dim=-1, keepdim=True))
def mode(self):
"""
Overview:
Calculate the mode of the distribution.
"""
_mode = self.probs * self.buckets
return inv_symlog(torch.sum(_mode, dim=-1, keepdim=True))
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
def log_prob(self, x):
"""
Overview:
Calculate the log probability of the distribution.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
"""
x = symlog(x)
# x(time, batch, 1)
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1)
below = torch.clip(below, 0, len(self.buckets) - 1)
above = torch.clip(above, 0, len(self.buckets) - 1)
equal = (below == above)
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
total = dist_to_below + dist_to_above
weight_below = dist_to_above / total
weight_above = dist_to_below / total
target = (
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] +
F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]
)
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
target = target.squeeze(-2)
return (target * log_pred).sum(-1)
def log_prob_target(self, target):
"""
Overview:
Calculate the log probability of the target.
Arguments:
- target (:obj:`torch.Tensor`): Target tensor.
"""
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
return (target * log_pred).sum(-1)
class SymlogDist:
"""
Overview:
A kind of Symlog Dist for dreamerv3.
Interfaces:
``__init__``, ``entropy``, ``mode``, ``mean``, ``log_prob``
"""
def __init__(self, mode, dist='mse', aggregation='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]):
"""
Overview:
Initialize the SymlogDist class.
Arguments:
- mode (:obj:`torch.Tensor`): Mode.
- dist (:obj:`str`): Distribution function.
- aggregation (:obj:`str`): Aggregation function.
- tol (:obj:`float`): Tolerance.
- dim_to_reduce (:obj:`list`): Dimension to reduce.
"""
self._mode = mode
self._dist = dist
self._aggregation = aggregation
self._tol = tol
self._dim_to_reduce = dim_to_reduce
def mode(self):
"""
Overview:
Calculate the mode of the distribution.
"""
return inv_symlog(self._mode)
def mean(self):
"""
Overview:
Calculate the mean of the distribution.
"""
return inv_symlog(self._mode)
def log_prob(self, value):
"""
Overview:
Calculate the log probability of the distribution.
Arguments:
- value (:obj:`torch.Tensor`): Input tensor.
"""
assert self._mode.shape == value.shape
if self._dist == 'mse':
distance = (self._mode - symlog(value)) ** 2.0
distance = torch.where(distance < self._tol, 0, distance)
elif self._dist == 'abs':
distance = torch.abs(self._mode - symlog(value))
distance = torch.where(distance < self._tol, 0, distance)
else:
raise NotImplementedError(self._dist)
if self._aggregation == 'mean':
loss = distance.mean(self._dim_to_reduce)
elif self._aggregation == 'sum':
loss = distance.sum(self._dim_to_reduce)
else:
raise NotImplementedError(self._aggregation)
return -loss
class ContDist:
"""
Overview:
A kind of ordinary Dist for dreamerv3.
Interfaces:
``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob``
"""
def __init__(self, dist=None):
"""
Overview:
Initialize the ContDist class.
Arguments:
- dist (:obj:`torch.Tensor`): Distribution.
"""
super().__init__()
self._dist = dist
self.mean = dist.mean
def __getattr__(self, name):
"""
Overview:
Get attribute.
Arguments:
- name (:obj:`str`): Attribute name.
"""
return getattr(self._dist, name)
def entropy(self):
"""
Overview:
Calculate the entropy of the distribution.
"""
return self._dist.entropy()
def mode(self):
"""
Overview:
Calculate the mode of the distribution.
"""
return self._dist.mean
def sample(self, sample_shape=()):
"""
Overview:
Sample from the distribution.
Arguments:
- sample_shape (:obj:`tuple`): Sample shape.
"""
return self._dist.rsample(sample_shape)
def log_prob(self, x):
return self._dist.log_prob(x)
class Bernoulli:
"""
Overview:
A kind of Bernoulli Dist for dreamerv3.
Interfaces:
``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob``
"""
def __init__(self, dist=None):
"""
Overview:
Initialize the Bernoulli distribution.
Arguments:
- dist (:obj:`torch.Tensor`): Distribution.
"""
super().__init__()
self._dist = dist
self.mean = dist.mean
def __getattr__(self, name):
"""
Overview:
Get attribute.
Arguments:
- name (:obj:`str`): Attribute name.
"""
return getattr(self._dist, name)
def entropy(self):
"""
Overview:
Calculate the entropy of the distribution.
"""
return self._dist.entropy()
def mode(self):
"""
Overview:
Calculate the mode of the distribution.
"""
_mode = torch.round(self._dist.mean)
return _mode.detach() + self._dist.mean - self._dist.mean.detach()
def sample(self, sample_shape=()):
"""
Overview:
Sample from the distribution.
Arguments:
- sample_shape (:obj:`tuple`): Sample shape.
"""
return self._dist.rsample(sample_shape)
def log_prob(self, x):
"""
Overview:
Calculate the log probability of the distribution.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
"""
_logits = self._dist.base_dist.logits
log_probs0 = -F.softplus(_logits)
log_probs1 = -F.softplus(-_logits)
return log_probs0 * (1 - x) + log_probs1 * x
class UnnormalizedHuber(torchd.normal.Normal):
"""
Overview:
A kind of UnnormalizedHuber Dist for dreamerv3.
Interfaces:
``__init__``, ``mode``, ``log_prob``
"""
def __init__(self, loc, scale, threshold=1, **kwargs):
"""
Overview:
Initialize the UnnormalizedHuber class.
Arguments:
- loc (:obj:`torch.Tensor`): Location.
- scale (:obj:`torch.Tensor`): Scale.
- threshold (:obj:`float`): Threshold.
"""
super().__init__(loc, scale, **kwargs)
self._threshold = threshold
def log_prob(self, event):
"""
Overview:
Calculate the log probability of the distribution.
Arguments:
- event (:obj:`torch.Tensor`): Event.
"""
return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold)
def mode(self):
"""
Overview:
Calculate the mode of the distribution.
"""
return self.mean
class SafeTruncatedNormal(torchd.normal.Normal):
"""
Overview:
A kind of SafeTruncatedNormal Dist for dreamerv3.
Interfaces:
``__init__``, ``sample``
"""
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
"""
Overview:
Initialize the SafeTruncatedNormal class.
Arguments:
- loc (:obj:`torch.Tensor`): Location.
- scale (:obj:`torch.Tensor`): Scale.
- low (:obj:`float`): Low.
- high (:obj:`float`): High.
- clip (:obj:`float`): Clip.
- mult (:obj:`float`): Mult.
"""
super().__init__(loc, scale)
self._low = low
self._high = high
self._clip = clip
self._mult = mult
def sample(self, sample_shape):
"""
Overview:
Sample from the distribution.
Arguments:
- sample_shape (:obj:`tuple`): Sample shape.
"""
event = super().sample(sample_shape)
if self._clip:
clipped = torch.clip(event, self._low + self._clip, self._high - self._clip)
event = event - event.detach() + clipped.detach()
if self._mult:
event *= self._mult
return event
class TanhBijector(torchd.Transform):
"""
Overview:
A kind of TanhBijector Dist for dreamerv3.
Interfaces:
``__init__``, ``_forward``, ``_inverse``, ``_forward_log_det_jacobian``
"""
def __init__(self, validate_args=False, name='tanh'):
"""
Overview:
Initialize the TanhBijector class.
Arguments:
- validate_args (:obj:`bool`): Validate arguments.
- name (:obj:`str`): Name.
"""
super().__init__()
def _forward(self, x):
"""
Overview:
Calculate the forward of the distribution.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
"""
return torch.tanh(x)
def _inverse(self, y):
"""
Overview:
Calculate the inverse of the distribution.
Arguments:
- y (:obj:`torch.Tensor`): Input tensor.
"""
y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y)
y = torch.atanh(y)
return y
def _forward_log_det_jacobian(self, x):
"""
Overview:
Calculate the forward log det jacobian of the distribution.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
"""
log2 = torch.math.log(2.0)
return 2.0 * (log2 - x - torch.softplus(-2.0 * x))
def static_scan(fn, inputs, start):
"""
Overview:
Static scan function.
Arguments:
- fn (:obj:`function`): Function.
- inputs (:obj:`tuple`): Inputs.
- start (:obj:`torch.Tensor`): Start tensor.
"""
last = start # {logit, stoch, deter:[batch_size, self._deter]}
indices = range(inputs[0].shape[0])
flag = True
for index in indices:
inp = lambda x: (_input[x] for _input in inputs) # inputs:(action:(time, batch, 6), embed:(time, batch, 4096))
last = fn(last, *inp(index)) # post, prior
if flag:
if isinstance(last, dict):
outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()}
else:
outputs = []
for _last in last:
if isinstance(_last, dict):
outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()})
else:
outputs.append(_last.clone().unsqueeze(0))
flag = False
else:
if isinstance(last, dict):
for key in last.keys():
outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0)
else:
for j in range(len(outputs)):
if isinstance(last[j], dict):
for key in last[j].keys():
outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0)
else:
outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0)
if isinstance(last, dict):
outputs = [outputs]
return outputs
def weight_init(m):
"""
Overview:
weight_init for Linear, Conv2d, ConvTranspose2d, and LayerNorm.
Arguments:
- m (:obj:`torch.nn`): Module.
"""
if isinstance(m, nn.Linear):
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
space = m.kernel_size[0] * m.kernel_size[1]
in_num = space * m.in_channels
out_num = space * m.out_channels
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
def uniform_weight_init(given_scale):
"""
Overview:
weight_init for Linear and LayerNorm.
Arguments:
- given_scale (:obj:`float`): Given scale.
"""
def f(m):
if isinstance(m, nn.Linear):
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = given_scale / denoms
limit = np.sqrt(3 * scale)
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
return f