|
import math |
|
import numpy as np |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch import distributions as torchd |
|
from ding.torch_utils import MLP |
|
from ding.rl_utils import symlog, inv_symlog |
|
|
|
|
|
class Conv2dSame(torch.nn.Conv2d): |
|
""" |
|
Overview: |
|
Conv2dSame Network for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def calc_same_pad(self, i, k, s, d): |
|
""" |
|
Overview: |
|
Calculate the same padding size. |
|
Arguments: |
|
- i (:obj:`int`): Input size. |
|
- k (:obj:`int`): Kernel size. |
|
- s (:obj:`int`): Stride size. |
|
- d (:obj:`int`): Dilation size. |
|
""" |
|
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) |
|
|
|
def forward(self, x): |
|
""" |
|
Overview: |
|
compute the forward of Conv2dSame. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
ih, iw = x.size()[-2:] |
|
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) |
|
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) |
|
|
|
if pad_h > 0 or pad_w > 0: |
|
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) |
|
|
|
ret = F.conv2d( |
|
x, |
|
self.weight, |
|
self.bias, |
|
self.stride, |
|
self.padding, |
|
self.dilation, |
|
self.groups, |
|
) |
|
return ret |
|
|
|
|
|
class DreamerLayerNorm(nn.Module): |
|
""" |
|
Overview: |
|
DreamerLayerNorm Network for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, ch, eps=1e-03): |
|
""" |
|
Overview: |
|
Init the DreamerLayerNorm class. |
|
Arguments: |
|
- ch (:obj:`int`): Input channel. |
|
- eps (:obj:`float`): Epsilon. |
|
""" |
|
|
|
super(DreamerLayerNorm, self).__init__() |
|
self.norm = torch.nn.LayerNorm(ch, eps=eps) |
|
|
|
def forward(self, x): |
|
""" |
|
Overview: |
|
compute the forward of DreamerLayerNorm. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
x = x.permute(0, 2, 3, 1) |
|
x = self.norm(x) |
|
x = x.permute(0, 3, 1, 2) |
|
return x |
|
|
|
|
|
class DenseHead(nn.Module): |
|
""" |
|
Overview: |
|
DenseHead Network for value head, reward head, and discount head of dreamerv3. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
inp_dim, |
|
shape, |
|
layer_num, |
|
units, |
|
act='SiLU', |
|
norm='LN', |
|
dist='normal', |
|
std=1.0, |
|
outscale=1.0, |
|
device='cpu', |
|
): |
|
""" |
|
Overview: |
|
Init the DenseHead class. |
|
Arguments: |
|
- inp_dim (:obj:`int`): Input dimension. |
|
- shape (:obj:`tuple`): Output shape. |
|
- layer_num (:obj:`int`): Number of layers. |
|
- units (:obj:`int`): Number of units. |
|
- act (:obj:`str`): Activation function. |
|
- norm (:obj:`str`): Normalization function. |
|
- dist (:obj:`str`): Distribution function. |
|
- std (:obj:`float`): Standard deviation. |
|
- outscale (:obj:`float`): Output scale. |
|
- device (:obj:`str`): Device. |
|
""" |
|
|
|
super(DenseHead, self).__init__() |
|
self._shape = (shape, ) if isinstance(shape, int) else shape |
|
if len(self._shape) == 0: |
|
self._shape = (1, ) |
|
self._layer_num = layer_num |
|
self._units = units |
|
self._act = getattr(torch.nn, act)() |
|
self._norm = norm |
|
self._dist = dist |
|
self._std = std |
|
self._device = device |
|
|
|
self.mlp = MLP( |
|
inp_dim, |
|
self._units, |
|
self._units, |
|
self._layer_num, |
|
layer_fn=nn.Linear, |
|
activation=self._act, |
|
norm_type=self._norm |
|
) |
|
self.mlp.apply(weight_init) |
|
|
|
self.mean_layer = nn.Linear(self._units, np.prod(self._shape)) |
|
self.mean_layer.apply(uniform_weight_init(outscale)) |
|
|
|
if self._std == "learned": |
|
self.std_layer = nn.Linear(self._units, np.prod(self._shape)) |
|
self.std_layer.apply(uniform_weight_init(outscale)) |
|
|
|
def forward(self, features): |
|
""" |
|
Overview: |
|
compute the forward of DenseHead. |
|
Arguments: |
|
- features (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
x = features |
|
out = self.mlp(x) |
|
mean = self.mean_layer(out) |
|
if self._std == "learned": |
|
std = self.std_layer(out) |
|
else: |
|
std = self._std |
|
if self._dist == "normal": |
|
return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape))) |
|
elif self._dist == "huber": |
|
return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape))) |
|
elif self._dist == "binary": |
|
return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) |
|
elif self._dist == "twohot_symlog": |
|
return TwoHotDistSymlog(logits=mean, device=self._device) |
|
raise NotImplementedError(self._dist) |
|
|
|
|
|
class ActionHead(nn.Module): |
|
""" |
|
Overview: |
|
ActionHead Network for action head of dreamerv3. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
inp_dim, |
|
size, |
|
layers, |
|
units, |
|
act=nn.ELU, |
|
norm=nn.LayerNorm, |
|
dist="trunc_normal", |
|
init_std=0.0, |
|
min_std=0.1, |
|
max_std=1.0, |
|
temp=0.1, |
|
outscale=1.0, |
|
unimix_ratio=0.01, |
|
): |
|
""" |
|
Overview: |
|
Initialize the ActionHead class. |
|
Arguments: |
|
- inp_dim (:obj:`int`): Input dimension. |
|
- size (:obj:`int`): Output size. |
|
- layers (:obj:`int`): Number of layers. |
|
- units (:obj:`int`): Number of units. |
|
- act (:obj:`str`): Activation function. |
|
- norm (:obj:`str`): Normalization function. |
|
- dist (:obj:`str`): Distribution function. |
|
- init_std (:obj:`float`): Initial standard deviation. |
|
- min_std (:obj:`float`): Minimum standard deviation. |
|
- max_std (:obj:`float`): Maximum standard deviation. |
|
- temp (:obj:`float`): Temperature. |
|
- outscale (:obj:`float`): Output scale. |
|
- unimix_ratio (:obj:`float`): Unimix ratio. |
|
""" |
|
super(ActionHead, self).__init__() |
|
self._size = size |
|
self._layers = layers |
|
self._units = units |
|
self._dist = dist |
|
self._act = getattr(torch.nn, act) |
|
self._norm = getattr(torch.nn, norm) |
|
self._min_std = min_std |
|
self._max_std = max_std |
|
self._init_std = init_std |
|
self._unimix_ratio = unimix_ratio |
|
self._temp = temp() if callable(temp) else temp |
|
|
|
pre_layers = [] |
|
for index in range(self._layers): |
|
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False)) |
|
pre_layers.append(self._norm(self._units, eps=1e-03)) |
|
pre_layers.append(self._act()) |
|
if index == 0: |
|
inp_dim = self._units |
|
self._pre_layers = nn.Sequential(*pre_layers) |
|
self._pre_layers.apply(weight_init) |
|
|
|
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: |
|
self._dist_layer = nn.Linear(self._units, 2 * self._size) |
|
self._dist_layer.apply(uniform_weight_init(outscale)) |
|
|
|
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: |
|
self._dist_layer = nn.Linear(self._units, self._size) |
|
self._dist_layer.apply(uniform_weight_init(outscale)) |
|
|
|
def forward(self, features): |
|
""" |
|
Overview: |
|
compute the forward of ActionHead. |
|
Arguments: |
|
- features (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
x = features |
|
x = self._pre_layers(x) |
|
if self._dist == "tanh_normal": |
|
x = self._dist_layer(x) |
|
mean, std = torch.split(x, 2, -1) |
|
mean = torch.tanh(mean) |
|
std = F.softplus(std + self._init_std) + self._min_std |
|
dist = torchd.normal.Normal(mean, std) |
|
dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) |
|
dist = torchd.independent.Independent(dist, 1) |
|
dist = SampleDist(dist) |
|
elif self._dist == "tanh_normal_5": |
|
x = self._dist_layer(x) |
|
mean, std = torch.split(x, 2, -1) |
|
mean = 5 * torch.tanh(mean / 5) |
|
std = F.softplus(std + 5) + 5 |
|
dist = torchd.normal.Normal(mean, std) |
|
dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) |
|
dist = torchd.independent.Independent(dist, 1) |
|
dist = SampleDist(dist) |
|
elif self._dist == "normal": |
|
x = self._dist_layer(x) |
|
mean, std = torch.split(x, [self._size] * 2, -1) |
|
std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std |
|
dist = torchd.normal.Normal(torch.tanh(mean), std) |
|
dist = ContDist(torchd.independent.Independent(dist, 1)) |
|
elif self._dist == "normal_1": |
|
x = self._dist_layer(x) |
|
dist = torchd.normal.Normal(mean, 1) |
|
dist = ContDist(torchd.independent.Independent(dist, 1)) |
|
elif self._dist == "trunc_normal": |
|
x = self._dist_layer(x) |
|
mean, std = torch.split(x, [self._size] * 2, -1) |
|
mean = torch.tanh(mean) |
|
std = 2 * torch.sigmoid(std / 2) + self._min_std |
|
dist = SafeTruncatedNormal(mean, std, -1, 1) |
|
dist = ContDist(torchd.independent.Independent(dist, 1)) |
|
elif self._dist == "onehot": |
|
x = self._dist_layer(x) |
|
dist = OneHotDist(x, unimix_ratio=self._unimix_ratio) |
|
elif self._dist == "onehot_gumble": |
|
x = self._dist_layer(x) |
|
temp = self._temp |
|
dist = ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) |
|
else: |
|
raise NotImplementedError(self._dist) |
|
return dist |
|
|
|
|
|
class SampleDist: |
|
""" |
|
Overview: |
|
A kind of sample Dist for ActionHead of dreamerv3. |
|
Interfaces: |
|
``__init__``, ``mean``, ``mode``, ``entropy`` |
|
""" |
|
|
|
def __init__(self, dist, samples=100): |
|
""" |
|
Overview: |
|
Initialize the SampleDist class. |
|
Arguments: |
|
- dist (:obj:`torch.Tensor`): Distribution. |
|
- samples (:obj:`int`): Number of samples. |
|
""" |
|
|
|
self._dist = dist |
|
self._samples = samples |
|
|
|
def mean(self): |
|
""" |
|
Overview: |
|
Calculate the mean of the distribution. |
|
""" |
|
|
|
samples = self._dist.sample(self._samples) |
|
return torch.mean(samples, 0) |
|
|
|
def mode(self): |
|
""" |
|
Overview: |
|
Calculate the mode of the distribution. |
|
""" |
|
|
|
sample = self._dist.sample(self._samples) |
|
logprob = self._dist.log_prob(sample) |
|
return sample[torch.argmax(logprob)][0] |
|
|
|
def entropy(self): |
|
""" |
|
Overview: |
|
Calculate the entropy of the distribution. |
|
""" |
|
|
|
sample = self._dist.sample(self._samples) |
|
logprob = self.log_prob(sample) |
|
return -torch.mean(logprob, 0) |
|
|
|
|
|
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): |
|
""" |
|
Overview: |
|
A kind of onehot Dist for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``mode``, ``sample`` |
|
""" |
|
|
|
def __init__(self, logits=None, probs=None, unimix_ratio=0.0): |
|
""" |
|
Overview: |
|
Initialize the OneHotDist class. |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): Logits. |
|
- probs (:obj:`torch.Tensor`): Probabilities. |
|
- unimix_ratio (:obj:`float`): Unimix ratio. |
|
""" |
|
|
|
if logits is not None and unimix_ratio > 0.0: |
|
probs = F.softmax(logits, dim=-1) |
|
probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1] |
|
logits = torch.log(probs) |
|
super().__init__(logits=logits, probs=None) |
|
else: |
|
super().__init__(logits=logits, probs=probs) |
|
|
|
def mode(self): |
|
""" |
|
Overview: |
|
Calculate the mode of the distribution. |
|
""" |
|
|
|
_mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) |
|
return _mode.detach() + super().logits - super().logits.detach() |
|
|
|
def sample(self, sample_shape=(), seed=None): |
|
""" |
|
Overview: |
|
Sample from the distribution. |
|
Arguments: |
|
- sample_shape (:obj:`tuple`): Sample shape. |
|
- seed (:obj:`int`): Seed. |
|
""" |
|
|
|
if seed is not None: |
|
raise ValueError('need to check') |
|
sample = super().sample(sample_shape) |
|
probs = super().probs |
|
while len(probs.shape) < len(sample.shape): |
|
probs = probs[None] |
|
sample += probs - probs.detach() |
|
return sample |
|
|
|
|
|
class TwoHotDistSymlog: |
|
""" |
|
Overview: |
|
A kind of twohotsymlog Dist for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``mode``, ``mean``, ``log_prob``, ``log_prob_target`` |
|
""" |
|
|
|
def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): |
|
""" |
|
Overview: |
|
Initialize the TwoHotDistSymlog class. |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): Logits. |
|
- low (:obj:`float`): Low. |
|
- high (:obj:`float`): High. |
|
- device (:obj:`str`): Device. |
|
""" |
|
|
|
self.logits = logits |
|
self.probs = torch.softmax(logits, -1) |
|
self.buckets = torch.linspace(low, high, steps=255).to(device) |
|
self.width = (self.buckets[-1] - self.buckets[0]) / 255 |
|
|
|
def mean(self): |
|
""" |
|
Overview: |
|
Calculate the mean of the distribution. |
|
""" |
|
|
|
_mean = self.probs * self.buckets |
|
return inv_symlog(torch.sum(_mean, dim=-1, keepdim=True)) |
|
|
|
def mode(self): |
|
""" |
|
Overview: |
|
Calculate the mode of the distribution. |
|
""" |
|
|
|
_mode = self.probs * self.buckets |
|
return inv_symlog(torch.sum(_mode, dim=-1, keepdim=True)) |
|
|
|
|
|
def log_prob(self, x): |
|
""" |
|
Overview: |
|
Calculate the log probability of the distribution. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
x = symlog(x) |
|
|
|
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 |
|
above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) |
|
below = torch.clip(below, 0, len(self.buckets) - 1) |
|
above = torch.clip(above, 0, len(self.buckets) - 1) |
|
equal = (below == above) |
|
|
|
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) |
|
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) |
|
total = dist_to_below + dist_to_above |
|
weight_below = dist_to_above / total |
|
weight_above = dist_to_below / total |
|
target = ( |
|
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + |
|
F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None] |
|
) |
|
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) |
|
target = target.squeeze(-2) |
|
|
|
return (target * log_pred).sum(-1) |
|
|
|
def log_prob_target(self, target): |
|
""" |
|
Overview: |
|
Calculate the log probability of the target. |
|
Arguments: |
|
- target (:obj:`torch.Tensor`): Target tensor. |
|
""" |
|
|
|
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) |
|
return (target * log_pred).sum(-1) |
|
|
|
|
|
class SymlogDist: |
|
""" |
|
Overview: |
|
A kind of Symlog Dist for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``entropy``, ``mode``, ``mean``, ``log_prob`` |
|
""" |
|
|
|
def __init__(self, mode, dist='mse', aggregation='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): |
|
""" |
|
Overview: |
|
Initialize the SymlogDist class. |
|
Arguments: |
|
- mode (:obj:`torch.Tensor`): Mode. |
|
- dist (:obj:`str`): Distribution function. |
|
- aggregation (:obj:`str`): Aggregation function. |
|
- tol (:obj:`float`): Tolerance. |
|
- dim_to_reduce (:obj:`list`): Dimension to reduce. |
|
""" |
|
self._mode = mode |
|
self._dist = dist |
|
self._aggregation = aggregation |
|
self._tol = tol |
|
self._dim_to_reduce = dim_to_reduce |
|
|
|
def mode(self): |
|
""" |
|
Overview: |
|
Calculate the mode of the distribution. |
|
""" |
|
|
|
return inv_symlog(self._mode) |
|
|
|
def mean(self): |
|
""" |
|
Overview: |
|
Calculate the mean of the distribution. |
|
""" |
|
|
|
return inv_symlog(self._mode) |
|
|
|
def log_prob(self, value): |
|
""" |
|
Overview: |
|
Calculate the log probability of the distribution. |
|
Arguments: |
|
- value (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
assert self._mode.shape == value.shape |
|
if self._dist == 'mse': |
|
distance = (self._mode - symlog(value)) ** 2.0 |
|
distance = torch.where(distance < self._tol, 0, distance) |
|
elif self._dist == 'abs': |
|
distance = torch.abs(self._mode - symlog(value)) |
|
distance = torch.where(distance < self._tol, 0, distance) |
|
else: |
|
raise NotImplementedError(self._dist) |
|
if self._aggregation == 'mean': |
|
loss = distance.mean(self._dim_to_reduce) |
|
elif self._aggregation == 'sum': |
|
loss = distance.sum(self._dim_to_reduce) |
|
else: |
|
raise NotImplementedError(self._aggregation) |
|
return -loss |
|
|
|
|
|
class ContDist: |
|
""" |
|
Overview: |
|
A kind of ordinary Dist for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` |
|
""" |
|
|
|
def __init__(self, dist=None): |
|
""" |
|
Overview: |
|
Initialize the ContDist class. |
|
Arguments: |
|
- dist (:obj:`torch.Tensor`): Distribution. |
|
""" |
|
|
|
super().__init__() |
|
self._dist = dist |
|
self.mean = dist.mean |
|
|
|
def __getattr__(self, name): |
|
""" |
|
Overview: |
|
Get attribute. |
|
Arguments: |
|
- name (:obj:`str`): Attribute name. |
|
""" |
|
|
|
return getattr(self._dist, name) |
|
|
|
def entropy(self): |
|
""" |
|
Overview: |
|
Calculate the entropy of the distribution. |
|
""" |
|
|
|
return self._dist.entropy() |
|
|
|
def mode(self): |
|
""" |
|
Overview: |
|
Calculate the mode of the distribution. |
|
""" |
|
|
|
return self._dist.mean |
|
|
|
def sample(self, sample_shape=()): |
|
""" |
|
Overview: |
|
Sample from the distribution. |
|
Arguments: |
|
- sample_shape (:obj:`tuple`): Sample shape. |
|
""" |
|
|
|
return self._dist.rsample(sample_shape) |
|
|
|
def log_prob(self, x): |
|
return self._dist.log_prob(x) |
|
|
|
|
|
class Bernoulli: |
|
""" |
|
Overview: |
|
A kind of Bernoulli Dist for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` |
|
""" |
|
|
|
def __init__(self, dist=None): |
|
""" |
|
Overview: |
|
Initialize the Bernoulli distribution. |
|
Arguments: |
|
- dist (:obj:`torch.Tensor`): Distribution. |
|
""" |
|
|
|
super().__init__() |
|
self._dist = dist |
|
self.mean = dist.mean |
|
|
|
def __getattr__(self, name): |
|
""" |
|
Overview: |
|
Get attribute. |
|
Arguments: |
|
- name (:obj:`str`): Attribute name. |
|
""" |
|
|
|
return getattr(self._dist, name) |
|
|
|
def entropy(self): |
|
""" |
|
Overview: |
|
Calculate the entropy of the distribution. |
|
""" |
|
return self._dist.entropy() |
|
|
|
def mode(self): |
|
""" |
|
Overview: |
|
Calculate the mode of the distribution. |
|
""" |
|
|
|
_mode = torch.round(self._dist.mean) |
|
return _mode.detach() + self._dist.mean - self._dist.mean.detach() |
|
|
|
def sample(self, sample_shape=()): |
|
""" |
|
Overview: |
|
Sample from the distribution. |
|
Arguments: |
|
- sample_shape (:obj:`tuple`): Sample shape. |
|
""" |
|
|
|
return self._dist.rsample(sample_shape) |
|
|
|
def log_prob(self, x): |
|
""" |
|
Overview: |
|
Calculate the log probability of the distribution. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
_logits = self._dist.base_dist.logits |
|
log_probs0 = -F.softplus(_logits) |
|
log_probs1 = -F.softplus(-_logits) |
|
|
|
return log_probs0 * (1 - x) + log_probs1 * x |
|
|
|
|
|
class UnnormalizedHuber(torchd.normal.Normal): |
|
""" |
|
Overview: |
|
A kind of UnnormalizedHuber Dist for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``mode``, ``log_prob`` |
|
""" |
|
|
|
def __init__(self, loc, scale, threshold=1, **kwargs): |
|
""" |
|
Overview: |
|
Initialize the UnnormalizedHuber class. |
|
Arguments: |
|
- loc (:obj:`torch.Tensor`): Location. |
|
- scale (:obj:`torch.Tensor`): Scale. |
|
- threshold (:obj:`float`): Threshold. |
|
""" |
|
super().__init__(loc, scale, **kwargs) |
|
self._threshold = threshold |
|
|
|
def log_prob(self, event): |
|
""" |
|
Overview: |
|
Calculate the log probability of the distribution. |
|
Arguments: |
|
- event (:obj:`torch.Tensor`): Event. |
|
""" |
|
|
|
return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) |
|
|
|
def mode(self): |
|
""" |
|
Overview: |
|
Calculate the mode of the distribution. |
|
""" |
|
|
|
return self.mean |
|
|
|
|
|
class SafeTruncatedNormal(torchd.normal.Normal): |
|
""" |
|
Overview: |
|
A kind of SafeTruncatedNormal Dist for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``sample`` |
|
""" |
|
|
|
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): |
|
""" |
|
Overview: |
|
Initialize the SafeTruncatedNormal class. |
|
Arguments: |
|
- loc (:obj:`torch.Tensor`): Location. |
|
- scale (:obj:`torch.Tensor`): Scale. |
|
- low (:obj:`float`): Low. |
|
- high (:obj:`float`): High. |
|
- clip (:obj:`float`): Clip. |
|
- mult (:obj:`float`): Mult. |
|
""" |
|
|
|
super().__init__(loc, scale) |
|
self._low = low |
|
self._high = high |
|
self._clip = clip |
|
self._mult = mult |
|
|
|
def sample(self, sample_shape): |
|
""" |
|
Overview: |
|
Sample from the distribution. |
|
Arguments: |
|
- sample_shape (:obj:`tuple`): Sample shape. |
|
""" |
|
|
|
event = super().sample(sample_shape) |
|
if self._clip: |
|
clipped = torch.clip(event, self._low + self._clip, self._high - self._clip) |
|
event = event - event.detach() + clipped.detach() |
|
if self._mult: |
|
event *= self._mult |
|
return event |
|
|
|
|
|
class TanhBijector(torchd.Transform): |
|
""" |
|
Overview: |
|
A kind of TanhBijector Dist for dreamerv3. |
|
Interfaces: |
|
``__init__``, ``_forward``, ``_inverse``, ``_forward_log_det_jacobian`` |
|
""" |
|
|
|
def __init__(self, validate_args=False, name='tanh'): |
|
""" |
|
Overview: |
|
Initialize the TanhBijector class. |
|
Arguments: |
|
- validate_args (:obj:`bool`): Validate arguments. |
|
- name (:obj:`str`): Name. |
|
""" |
|
|
|
super().__init__() |
|
|
|
def _forward(self, x): |
|
""" |
|
Overview: |
|
Calculate the forward of the distribution. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
return torch.tanh(x) |
|
|
|
def _inverse(self, y): |
|
""" |
|
Overview: |
|
Calculate the inverse of the distribution. |
|
Arguments: |
|
- y (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) |
|
y = torch.atanh(y) |
|
return y |
|
|
|
def _forward_log_det_jacobian(self, x): |
|
""" |
|
Overview: |
|
Calculate the forward log det jacobian of the distribution. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Input tensor. |
|
""" |
|
|
|
log2 = torch.math.log(2.0) |
|
return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) |
|
|
|
|
|
def static_scan(fn, inputs, start): |
|
""" |
|
Overview: |
|
Static scan function. |
|
Arguments: |
|
- fn (:obj:`function`): Function. |
|
- inputs (:obj:`tuple`): Inputs. |
|
- start (:obj:`torch.Tensor`): Start tensor. |
|
""" |
|
|
|
last = start |
|
indices = range(inputs[0].shape[0]) |
|
flag = True |
|
for index in indices: |
|
inp = lambda x: (_input[x] for _input in inputs) |
|
last = fn(last, *inp(index)) |
|
if flag: |
|
if isinstance(last, dict): |
|
outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} |
|
else: |
|
outputs = [] |
|
for _last in last: |
|
if isinstance(_last, dict): |
|
outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) |
|
else: |
|
outputs.append(_last.clone().unsqueeze(0)) |
|
flag = False |
|
else: |
|
if isinstance(last, dict): |
|
for key in last.keys(): |
|
outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) |
|
else: |
|
for j in range(len(outputs)): |
|
if isinstance(last[j], dict): |
|
for key in last[j].keys(): |
|
outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0) |
|
else: |
|
outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) |
|
if isinstance(last, dict): |
|
outputs = [outputs] |
|
return outputs |
|
|
|
|
|
def weight_init(m): |
|
""" |
|
Overview: |
|
weight_init for Linear, Conv2d, ConvTranspose2d, and LayerNorm. |
|
Arguments: |
|
- m (:obj:`torch.nn`): Module. |
|
""" |
|
|
|
if isinstance(m, nn.Linear): |
|
in_num = m.in_features |
|
out_num = m.out_features |
|
denoms = (in_num + out_num) / 2.0 |
|
scale = 1.0 / denoms |
|
std = np.sqrt(scale) / 0.87962566103423978 |
|
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) |
|
if hasattr(m.bias, 'data'): |
|
m.bias.data.fill_(0.0) |
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): |
|
space = m.kernel_size[0] * m.kernel_size[1] |
|
in_num = space * m.in_channels |
|
out_num = space * m.out_channels |
|
denoms = (in_num + out_num) / 2.0 |
|
scale = 1.0 / denoms |
|
std = np.sqrt(scale) / 0.87962566103423978 |
|
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) |
|
if hasattr(m.bias, 'data'): |
|
m.bias.data.fill_(0.0) |
|
elif isinstance(m, nn.LayerNorm): |
|
m.weight.data.fill_(1.0) |
|
if hasattr(m.bias, 'data'): |
|
m.bias.data.fill_(0.0) |
|
|
|
|
|
def uniform_weight_init(given_scale): |
|
""" |
|
Overview: |
|
weight_init for Linear and LayerNorm. |
|
Arguments: |
|
- given_scale (:obj:`float`): Given scale. |
|
""" |
|
|
|
def f(m): |
|
if isinstance(m, nn.Linear): |
|
in_num = m.in_features |
|
out_num = m.out_features |
|
denoms = (in_num + out_num) / 2.0 |
|
scale = given_scale / denoms |
|
limit = np.sqrt(3 * scale) |
|
nn.init.uniform_(m.weight.data, a=-limit, b=limit) |
|
if hasattr(m.bias, 'data'): |
|
m.bias.data.fill_(0.0) |
|
elif isinstance(m, nn.LayerNorm): |
|
m.weight.data.fill_(1.0) |
|
if hasattr(m.bias, 'data'): |
|
m.bias.data.fill_(0.0) |
|
|
|
return f |
|
|