File size: 6,260 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.utils import SequenceType


class ContrastiveLoss(nn.Module):
    """
    Overview:
        The class for contrastive learning losses. Only InfoNCE loss is supported currently. \
        Code Reference: https://github.com/rdevon/DIM. Paper Reference: https://arxiv.org/abs/1808.06670.
    Interfaces:
        ``__init__``, ``forward``.
    """

    def __init__(
            self,
            x_size: Union[int, SequenceType],
            y_size: Union[int, SequenceType],
            heads: SequenceType = [1, 1],
            encode_shape: int = 64,
            loss_type: str = "infoNCE",  # Only the InfoNCE loss is available now.
            temperature: float = 1.0,
    ) -> None:
        """
        Overview:
            Initialize the ContrastiveLoss object using the given arguments.
        Arguments:
            - x_size (:obj:`Union[int, SequenceType]`): input shape for x, both the obs shape and the encoding shape \
                are supported.
            - y_size (:obj:`Union[int, SequenceType]`): Input shape for y, both the obs shape and the encoding shape \
                are supported.
            - heads (:obj:`SequenceType`): A list of 2 int elems, ``heads[0]`` for x and ``head[1]`` for y. \
                Used in multi-head, global-local, local-local MI maximization process.
            - encoder_shape (:obj:`Union[int, SequenceType]`): The dimension of encoder hidden state.
            - loss_type: Only the InfoNCE loss is available now.
            - temperature: The parameter to adjust the ``log_softmax``.
        """
        super(ContrastiveLoss, self).__init__()
        assert len(heads) == 2, "Expected length of 2, but got: {}".format(len(heads))
        assert loss_type.lower() in ["infonce"]

        self._type = loss_type.lower()
        self._encode_shape = encode_shape
        self._heads = heads
        self._x_encoder = self._create_encoder(x_size, heads[0])
        self._y_encoder = self._create_encoder(y_size, heads[1])
        self._temperature = temperature

    def _create_encoder(self, obs_size: Union[int, SequenceType], heads: int) -> nn.Module:
        """
        Overview:
            Create the encoder for the input obs.
        Arguments:
            - obs_size (:obj:`Union[int, SequenceType]`): input shape for x, both the obs shape and the encoding shape \
                are supported. If the obs_size is an int, it means the obs is a 1D vector. If the obs_size is a list \
                such as [1, 16, 16], it means the obs is a 3D image with shape [1, 16, 16].
            - heads (:obj:`int`): The number of heads.
        Returns:
            - encoder (:obj:`nn.Module`): The encoder module.
        Examples:
            >>> obs_size = 16
            or
            >>> obs_size = [1, 16, 16]
            >>> heads = 1
            >>> encoder = self._create_encoder(obs_size, heads)
        """
        from ding.model import ConvEncoder, FCEncoder

        if isinstance(obs_size, int):
            obs_size = [obs_size]
        assert len(obs_size) in [1, 3]

        if len(obs_size) == 1:
            hidden_size_list = [128, 128, self._encode_shape * heads]
            encoder = FCEncoder(obs_size[0], hidden_size_list)
        else:
            hidden_size_list = [32, 64, 64, self._encode_shape * heads]
            if obs_size[-1] >= 36:
                encoder = ConvEncoder(obs_size, hidden_size_list)
            else:
                encoder = ConvEncoder(obs_size, hidden_size_list, kernel_size=[4, 3, 2], stride=[2, 1, 1])
        return encoder

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Overview:
            Computes the noise contrastive estimation-based loss, a.k.a. infoNCE.
        Arguments:
            - x (:obj:`torch.Tensor`): The input x, both raw obs and encoding are supported.
            - y (:obj:`torch.Tensor`): The input y, both raw obs and encoding are supported.
        Returns:
            loss (:obj:`torch.Tensor`): The calculated loss value.
        Examples:
            >>> x_dim = [3, 16]
            >>> encode_shape = 16
            >>> x = np.random.normal(0, 1, size=x_dim)
            >>> y = x ** 2 + 0.01 * np.random.normal(0, 1, size=x_dim)
            >>> estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape)
            >>> loss = estimator.forward(x, y)
        Examples:
            >>> x_dim = [3, 1, 16, 16]
            >>> encode_shape = 16
            >>> x = np.random.normal(0, 1, size=x_dim)
            >>> y = x ** 2 + 0.01 * np.random.normal(0, 1, size=x_dim)
            >>> estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape)
            >>> loss = estimator.forward(x, y)
        """

        N = x.size(0)
        x_heads, y_heads = self._heads
        x = self._x_encoder.forward(x).view(N, x_heads, self._encode_shape)
        y = self._y_encoder.forward(y).view(N, y_heads, self._encode_shape)

        x_n = x.view(-1, self._encode_shape)
        y_n = y.view(-1, self._encode_shape)

        # Use inner product to obtain positive samples.
        # [N, x_heads, encode_dim] * [N, encode_dim, y_heads] -> [N, x_heads, y_heads]
        u_pos = torch.matmul(x, y.permute(0, 2, 1)).unsqueeze(2)
        # Use outer product to obtain all sample permutations.
        # [N * x_heads, encode_dim] X [encode_dim, N * y_heads] -> [N * x_heads, N * y_heads]
        u_all = torch.mm(y_n, x_n.t()).view(N, y_heads, N, x_heads).permute(0, 2, 3, 1)

        # Mask the diagonal part to obtain the negative samples, with all diagonals setting to -10.
        mask = torch.eye(N)[:, :, None, None].to(x.device)
        n_mask = 1 - mask
        u_neg = (n_mask * u_all) - (10. * (1 - n_mask))
        u_neg = u_neg.view(N, N * x_heads, y_heads).unsqueeze(dim=1).expand(-1, x_heads, -1, -1)

        # Concatenate positive and negative samples and apply log softmax.
        pred_lgt = torch.cat([u_pos, u_neg], dim=2)
        pred_log = F.log_softmax(pred_lgt * self._temperature, dim=2)

        # The positive score is the first element of the log softmax.
        loss = -pred_log[:, :, 0, :].mean()
        return loss