File size: 12,026 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from typing import Dict, List
import math
import random
import pickle
import torch

from ding.utils import REWARD_MODEL_REGISTRY
from .base_reward_model import BaseRewardModel


def collect_state_action_pairs(iterator):
    # concat state and action
    """
    Overview:
        Concate state and action pairs from input iterator.
    Arguments:
        - iterator (:obj:`Iterable`): Iterables with at least ``obs`` and ``action`` tensor keys.
    Returns:
        - res (:obj:`Torch.tensor`): State and action pairs.
    """
    res = []
    for item in iterator:
        state = item['obs']
        action = item['action']
        # s_a = torch.cat([state, action.float()], dim=-1)
        res.append((state, action))
    return res


@REWARD_MODEL_REGISTRY.register('pwil')
class PwilRewardModel(BaseRewardModel):
    """
    Overview:
        The Pwil reward model class (https://arxiv.org/pdf/2006.04678.pdf)
    Interface:
        ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
            ``__init__``, ``_train``, ``_get_state_distance``, ``_get_action_distance``
    Config:
        == ==================  =====   =============  =======================================  =======================
        ID Symbol              Type    Default Value  Description                              Other(Shape)
        == ==================  =====   =============  =======================================  =======================
        1  ``type``             str     pwil          | Reward model register name, refer      |
                                                      | to registry ``REWARD_MODEL_REGISTRY``  |
        2  | ``expert_data_``   str     expert_data.  | Path to the expert dataset             | Should be a '.pkl'
           | ``path``                   .pkl          |                                        | file
        3  | ``sample_size``    int    1000           | sample data from expert dataset        |
                                                      | with fixed size                        |
        4  | ``alpha``          int    5              | factor alpha                           |
        5  | ``beta``           int    5              | factor beta                            |
        6  | ``s_size``         int    4              | state size                             |
        7  | ``a_size``         int    2              | action size                            |
        8  | ``clear_buffer``   int    1              | clear buffer per fixed iters           | make sure replay
             ``_per_iters``                                                                    | buffer's data count
                                                                                               | isn't too few.
                                                                                               | (code work in entry)
        == ==================  =====   =============  =======================================  =======================
    Properties:
        - reward_table (:obj: `Dict`): In this algorithm, reward model is a dictionary.
    """
    config = dict(
        # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
        type='pwil',
        # (str) Path to the expert dataset.
        # expert_data_path='expert_data.pkl',
        # (int) Sample data from expert dataset with fixed size.
        sample_size=1000,
        # r = alpha * exp((-beta*T/sqrt(|s_size|+ |a_size|))*c_i)
        # key idea for this reward is to minimize.
        # the Wasserstein distance between the state-action distribution.
        # (int) Factor alpha.
        alpha=5,
        # (int) Factor beta.
        beta=5,
        #(int)State size.
        # s_size=4,
        # (int) Action size.
        # a_size=2,
        # (int) Clear buffer per fixed iters.
        clear_buffer_per_iters=1,
    )

    def __init__(self, config: Dict, device: str, tb_logger: 'SummaryWriter') -> None:  # noqa
        """
        Overview:
            Initialize ``self.`` See ``help(type(self))`` for accurate signature.
        Arguments:
            - cfg (:obj:`Dict`): Training config
            - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
            - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary
        """
        super(PwilRewardModel, self).__init__()
        self.cfg: Dict = config
        assert device in ["cpu", "cuda"] or "cuda" in device
        self.device = device
        self.expert_data: List[tuple] = []
        self.train_data: List[tuple] = []
        # In this algo, model is a dict
        self.reward_table: Dict = {}
        self.T: int = 0

        self.load_expert_data()

    def load_expert_data(self) -> None:
        """
        Overview:
            Getting the expert data from ``config['expert_data_path']`` attribute in self
        Effects:
            This is a side effect function which updates the expert data attribute (e.g.  ``self.expert_data``); \
            in this algorithm, also the ``self.expert_s``, ``self.expert_a`` for states and actions are updated.

        """
        with open(self.cfg.expert_data_path, 'rb') as f:
            self.expert_data = pickle.load(f)
            print("the data size is:", len(self.expert_data))
        sample_size = min(self.cfg.sample_size, len(self.expert_data))
        self.expert_data = random.sample(self.expert_data, sample_size)
        self.expert_data = [(item['obs'], item['action']) for item in self.expert_data]
        self.expert_s, self.expert_a = list(zip(*self.expert_data))
        print('the expert data demonstrations is:', len(self.expert_data))

    def collect_data(self, data: list) -> None:
        """
        Overview:
            Collecting training data formatted by  ``fn:concat_state_action_pairs``.
        Arguments:
            - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc)
        Effects:
            - This is a side effect function which updates the data attribute in ``self``; \
                in this algorithm, also the ``s_size``, ``a_size`` for states and actions are updated in the \
                    attribute in ``self.cfg`` Dict; ``reward_factor`` also updated as ``collect_data`` called.
        """
        self.train_data.extend(collect_state_action_pairs(data))
        self.T = len(self.train_data)

        s_size = self.cfg.s_size
        a_size = self.cfg.a_size
        beta = self.cfg.beta
        self.reward_factor = -beta * self.T / math.sqrt(s_size + a_size)

    def train(self) -> None:
        """
        Overview:
            Training the Pwil reward model.
        """
        self._train(self.train_data)

    def estimate(self, data: list) -> List[Dict]:
        """
        Overview:
            Estimate reward by rewriting the reward key in each row of the data.
        Arguments:
            - data (:obj:`list`): the list of data used for estimation, \
                with at least ``obs`` and ``action`` keys.
        Effects:
            - This is a side effect function which updates the ``reward_table`` with ``(obs,action)`` \
                tuples from input.
        """
        # NOTE: deepcopy reward part of data is very important,
        # otherwise the reward of data in the replay buffer will be incorrectly modified.
        train_data_augmented = self.reward_deepcopy(data)
        for item in train_data_augmented:
            s = item['obs']
            a = item['action']
            if (s, a) in self.reward_table:
                item['reward'] = self.reward_table[(s, a)]
            else:
                # when (s, a) pair is not trained, set the reward value to default value(e.g.: 0)
                item['reward'] = torch.zeros_like(item['reward'])
        return train_data_augmented

    def _get_state_distance(self, s1: list, s2: list) -> torch.Tensor:
        """
        Overview:
            Getting distances of states given 2 state lists. One single state \
                is of shape ``torch.Size([n])`` (``n`` referred in in-code comments)
        Arguments:
            - s1 (:obj:`torch.Tensor list`): the 1st states' list of size M
            - s2 (:obj:`torch.Tensor list`): the 2nd states' list of size N
        Returns:
            - distance (:obj:`torch.Tensor`) Euclidean distance tensor of  \
                the state tensor lists, of size M x N.
        """
        # Format the values in the tensors to be of float type
        s1 = torch.stack(s1).float()
        s2 = torch.stack(s2).float()
        M, N = s1.shape[0], s2.shape[0]
        # Automatically fill in length
        s1 = s1.view(M, -1)
        s2 = s2.view(N, -1)
        # Automatically fill in & format the tensor size to be (MxNxn)
        s1 = s1.unsqueeze(1).repeat(1, N, 1)
        s2 = s2.unsqueeze(0).repeat(M, 1, 1)
        # Return the distance tensor of size MxN
        return ((s1 - s2) ** 2).mean(dim=-1)

    def _get_action_distance(self, a1: list, a2: list) -> torch.Tensor:
        # TODO the metric of action distance maybe different from envs
        """
        Overview:
            Getting distances of actions given 2 action lists. One single action \
                is of shape ``torch.Size([n])`` (``n`` referred in in-code comments)
        Arguments:
            - a1 (:obj:`torch.Tensor list`): the 1st actions' list of size M
            - a2 (:obj:`torch.Tensor list`): the 2nd actions' list of size N
        Returns:
            - distance (:obj:`torch.Tensor`) Euclidean distance tensor of  \
                the action tensor lists, of size M x N.
        """
        a1 = torch.stack(a1).float()
        a2 = torch.stack(a2).float()
        M, N = a1.shape[0], a2.shape[0]
        a1 = a1.view(M, -1)
        a2 = a2.view(N, -1)
        a1 = a1.unsqueeze(1).repeat(1, N, 1)
        a2 = a2.unsqueeze(0).repeat(M, 1, 1)
        return ((a1 - a2) ** 2).mean(dim=-1)

    def _train(self, data: list):
        """
        Overview:
            Helper function for ``train``, find the min disctance ``s_e``, ``a_e``.
        Arguments:
            - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc)
        Effects:
            - This is a side effect function which updates the ``reward_table`` attribute in ``self`` .
        """
        batch_s, batch_a = list(zip(*data))
        s_distance_matrix = self._get_state_distance(batch_s, self.expert_s)
        a_distance_matrix = self._get_action_distance(batch_a, self.expert_a)
        distance_matrix = s_distance_matrix + a_distance_matrix
        w_e_list = [1 / len(self.expert_data)] * len(self.expert_data)
        for i, item in enumerate(data):
            s, a = item
            w_pi = 1 / self.T
            c = 0
            expert_data_idx = torch.arange(len(self.expert_data)).tolist()
            while w_pi > 0:
                selected_dist = distance_matrix[i, expert_data_idx]
                nearest_distance = selected_dist.min().item()
                nearest_index_selected = selected_dist.argmin().item()
                nearest_index = expert_data_idx[nearest_index_selected]
                if w_pi >= w_e_list[nearest_index]:
                    c = c + nearest_distance * w_e_list[nearest_index]
                    w_pi = w_pi - w_e_list[nearest_index]
                    expert_data_idx.pop(nearest_index_selected)
                else:
                    c = c + w_pi * nearest_distance
                    w_e_list[nearest_index] = w_e_list[nearest_index] - w_pi
                    w_pi = 0
            reward = self.cfg.alpha * math.exp(self.reward_factor * c)
            self.reward_table[(s, a)] = torch.FloatTensor([reward])

    def clear_data(self) -> None:
        """
        Overview:
            Clearing training data. \
            This is a side effect function which clears the data attribute in ``self``
        """
        self.train_data.clear()