import copy import numpy as np from collections import namedtuple from typing import Union, Optional, Callable import torch import torch.nn as nn import torch.nn.functional as F from ding.hpc_rl import hpc_wrapper from ding.rl_utils.value_rescale import value_transform, value_inv_transform from ding.torch_utils import to_tensor q_1step_td_data = namedtuple('q_1step_td_data', ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight']) def discount_cumsum(x, gamma: float = 1.0) -> np.ndarray: assert abs(gamma - 1.) < 1e-5, "gamma equals to 1.0 in original decision transformer paper" disc_cumsum = np.zeros_like(x) disc_cumsum[-1] = x[-1] for t in reversed(range(x.shape[0] - 1)): disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1] return disc_cumsum def q_1step_td_error( data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa ) -> torch.Tensor: """ Overview: 1 step td_error, support single agent case and multi agent case. Arguments: - data (:obj:`q_1step_td_data`): The input data, q_1step_td_data to calculate loss - gamma (:obj:`float`): Discount factor - criterion (:obj:`torch.nn.modules`): Loss function criterion Returns: - loss (:obj:`torch.Tensor`): 1step td error Shapes: - data (:obj:`q_1step_td_data`): the q_1step_td_data containing\ ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'] - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - act (:obj:`torch.LongTensor`): :math:`(B, )` - next_act (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`( , B)` - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight Examples: >>> action_dim = 4 >>> data = q_1step_td_data( >>> q=torch.randn(3, action_dim), >>> next_q=torch.randn(3, action_dim), >>> act=torch.randint(0, action_dim, (3,)), >>> next_act=torch.randint(0, action_dim, (3,)), >>> reward=torch.randn(3), >>> done=torch.randint(0, 2, (3,)).bool(), >>> weight=torch.ones(3), >>> ) >>> loss = q_1step_td_error(data, 0.99) """ q, next_q, act, next_act, reward, done, weight = data assert len(act.shape) == 1, act.shape assert len(reward.shape) == 1, reward.shape batch_range = torch.arange(act.shape[0]) if weight is None: weight = torch.ones_like(reward) q_s_a = q[batch_range, act] target_q_s_a = next_q[batch_range, next_act] target_q_s_a = gamma * (1 - done) * target_q_s_a + reward return (criterion(q_s_a, target_q_s_a.detach()) * weight).mean() m_q_1step_td_data = namedtuple('m_q_1step_td_data', ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight']) def m_q_1step_td_error( data: namedtuple, gamma: float, tau: float, alpha: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa ) -> torch.Tensor: """ Overview: Munchausen td_error for DQN algorithm, support 1 step td error. Arguments: - data (:obj:`m_q_1step_td_data`): The input data, m_q_1step_td_data to calculate loss - gamma (:obj:`float`): Discount factor - tau (:obj:`float`): Entropy factor for Munchausen DQN - alpha (:obj:`float`): Discount factor for Munchausen term - criterion (:obj:`torch.nn.modules`): Loss function criterion Returns: - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor Shapes: - data (:obj:`m_q_1step_td_data`): the m_q_1step_td_data containing\ ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'] - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - target_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - act (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`( , B)` - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight Examples: >>> action_dim = 4 >>> data = m_q_1step_td_data( >>> q=torch.randn(3, action_dim), >>> target_q=torch.randn(3, action_dim), >>> next_q=torch.randn(3, action_dim), >>> act=torch.randint(0, action_dim, (3,)), >>> reward=torch.randn(3), >>> done=torch.randint(0, 2, (3,)), >>> weight=torch.ones(3), >>> ) >>> loss = m_q_1step_td_error(data, 0.99, 0.01, 0.01) """ q, target_q, next_q, act, reward, done, weight = data lower_bound = -1 assert len(act.shape) == 1, act.shape assert len(reward.shape) == 1, reward.shape batch_range = torch.arange(act.shape[0]) if weight is None: weight = torch.ones_like(reward) q_s_a = q[batch_range, act] # calculate muchausen addon # replay_log_policy target_v_s = target_q[batch_range].max(1)[0].unsqueeze(-1) logsum = torch.logsumexp((target_q - target_v_s) / tau, 1).unsqueeze(-1) log_pi = target_q - target_v_s - tau * logsum act_get = act.unsqueeze(-1) # same to the last second tau_log_pi_a munchausen_addon = log_pi.gather(1, act_get) muchausen_term = alpha * torch.clamp(munchausen_addon, min=lower_bound, max=1) # replay_next_log_policy target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1) logsum_next = torch.logsumexp((next_q - target_v_s_next) / tau, 1).unsqueeze(-1) tau_log_pi_next = next_q - target_v_s_next - tau * logsum_next # do stable softmax == replay_next_policy pi_target = F.softmax((next_q - target_v_s_next) / tau) target_q_s_a = (gamma * (pi_target * (next_q - tau_log_pi_next) * (1 - done.unsqueeze(-1))).sum(1)).unsqueeze(-1) target_q_s_a = reward.unsqueeze(-1) + muchausen_term + target_q_s_a td_error_per_sample = criterion(q_s_a.unsqueeze(-1), target_q_s_a.detach()).squeeze(-1) # calculate action_gap and clipfrac with torch.no_grad(): top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0] action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean() clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lower_bound) clipfrac = torch.as_tensor(clipped).float() return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac q_v_1step_td_data = namedtuple('q_v_1step_td_data', ['q', 'v', 'act', 'reward', 'done', 'weight']) def q_v_1step_td_error( data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none') ) -> torch.Tensor: # we will use this function in discrete sac algorithm to calculate td error between q and v value. """ Overview: td_error between q and v value for SAC algorithm, support 1 step td error. Arguments: - data (:obj:`q_v_1step_td_data`): The input data, q_v_1step_td_data to calculate loss - gamma (:obj:`float`): Discount factor - criterion (:obj:`torch.nn.modules`): Loss function criterion Returns: - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor Shapes: - data (:obj:`q_v_1step_td_data`): the q_v_1step_td_data containing\ ['q', 'v', 'act', 'reward', 'done', 'weight'] - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - v (:obj:`torch.FloatTensor`): :math:`(B, )` - act (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`( , B)` - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight Examples: >>> action_dim = 4 >>> data = q_v_1step_td_data( >>> q=torch.randn(3, action_dim), >>> v=torch.randn(3), >>> act=torch.randint(0, action_dim, (3,)), >>> reward=torch.randn(3), >>> done=torch.randint(0, 2, (3,)), >>> weight=torch.ones(3), >>> ) >>> loss = q_v_1step_td_error(data, 0.99) """ q, v, act, reward, done, weight = data if len(act.shape) == 1: assert len(reward.shape) == 1, reward.shape batch_range = torch.arange(act.shape[0]) if weight is None: weight = torch.ones_like(reward) q_s_a = q[batch_range, act] target_q_s_a = gamma * (1 - done) * v + reward else: assert len(reward.shape) == 1, reward.shape batch_range = torch.arange(act.shape[0]) actor_range = torch.arange(act.shape[1]) batch_actor_range = torch.arange(act.shape[0] * act.shape[1]) if weight is None: weight = torch.ones_like(act) temp_q = q.reshape(act.shape[0] * act.shape[1], -1) temp_act = act.reshape(act.shape[0] * act.shape[1]) q_s_a = temp_q[batch_actor_range, temp_act] q_s_a = q_s_a.reshape(act.shape[0], act.shape[1]) target_q_s_a = gamma * (1 - done).unsqueeze(1) * v + reward.unsqueeze(1) td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) return (td_error_per_sample * weight).mean(), td_error_per_sample def view_similar(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: size = list(x.shape) + [1 for _ in range(len(target.shape) - len(x.shape))] return x.view(*size) nstep_return_data = namedtuple('nstep_return_data', ['reward', 'next_value', 'done']) def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_gamma: Optional[torch.Tensor] = None): ''' Overview: Calculate nstep return for DQN algorithm, support single agent case and multi agent case. Arguments: - data (:obj:`nstep_return_data`): The input data, nstep_return_data to calculate loss - gamma (:obj:`float`): Discount factor - nstep (:obj:`int`): nstep num - value_gamma (:obj:`torch.Tensor`): Discount factor for value Returns: - return (:obj:`torch.Tensor`): nstep return Shapes: - data (:obj:`nstep_return_data`): the nstep_return_data containing\ ['reward', 'next_value', 'done'] - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - next_value (:obj:`torch.FloatTensor`): :math:`(, B)` - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep Examples: >>> data = nstep_return_data( >>> reward=torch.randn(3, 3), >>> next_value=torch.randn(3), >>> done=torch.randint(0, 2, (3,)), >>> ) >>> loss = nstep_return(data, 0.99, 3) ''' reward, next_value, done = data assert reward.shape[0] == nstep device = reward.device if isinstance(gamma, float): reward_factor = torch.ones(nstep).to(device) for i in range(1, nstep): reward_factor[i] = gamma * reward_factor[i - 1] reward_factor = view_similar(reward_factor, reward) return_tmp = reward.mul(reward_factor).sum(0) if value_gamma is None: return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done) else: return_ = return_tmp + value_gamma * next_value * (1 - done) elif isinstance(gamma, list): # if gamma is list, for NGU policy case reward_factor = torch.ones([nstep + 1, done.shape[0]]).to(device) for i in range(1, nstep + 1): reward_factor[i] = torch.stack(gamma, dim=0).to(device) * reward_factor[i - 1] reward_factor = view_similar(reward_factor, reward) return_tmp = reward.mul(reward_factor[:nstep]).sum(0) return_ = return_tmp + reward_factor[nstep] * next_value * (1 - done) else: raise TypeError("The type of gamma should be float or list") return return_ dist_1step_td_data = namedtuple( 'dist_1step_td_data', ['dist', 'next_dist', 'act', 'next_act', 'reward', 'done', 'weight'] ) def dist_1step_td_error( data: namedtuple, gamma: float, v_min: float, v_max: float, n_atom: int, ) -> torch.Tensor: """ Overview: 1 step td_error for distributed q-learning based algorithm Arguments: - data (:obj:`dist_1step_td_data`): The input data, dist_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - v_min (:obj:`float`): The min value of support - v_max (:obj:`float`): The max value of support - n_atom (:obj:`int`): The num of atom Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor Shapes: - data (:obj:`dist_1step_td_data`): the dist_1step_td_data containing\ ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] - dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom] - next_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` - act (:obj:`torch.LongTensor`): :math:`(B, )` - next_act (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(, B)` - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight Examples: >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) >>> next_dist = torch.randn(4, 3, 51).abs() >>> act = torch.randint(0, 3, (4,)) >>> next_act = torch.randint(0, 3, (4,)) >>> reward = torch.randn(4) >>> done = torch.randint(0, 2, (4,)) >>> data = dist_1step_td_data(dist, next_dist, act, next_act, reward, done, None) >>> loss = dist_1step_td_error(data, 0.99, -10.0, 10.0, 51) """ dist, next_dist, act, next_act, reward, done, weight = data device = reward.device assert len(reward.shape) == 1, reward.shape support = torch.linspace(v_min, v_max, n_atom).to(device) delta_z = (v_max - v_min) / (n_atom - 1) if len(act.shape) == 1: reward = reward.unsqueeze(-1) done = done.unsqueeze(-1) batch_size = act.shape[0] batch_range = torch.arange(batch_size) if weight is None: weight = torch.ones_like(reward) next_dist = next_dist[batch_range, next_act].detach() else: reward = reward.unsqueeze(-1).repeat(1, act.shape[1]) done = done.unsqueeze(-1).repeat(1, act.shape[1]) batch_size = act.shape[0] * act.shape[1] batch_range = torch.arange(act.shape[0] * act.shape[1]) action_dim = dist.shape[2] dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) reward = reward.reshape(act.shape[0] * act.shape[1], -1) done = done.reshape(act.shape[0] * act.shape[1], -1) next_dist = next_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) next_act = next_act.reshape(act.shape[0] * act.shape[1]) next_dist = next_dist[batch_range, next_act].detach() next_dist = next_dist.reshape(act.shape[0] * act.shape[1], -1) act = act.reshape(act.shape[0] * act.shape[1]) if weight is None: weight = torch.ones_like(reward) target_z = reward + (1 - done) * gamma * support target_z = target_z.clamp(min=v_min, max=v_max) b = (target_z - v_min) / delta_z l = b.floor().long() u = b.ceil().long() # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (n_atom - 1)) * (l == u)] += 1 proj_dist = torch.zeros_like(next_dist) offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size, n_atom).long().to(device) proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)) proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)) log_p = torch.log(dist[batch_range, act]) loss = -(log_p * proj_dist * weight).sum(-1).mean() return loss dist_nstep_td_data = namedtuple( 'dist_1step_td_data', ['dist', 'next_n_dist', 'act', 'next_n_act', 'reward', 'done', 'weight'] ) def shape_fn_dntd(args, kwargs): r""" Overview: Return dntd shape for hpc Returns: shape: [T, B, N, n_atom] """ if len(args) <= 0: tmp = [kwargs['data'].reward.shape[0]] tmp.extend(list(kwargs['data'].dist.shape)) else: tmp = [args[0].reward.shape[0]] tmp.extend(list(args[0].dist.shape)) return tmp @hpc_wrapper( shape_fn=shape_fn_dntd, namedtuple_data=True, include_args=[0, 1, 2, 3], include_kwargs=['data', 'gamma', 'v_min', 'v_max'] ) def dist_nstep_td_error( data: namedtuple, gamma: float, v_min: float, v_max: float, n_atom: int, nstep: int = 1, value_gamma: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Overview: Multistep (1 step or n step) td_error for distributed q-learning based algorithm, support single\ agent case and multi agent case. Arguments: - data (:obj:`dist_nstep_td_data`): The input data, dist_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - nstep (:obj:`int`): nstep num, default set to 1 Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor Shapes: - data (:obj:`dist_nstep_td_data`): the dist_nstep_td_data containing\ ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] - dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom] - next_n_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` - act (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_act (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep Examples: >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) >>> next_n_dist = torch.randn(4, 3, 51).abs() >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> reward = torch.randn(5, 4) >>> data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) >>> loss, _ = dist_nstep_td_error(data, 0.95, -10.0, 10.0, 51, 5) """ dist, next_n_dist, act, next_n_act, reward, done, weight = data device = reward.device reward_factor = torch.ones(nstep).to(device) for i in range(1, nstep): reward_factor[i] = gamma * reward_factor[i - 1] reward = torch.matmul(reward_factor, reward) support = torch.linspace(v_min, v_max, n_atom).to(device) delta_z = (v_max - v_min) / (n_atom - 1) if len(act.shape) == 1: reward = reward.unsqueeze(-1) done = done.unsqueeze(-1) batch_size = act.shape[0] batch_range = torch.arange(batch_size) if weight is None: weight = torch.ones_like(reward) elif isinstance(weight, float): weight = torch.tensor(weight) next_n_dist = next_n_dist[batch_range, next_n_act].detach() else: reward = reward.unsqueeze(-1).repeat(1, act.shape[1]) done = done.unsqueeze(-1).repeat(1, act.shape[1]) batch_size = act.shape[0] * act.shape[1] batch_range = torch.arange(act.shape[0] * act.shape[1]) action_dim = dist.shape[2] dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) reward = reward.reshape(act.shape[0] * act.shape[1], -1) done = done.reshape(act.shape[0] * act.shape[1], -1) next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) next_n_act = next_n_act.reshape(act.shape[0] * act.shape[1]) next_n_dist = next_n_dist[batch_range, next_n_act].detach() next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], -1) act = act.reshape(act.shape[0] * act.shape[1]) if weight is None: weight = torch.ones_like(reward) elif isinstance(weight, float): weight = torch.tensor(weight) if value_gamma is None: target_z = reward + (1 - done) * (gamma ** nstep) * support elif isinstance(value_gamma, float): value_gamma = torch.tensor(value_gamma).unsqueeze(-1) target_z = reward + (1 - done) * value_gamma * support else: value_gamma = value_gamma.unsqueeze(-1) target_z = reward + (1 - done) * value_gamma * support target_z = target_z.clamp(min=v_min, max=v_max) b = (target_z - v_min) / delta_z l = b.floor().long() u = b.ceil().long() # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (n_atom - 1)) * (l == u)] += 1 proj_dist = torch.zeros_like(next_n_dist) offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size, n_atom).long().to(device) proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_n_dist * (u.float() - b)).view(-1)) proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_n_dist * (b - l.float())).view(-1)) assert (dist[batch_range, act] > 0.0).all(), ("dist act", dist[batch_range, act], "dist:", dist) log_p = torch.log(dist[batch_range, act]) if len(weight.shape) == 1: weight = weight.unsqueeze(-1) td_error_per_sample = -(log_p * proj_dist).sum(-1) loss = -(log_p * proj_dist * weight).sum(-1).mean() return loss, td_error_per_sample v_1step_td_data = namedtuple('v_1step_td_data', ['v', 'next_v', 'reward', 'done', 'weight']) def v_1step_td_error( data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa ) -> torch.Tensor: ''' Overview: 1 step td_error for distributed value based algorithm Arguments: - data (:obj:`v_1step_td_data`): The input data, v_1step_td_data to calculate loss - gamma (:obj:`float`): Discount factor - criterion (:obj:`torch.nn.modules`): Loss function criterion Returns: - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor Shapes: - data (:obj:`v_1step_td_data`): the v_1step_td_data containing\ ['v', 'next_v', 'reward', 'done', 'weight'] - v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ] - next_v (:obj:`torch.FloatTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(, B)` - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight Examples: >>> v = torch.randn(5).requires_grad_(True) >>> next_v = torch.randn(5) >>> reward = torch.rand(5) >>> done = torch.zeros(5) >>> data = v_1step_td_data(v, next_v, reward, done, None) >>> loss, td_error_per_sample = v_1step_td_error(data, 0.99) ''' v, next_v, reward, done, weight = data if weight is None: weight = torch.ones_like(v) if len(v.shape) == len(reward.shape): if done is not None: target_v = gamma * (1 - done) * next_v + reward else: target_v = gamma * next_v + reward else: if done is not None: target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1) else: target_v = gamma * next_v + reward.unsqueeze(1) td_error_per_sample = criterion(v, target_v.detach()) return (td_error_per_sample * weight).mean(), td_error_per_sample v_nstep_td_data = namedtuple('v_nstep_td_data', ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']) def v_nstep_td_error( data: namedtuple, gamma: float, nstep: int = 1, criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa ) -> torch.Tensor: r""" Overview: Multistep (n step) td_error for distributed value based algorithm Arguments: - data (:obj:`dist_nstep_td_data`): The input data, v_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - nstep (:obj:`int`): nstep num, default set to 1 Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor Shapes: - data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing\ ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma'] - v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ] - next_v (:obj:`torch.FloatTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight - value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step\ we use value_gamma as the gamma discount value for next_v rather than gamma**n_step Examples: >>> v = torch.randn(5).requires_grad_(True) >>> next_v = torch.randn(5) >>> reward = torch.rand(5, 5) >>> done = torch.zeros(5) >>> data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99) >>> loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) """ v, next_n_v, reward, done, weight, value_gamma = data if weight is None: weight = torch.ones_like(v) target_v = nstep_return(nstep_return_data(reward, next_n_v, done), gamma, nstep, value_gamma) td_error_per_sample = criterion(v, target_v.detach()) return (td_error_per_sample * weight).mean(), td_error_per_sample q_nstep_td_data = namedtuple( 'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'] ) dqfd_nstep_td_data = namedtuple( 'dqfd_nstep_td_data', [ 'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_one_step', 'weight', 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert' ] ) def shape_fn_qntd(args, kwargs): r""" Overview: Return qntd shape for hpc Returns: shape: [T, B, N] """ if len(args) <= 0: tmp = [kwargs['data'].reward.shape[0]] tmp.extend(list(kwargs['data'].q.shape)) else: tmp = [args[0].reward.shape[0]] tmp.extend(list(args[0].q.shape)) return tmp @hpc_wrapper(shape_fn=shape_fn_qntd, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma']) def q_nstep_td_error( data: namedtuple, gamma: Union[float, list], nstep: int = 1, cum_reward: bool = False, value_gamma: Optional[torch.Tensor] = None, criterion: torch.nn.modules = nn.MSELoss(reduction='none'), ) -> torch.Tensor: """ Overview: Multistep (1 step or n step) td_error for q-learning based algorithm Arguments: - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value - criterion (:obj:`torch.nn.modules`): Loss function criterion - nstep (:obj:`int`): nstep num, default set to 1 Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` Examples: >>> next_q = torch.randn(4, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep =3 >>> q = torch.randn(4, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) >>> loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep) """ q, next_n_q, action, next_n_action, reward, done, weight = data if weight is None: weight = torch.ones_like(reward) if len(action.shape) == 1: # single agent case action = action.unsqueeze(-1) elif len(action.shape) > 1: # MARL case reward = reward.unsqueeze(-1) weight = weight.unsqueeze(-1) done = done.unsqueeze(-1) if value_gamma is not None: value_gamma = value_gamma.unsqueeze(-1) q_s_a = q.gather(-1, action).squeeze(-1) target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) if cum_reward: if value_gamma is None: target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) else: target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) else: target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) return (td_error_per_sample * weight).mean(), td_error_per_sample def bdq_nstep_td_error( data: namedtuple, gamma: Union[float, list], nstep: int = 1, cum_reward: bool = False, value_gamma: Optional[torch.Tensor] = None, criterion: torch.nn.modules = nn.MSELoss(reduction='none'), ) -> torch.Tensor: """ Overview: Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for \ Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946. In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \ calculation method of n-step, i.e., TD-error: Arguments: - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value - criterion (:obj:`torch.nn.modules`): Loss function criterion - nstep (:obj:`int`): nstep num, default set to 1 Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing \ ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` i.e. [batch_size, branch_num, action_bins_per_branch] - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` - action (:obj:`torch.LongTensor`): :math:`(B, D)` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, D)` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` Examples: >>> action_per_branch = 3 >>> next_q = torch.randn(8, 6, action_per_branch) >>> done = torch.randn(8) >>> action = torch.randint(0, action_per_branch, size=(8, 6)) >>> next_action = torch.randint(0, action_per_branch, size=(8, 6)) >>> nstep =3 >>> q = torch.randn(8, 6, action_per_branch).requires_grad_(True) >>> reward = torch.rand(nstep, 8) >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) >>> loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep) """ q, next_n_q, action, next_n_action, reward, done, weight = data if weight is None: weight = torch.ones_like(reward) reward = reward.unsqueeze(-1) done = done.unsqueeze(-1) if value_gamma is not None: value_gamma = value_gamma.unsqueeze(-1) q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1) target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) if cum_reward: if value_gamma is None: target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) else: target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) else: target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) td_error_per_sample = td_error_per_sample.mean(-1) return (td_error_per_sample * weight).mean(), td_error_per_sample def shape_fn_qntd_rescale(args, kwargs): r""" Overview: Return qntd_rescale shape for hpc Returns: shape: [T, B, N] """ if len(args) <= 0: tmp = [kwargs['data'].reward.shape[0]] tmp.extend(list(kwargs['data'].q.shape)) else: tmp = [args[0].reward.shape[0]] tmp.extend(list(args[0].q.shape)) return tmp @hpc_wrapper( shape_fn=shape_fn_qntd_rescale, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma'] ) def q_nstep_td_error_with_rescale( data: namedtuple, gamma: Union[float, list], nstep: int = 1, value_gamma: Optional[torch.Tensor] = None, criterion: torch.nn.modules = nn.MSELoss(reduction='none'), trans_fn: Callable = value_transform, inv_trans_fn: Callable = value_inv_transform, ) -> torch.Tensor: """ Overview: Multistep (1 step or n step) td_error with value rescaling Arguments: - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - nstep (:obj:`int`): nstep num, default set to 1 - criterion (:obj:`torch.nn.modules`): Loss function criterion - trans_fn (:obj:`Callable`): Value transfrom function, default to value_transform\ (refer to rl_utils/value_rescale.py) - inv_trans_fn (:obj:`Callable`): Value inverse transfrom function, default to value_inv_transform\ (refer to rl_utils/value_rescale.py) Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep Examples: >>> next_q = torch.randn(4, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep =3 >>> q = torch.randn(4, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) >>> loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep) """ q, next_n_q, action, next_n_action, reward, done, weight = data assert len(action.shape) == 1, action.shape if weight is None: weight = torch.ones_like(action) batch_range = torch.arange(action.shape[0]) q_s_a = q[batch_range, action] target_q_s_a = next_n_q[batch_range, next_n_action] target_q_s_a = inv_trans_fn(target_q_s_a) target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) target_q_s_a = trans_fn(target_q_s_a) td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) return (td_error_per_sample * weight).mean(), td_error_per_sample def dqfd_nstep_td_error( data: namedtuple, gamma: float, lambda_n_step_td: float, lambda_supervised_loss: float, margin_function: float, lambda_one_step_td: float = 1., nstep: int = 1, cum_reward: bool = False, value_gamma: Optional[torch.Tensor] = None, criterion: torch.nn.modules = nn.MSELoss(reduction='none'), ) -> torch.Tensor: """ Overview: Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd Arguments: - data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss - gamma (:obj:`float`): discount factor - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value - criterion (:obj:`torch.nn.modules`): Loss function criterion - nstep (:obj:`int`): nstep num, default set to 10 Returns: - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\ + supervised margin loss, 1-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\ ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\ , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)` - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )` - is_expert (:obj:`int`) : 0 or 1 Examples: >>> next_q = torch.randn(4, 3) >>> done = torch.randn(4) >>> done_1 = torch.randn(4) >>> next_q_one_step = torch.randn(4, 3) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> next_action_one_step = torch.randint(0, 3, size=(4, )) >>> is_expert = torch.ones((4)) >>> nstep = 3 >>> q = torch.randn(4, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = dqfd_nstep_td_data( >>> q, next_q, action, next_action, reward, done, done_1, None, >>> next_q_one_step, next_action_one_step, is_expert >>> ) >>> loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( >>> data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, >>> margin_function=0.8, nstep=nstep >>> ) """ q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \ is_expert = data # set is_expert flag(expert 1, agent 0) assert len(action.shape) == 1, action.shape if weight is None: weight = torch.ones_like(action) batch_range = torch.arange(action.shape[0]) q_s_a = q[batch_range, action] target_q_s_a = next_n_q[batch_range, next_n_action] target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step] # calculate n-step TD-loss if cum_reward: if value_gamma is None: target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) else: target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) else: target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) # calculate 1-step TD-loss nstep = 1 reward = reward[0].unsqueeze(0) # get the one-step reward value_gamma = None if cum_reward: if value_gamma is None: target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step) else: target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step) else: target_q_s_a_one_step = nstep_return( nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma ) td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach()) device = q_s_a.device device_cpu = torch.device('cpu') # calculate the supervised loss l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, ) l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)) # along the first dimension. for the index of the action, fill the corresponding position in l with 0 JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a) return ( ( ( lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + lambda_supervised_loss * JE ) * weight ).mean(), lambda_n_step_td * td_error_per_sample.abs() + lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(), (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean()) ) def dqfd_nstep_td_error_with_rescale( data: namedtuple, gamma: float, lambda_n_step_td: float, lambda_supervised_loss: float, lambda_one_step_td: float, margin_function: float, nstep: int = 1, cum_reward: bool = False, value_gamma: Optional[torch.Tensor] = None, criterion: torch.nn.modules = nn.MSELoss(reduction='none'), trans_fn: Callable = value_transform, inv_trans_fn: Callable = value_inv_transform, ) -> torch.Tensor: """ Overview: Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd Arguments: - data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value - criterion (:obj:`torch.nn.modules`): Loss function criterion - nstep (:obj:`int`): nstep num, default set to 10 Returns: - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\ + supervised margin loss, 1-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\ , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)` - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )` - is_expert (:obj:`int`) : 0 or 1 """ q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \ is_expert = data # set is_expert flag(expert 1, agent 0) assert len(action.shape) == 1, action.shape if weight is None: weight = torch.ones_like(action) batch_range = torch.arange(action.shape[0]) q_s_a = q[batch_range, action] target_q_s_a = next_n_q[batch_range, next_n_action] target_q_s_a = inv_trans_fn(target_q_s_a) # rescale target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step] target_q_s_a_one_step = inv_trans_fn(target_q_s_a_one_step) # rescale # calculate n-step TD-loss if cum_reward: if value_gamma is None: target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) else: target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) else: # to use value_gamma in n-step TD-loss target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) target_q_s_a = trans_fn(target_q_s_a) # rescale td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) # calculate 1-step TD-loss nstep = 1 reward = reward[0].unsqueeze(0) # get the one-step reward value_gamma = None # This is very important, to use gamma in 1-step TD-loss if cum_reward: if value_gamma is None: target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step) else: target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step) else: target_q_s_a_one_step = nstep_return( nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma ) target_q_s_a_one_step = trans_fn(target_q_s_a_one_step) # rescale td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach()) device = q_s_a.device device_cpu = torch.device('cpu') # calculate the supervised loss l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, ) l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)) # along the first dimension. for the index of the action, fill the corresponding position in l with 0 JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a) return ( ( ( lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + lambda_supervised_loss * JE ) * weight ).mean(), lambda_n_step_td * td_error_per_sample.abs() + lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(), (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean()) ) qrdqn_nstep_td_data = namedtuple( 'qrdqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'tau', 'weight'] ) def qrdqn_nstep_td_error( data: namedtuple, gamma: float, nstep: int = 1, value_gamma: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Overview: Multistep (1 step or n step) td_error with in QRDQN Arguments: - data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - nstep (:obj:`int`): nstep num, default set to 1 Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim] - next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep Examples: >>> next_q = torch.randn(4, 3, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep = 3 >>> q = torch.randn(4, 3, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, 3, None) >>> loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep) """ q, next_n_q, action, next_n_action, reward, done, tau, weight = data assert len(action.shape) == 1, action.shape assert len(next_n_action.shape) == 1, next_n_action.shape assert len(done.shape) == 1, done.shape assert len(q.shape) == 3, q.shape assert len(next_n_q.shape) == 3, next_n_q.shape assert len(reward.shape) == 2, reward.shape if weight is None: weight = torch.ones_like(action) batch_range = torch.arange(action.shape[0]) # shape: batch_size x num x 1 q_s_a = q[batch_range, action, :].unsqueeze(2) # shape: batch_size x 1 x num target_q_s_a = next_n_q[batch_range, next_n_action, :].unsqueeze(1) assert reward.shape[0] == nstep reward_factor = torch.ones(nstep).to(reward) for i in range(1, nstep): reward_factor[i] = gamma * reward_factor[i - 1] # shape: batch_size reward = torch.matmul(reward_factor, reward) # shape: batch_size x 1 x num if value_gamma is None: target_q_s_a = reward.unsqueeze(-1).unsqueeze(-1) + (gamma ** nstep ) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1) else: target_q_s_a = reward.unsqueeze(-1).unsqueeze( -1 ) + value_gamma.unsqueeze(-1).unsqueeze(-1) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1) # shape: batch_size x num x num u = F.smooth_l1_loss(target_q_s_a, q_s_a, reduction="none") # shape: batch_size loss = (u * (tau - (target_q_s_a - q_s_a).detach().le(0.).float()).abs()).sum(-1).mean(1) return (loss * weight).mean(), loss def q_nstep_sql_td_error( data: namedtuple, gamma: float, alpha: float, nstep: int = 1, cum_reward: bool = False, value_gamma: Optional[torch.Tensor] = None, criterion: torch.nn.modules = nn.MSELoss(reduction='none'), ) -> torch.Tensor: """ Overview: Multistep (1 step or n step) td_error for q-learning based algorithm Arguments: - data (:obj:`q_nstep_td_data`): The input data, q_nstep_sql_td_data to calculate loss - gamma (:obj:`float`): Discount factor - Alpha (:obj:ï½€float`): A parameter to weight entropy term in a policy equation - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target soft_q_value - criterion (:obj:`torch.nn.modules`): Loss function criterion - nstep (:obj:`int`): nstep num, default set to 1 Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` Examples: >>> next_q = torch.randn(4, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep = 3 >>> q = torch.randn(4, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) >>> loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep) """ q, next_n_q, action, next_n_action, reward, done, weight = data assert len(action.shape) == 1, action.shape if weight is None: weight = torch.ones_like(action) batch_range = torch.arange(action.shape[0]) q_s_a = q[batch_range, action] # target_q_s_a = next_n_q[batch_range, next_n_action] target_v = alpha * torch.logsumexp( next_n_q / alpha, 1 ) # target_v = alpha * torch.log(torch.sum(torch.exp(next_n_q / alpha), 1)) target_v[target_v == float("Inf")] = 20 target_v[target_v == float("-Inf")] = -20 # For an appropriate hyper-parameter alpha, these hardcodes can be removed. # However, algorithms may face the danger of explosion for other alphas. # The hardcodes above are to prevent this situation from happening record_target_v = copy.deepcopy(target_v) # print(target_v) if cum_reward: if value_gamma is None: target_v = reward + (gamma ** nstep) * target_v * (1 - done) else: target_v = reward + value_gamma * target_v * (1 - done) else: target_v = nstep_return(nstep_return_data(reward, target_v, done), gamma, nstep, value_gamma) td_error_per_sample = criterion(q_s_a, target_v.detach()) return (td_error_per_sample * weight).mean(), td_error_per_sample, record_target_v iqn_nstep_td_data = namedtuple( 'iqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'replay_quantiles', 'weight'] ) def iqn_nstep_td_error( data: namedtuple, gamma: float, nstep: int = 1, kappa: float = 1.0, value_gamma: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Overview: Multistep (1 step or n step) td_error with in IQN, \ referenced paper Implicit Quantile Networks for Distributional Reinforcement Learning \ Arguments: - data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - nstep (:obj:`int`): nstep num, default set to 1 - criterion (:obj:`torch.nn.modules`): Loss function criterion - beta_function (:obj:`Callable`): The risk function Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim] - next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep Examples: >>> next_q = torch.randn(3, 4, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep = 3 >>> q = torch.randn(3, 4, 3).requires_grad_(True) >>> replay_quantile = torch.randn([3, 4, 1]) >>> reward = torch.rand(nstep, 4) >>> data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None) >>> loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep) """ q, next_n_q, action, next_n_action, reward, done, replay_quantiles, weight = data assert len(action.shape) == 1, action.shape assert len(next_n_action.shape) == 1, next_n_action.shape assert len(done.shape) == 1, done.shape assert len(q.shape) == 3, q.shape assert len(next_n_q.shape) == 3, next_n_q.shape assert len(reward.shape) == 2, reward.shape if weight is None: weight = torch.ones_like(action) batch_size = done.shape[0] tau = q.shape[0] tau_prime = next_n_q.shape[0] action = action.repeat([tau, 1]).unsqueeze(-1) next_n_action = next_n_action.repeat([tau_prime, 1]).unsqueeze(-1) # shape: batch_size x tau x a q_s_a = torch.gather(q, -1, action).permute([1, 0, 2]) # shape: batch_size x tau_prim x 1 target_q_s_a = torch.gather(next_n_q, -1, next_n_action).permute([1, 0, 2]) assert reward.shape[0] == nstep device = torch.device("cuda" if reward.is_cuda else "cpu") reward_factor = torch.ones(nstep).to(device) for i in range(1, nstep): reward_factor[i] = gamma * reward_factor[i - 1] reward = torch.matmul(reward_factor, reward) if value_gamma is None: target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1) else: target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done ).unsqueeze(-1) target_q_s_a = target_q_s_a.unsqueeze(-1) # shape: batch_size x tau' x tau x 1. bellman_errors = (target_q_s_a[:, :, None, :] - q_s_a[:, None, :, :]) # The huber loss (see Section 2.3 of the paper) is defined via two cases: huber_loss = torch.where( bellman_errors.abs() <= kappa, 0.5 * bellman_errors ** 2, kappa * (bellman_errors.abs() - 0.5 * kappa) ) # Reshape replay_quantiles to batch_size x num_tau_samples x 1 replay_quantiles = replay_quantiles.reshape([tau, batch_size, 1]).permute([1, 0, 2]) # shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. replay_quantiles = replay_quantiles[:, None, :, :].repeat([1, tau_prime, 1, 1]) # shape: batch_size x tau_prime x tau x 1. quantile_huber_loss = (torch.abs(replay_quantiles - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa # shape: batch_size loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0] return (loss * weight).mean(), loss fqf_nstep_td_data = namedtuple( 'fqf_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'quantiles_hats', 'weight'] ) def fqf_nstep_td_error( data: namedtuple, gamma: float, nstep: int = 1, kappa: float = 1.0, value_gamma: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Overview: Multistep (1 step or n step) td_error with in FQF, \ referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \ Arguments: - data (:obj:`fqf_nstep_td_data`): The input data, fqf_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor - nstep (:obj:`int`): nstep num, default set to 1 - criterion (:obj:`torch.nn.modules`): Loss function criterion - beta_function (:obj:`Callable`): The risk function Returns: - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor Shapes: - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:`torch.FloatTensor`): :math:`(B, tau, N)` i.e. [batch_size, tau, action_dim] - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, tau', N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep - quantiles_hats (:obj:`torch.FloatTensor`): :math:`(B, tau)` Examples: >>> next_q = torch.randn(4, 3, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep = 3 >>> q = torch.randn(4, 3, 3).requires_grad_(True) >>> quantiles_hats = torch.randn([4, 3]) >>> reward = torch.rand(nstep, 4) >>> data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None) >>> loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep) """ q, next_n_q, action, next_n_action, reward, done, quantiles_hats, weight = data assert len(action.shape) == 1, action.shape assert len(next_n_action.shape) == 1, next_n_action.shape assert len(done.shape) == 1, done.shape assert len(q.shape) == 3, q.shape assert len(next_n_q.shape) == 3, next_n_q.shape assert len(reward.shape) == 2, reward.shape if weight is None: weight = torch.ones_like(action) batch_size = done.shape[0] tau = q.shape[1] tau_prime = next_n_q.shape[1] # shape: batch_size x tau x 1 q_s_a = evaluate_quantile_at_action(q, action) # shape: batch_size x tau_prime x 1 target_q_s_a = evaluate_quantile_at_action(next_n_q, next_n_action) assert reward.shape[0] == nstep reward_factor = torch.ones(nstep).to(reward.device) for i in range(1, nstep): reward_factor[i] = gamma * reward_factor[i - 1] reward = torch.matmul(reward_factor, reward) # [batch_size] if value_gamma is None: target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1) else: target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done ).unsqueeze(-1) target_q_s_a = target_q_s_a.unsqueeze(-1) # shape: batch_size x tau' x tau x 1. bellman_errors = (target_q_s_a.unsqueeze(2) - q_s_a.unsqueeze(1)) # shape: batch_size x tau' x tau x 1 huber_loss = F.smooth_l1_loss(target_q_s_a.unsqueeze(2), q_s_a.unsqueeze(1), reduction="none") # shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. quantiles_hats = quantiles_hats[:, None, :, None].repeat([1, tau_prime, 1, 1]) # shape: batch_size x tau_prime x tau x 1. quantile_huber_loss = (torch.abs(quantiles_hats - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa # shape: batch_size loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0] return (loss * weight).mean(), loss def evaluate_quantile_at_action(q_s, actions): assert q_s.shape[0] == actions.shape[0] batch_size, num_quantiles = q_s.shape[:2] # Expand actions into (batch_size, num_quantiles, 1). action_index = actions[:, None, None].expand(batch_size, num_quantiles, 1) # Calculate quantile values at specified actions. q_s_a = q_s.gather(dim=2, index=action_index) return q_s_a def fqf_calculate_fraction_loss(q_tau_i, q_value, quantiles, actions): """ Overview: Calculate the fraction loss in FQF, \ referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \ Arguments: - q_tau_i (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles-1, action_dim)` - q_value (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles, action_dim)` - quantiles (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles+1)` - actions (:obj:`torch.LongTensor`): :math:`(batch_size, )` Returns: - fraction_loss (:obj:`torch.Tensor`): fraction loss, 0-dim tensor """ assert q_value.requires_grad batch_size = q_value.shape[0] num_quantiles = q_value.shape[1] with torch.no_grad(): sa_quantiles = evaluate_quantile_at_action(q_tau_i, actions) assert sa_quantiles.shape == (batch_size, num_quantiles - 1, 1) q_s_a_hats = evaluate_quantile_at_action(q_value, actions) # [batch_size, num_quantiles, 1] assert q_s_a_hats.shape == (batch_size, num_quantiles, 1) assert not q_s_a_hats.requires_grad # NOTE: Proposition 1 in the paper requires F^{-1} is non-decreasing. # I relax this requirements and calculate gradients of quantiles even when # F^{-1} is not non-decreasing. values_1 = sa_quantiles - q_s_a_hats[:, :-1] signs_1 = sa_quantiles > torch.cat([q_s_a_hats[:, :1], sa_quantiles[:, :-1]], dim=1) assert values_1.shape == signs_1.shape values_2 = sa_quantiles - q_s_a_hats[:, 1:] signs_2 = sa_quantiles < torch.cat([sa_quantiles[:, 1:], q_s_a_hats[:, -1:]], dim=1) assert values_2.shape == signs_2.shape gradient_of_taus = (torch.where(signs_1, values_1, -values_1) + torch.where(signs_2, values_2, -values_2)).view(batch_size, num_quantiles - 1) assert not gradient_of_taus.requires_grad assert gradient_of_taus.shape == quantiles[:, 1:-1].shape # Gradients of the network parameters and corresponding loss # are calculated using chain rule. fraction_loss = (gradient_of_taus * quantiles[:, 1:-1]).sum(dim=1).mean() return fraction_loss td_lambda_data = namedtuple('td_lambda_data', ['value', 'reward', 'weight']) def shape_fn_td_lambda(args, kwargs): r""" Overview: Return td_lambda shape for hpc Returns: shape: [T, B] """ if len(args) <= 0: tmp = kwargs['data'].reward.shape[0] else: tmp = args[0].reward.shape return tmp @hpc_wrapper( shape_fn=shape_fn_td_lambda, namedtuple_data=True, include_args=[0, 1, 2], include_kwargs=['data', 'gamma', 'lambda_'] ) def td_lambda_error(data: namedtuple, gamma: float = 0.9, lambda_: float = 0.8) -> torch.Tensor: """ Overview: Computing TD(lambda) loss given constant gamma and lambda. There is no special handling for terminal state value, if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal (*including the terminal state*, values[terminal] should also be 0) Arguments: - data (:obj:`namedtuple`): td_lambda input data with fields ['value', 'reward', 'weight'] - gamma (:obj:`float`): Constant discount factor gamma, should be in [0, 1], defaults to 0.9 - lambda (:obj:`float`): Constant lambda, should be in [0, 1], defaults to 0.8 Returns: - loss (:obj:`torch.Tensor`): Computed MSE loss, averaged over the batch Shapes: - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`, where T is trajectory length and B is batch,\ which is the estimation of the state value at step 0 to T - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, the returns from time step 0 to T-1 - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight - loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor Examples: >>> T, B = 8, 4 >>> value = torch.randn(T + 1, B).requires_grad_(True) >>> reward = torch.rand(T, B) >>> loss = td_lambda_error(td_lambda_data(value, reward, None)) """ value, reward, weight = data if weight is None: weight = torch.ones_like(reward) with torch.no_grad(): return_ = generalized_lambda_returns(value, reward, gamma, lambda_) # discard the value at T as it should be considered in the next slice loss = 0.5 * (F.mse_loss(return_, value[:-1], reduction='none') * weight).mean() return loss def generalized_lambda_returns( bootstrap_values: torch.Tensor, rewards: torch.Tensor, gammas: float, lambda_: float, done: Optional[torch.Tensor] = None ) -> torch.Tensor: r""" Overview: Functional equivalent to trfl.value_ops.generalized_lambda_returns https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L74 Passing in a number instead of tensor to make the value constant for all samples in batch Arguments: - bootstrap_values (:obj:`torch.Tensor` or :obj:`float`): estimation of the value at step 0 to *T*, of size [T_traj+1, batchsize] - rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize] - gammas (:obj:`torch.Tensor` or :obj:`float`): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize] - lambda (:obj:`torch.Tensor` or :obj:`float`): Determining the mix of bootstrapping vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize] - done (:obj:`torch.Tensor` or :obj:`float`): Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize] Returns: - return (:obj:`torch.Tensor`): Computed lambda return value for each state from 0 to T-1, of size [T_traj, batchsize] """ if not isinstance(gammas, torch.Tensor): gammas = gammas * torch.ones_like(rewards) if not isinstance(lambda_, torch.Tensor): lambda_ = lambda_ * torch.ones_like(rewards) bootstrap_values_tp1 = bootstrap_values[1:, :] return multistep_forward_view(bootstrap_values_tp1, rewards, gammas, lambda_, done) def multistep_forward_view( bootstrap_values: torch.Tensor, rewards: torch.Tensor, gammas: float, lambda_: float, done: Optional[torch.Tensor] = None ) -> torch.Tensor: r""" Overview: Same as trfl.sequence_ops.multistep_forward_view Implementing (12.18) in Sutton & Barto ``` result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T] for t in 0...T-2 : result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1]) ``` Assuming the first dim of input tensors correspond to the index in batch Arguments: - bootstrap_values (:obj:`torch.Tensor`): Estimation of the value at *step 1 to T*, of size [T_traj, batchsize] - rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize] - gammas (:obj:`torch.Tensor`): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize] - lambda (:obj:`torch.Tensor`): Determining the mix of bootstrapping vs further accumulation of \ multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored \ and effectively set to 0, as there is no information about future rewards. - done (:obj:`torch.Tensor` or :obj:`float`): Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize] Returns: - ret (:obj:`torch.Tensor`): Computed lambda return value \ for each state from 0 to T-1, of size [T_traj, batchsize] """ result = torch.empty_like(rewards) if done is None: done = torch.zeros_like(rewards) # Forced cutoff at the last one result[-1, :] = rewards[-1, :] + (1 - done[-1, :]) * gammas[-1, :] * bootstrap_values[-1, :] discounts = gammas * lambda_ for t in reversed(range(rewards.size()[0] - 1)): result[t, :] = rewards[t, :] + (1 - done[t, :]) * \ ( discounts[t, :] * result[t + 1, :] + (gammas[t, :] - discounts[t, :]) * bootstrap_values[t, :] ) return result