File size: 4,676 Bytes
b20af9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn


def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False):
    """Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
    References:
        - https://github.com/yandexdataschool/gumbel_dpg/blob/master/gumbel.py
        - https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
    Note:
        X - Y ~ Logistic(0,1) s.t. X, Y ~ Gumbel(0, 1).
        That is, we can implement gumbel_sigmoid using Logistic distribution.
    """
    logistic = torch.rand_like(logits)
    logistic = logistic.div_(1. - logistic).log_()  # ~Logistic(0,1)

    gumbels = (logits + logistic) / tau  # ~Logistic(logits, tau)
    y_soft = gumbels.sigmoid_()

    if hard:
        # Straight through.
        y_hard = y_soft.gt(0.5).type(y_soft.dtype)
        # gt_ break gradient flow
        #  y_hard = y_soft.gt_(0.5)  # gt_() maintain dtype, different to gt()
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft

    return ret


class Sim2Mask(nn.Module):
    def __init__(self, init_w: float = 1.0, init_b: float = 0.0, gumbel_tau: float = 1.0, learnable: bool = True):
        """
        Sim2Mask module for generating binary masks.

        Args:
            init_w (float): Initial value for weight.
            init_b (float): Initial value for bias.
            gumbel_tau (float): Gumbel-Softmax temperature.
            learnable (bool): If True, weight and bias are learnable parameters.

        Reference:
            "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs" CVPR 2023
            - https://github.com/kakaobrain/tcl
            - https://arxiv.org/abs/2212.00785
        """
        super().__init__()
        self.init_w = init_w
        self.init_b = init_b
        self.gumbel_tau = gumbel_tau
        self.learnable = learnable

        assert not ((init_w is None) ^ (init_b is None))
        if learnable:
            self.w = nn.Parameter(torch.full([], float(init_w)))
            self.b = nn.Parameter(torch.full([], float(init_b)))
        else:
            self.w = init_w
            self.b = init_b

    def forward(self, x, deterministic=False):
        logits = x * self.w + self.b

        soft_mask = torch.sigmoid(logits)
        if deterministic:
            hard_mask = soft_mask.gt(0.5).type(logits.dtype)
        else:
            hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)

        return hard_mask, soft_mask

    def extra_repr(self):
        return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'


def norm_img_tensor(tensor: torch.Tensor) -> torch.Tensor:
    """
    Normalize image tensor to the range [0, 1].

    Args:
        tensor (torch.Tensor): Input image tensor.

    Returns:
        torch.Tensor: Normalized image tensor.
    """
    vmin = tensor.amin((2, 3), keepdims=True) - 1e-7
    vmax = tensor.amax((2, 3), keepdims=True) + 1e-7
    tensor = (tensor - vmin) / (vmax - vmin)
    return tensor


class ImageMasker(Sim2Mask):
    def forward(self, x: torch.Tensor, infer: bool = False) -> torch.Tensor:
        """
        Forward pass for generating image-level binary masks.

        Args:
            x (torch.Tensor): Input tensor.
            infer (bool): True for only inference stage.

        Returns:
            torch.Tensor: Binary mask.

        Reference:
            "Can CLIP Help Sound Source Localization?" WACV 2024
            - https://arxiv.org/abs/2311.04066
        """
        if self.training or not infer:
            output = super().forward(x, False)[0]
        else:
            output = torch.sigmoid(x + self.b / self.w)
        return output


class FeatureMasker(nn.Module):
    def __init__(self, thr: float = 0.5, tau: float = 0.07):
        """
        Masker module for generating feature-level masks.

        Args:
            thr (float): Threshold for generating the mask.
            tau (float): Temperature for the sigmoid function.

        Reference:
            "Can CLIP Help Sound Source Localization?" WACV 2024
            - https://arxiv.org/abs/2311.04066
        """
        super().__init__()
        self.thr = thr
        self.tau = tau

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for generating feature-level masks

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Generated mask.
        """
        return torch.sigmoid((norm_img_tensor(x) - self.thr) / self.tau)