import torch import torch.nn as nn def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False): """Samples from the Gumbel-Sigmoid distribution and optionally discretizes. References: - https://github.com/yandexdataschool/gumbel_dpg/blob/master/gumbel.py - https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax Note: X - Y ~ Logistic(0,1) s.t. X, Y ~ Gumbel(0, 1). That is, we can implement gumbel_sigmoid using Logistic distribution. """ logistic = torch.rand_like(logits) logistic = logistic.div_(1. - logistic).log_() # ~Logistic(0,1) gumbels = (logits + logistic) / tau # ~Logistic(logits, tau) y_soft = gumbels.sigmoid_() if hard: # Straight through. y_hard = y_soft.gt(0.5).type(y_soft.dtype) # gt_ break gradient flow # y_hard = y_soft.gt_(0.5) # gt_() maintain dtype, different to gt() ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret class Sim2Mask(nn.Module): def __init__(self, init_w: float = 1.0, init_b: float = 0.0, gumbel_tau: float = 1.0, learnable: bool = True): """ Sim2Mask module for generating binary masks. Args: init_w (float): Initial value for weight. init_b (float): Initial value for bias. gumbel_tau (float): Gumbel-Softmax temperature. learnable (bool): If True, weight and bias are learnable parameters. Reference: "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs" CVPR 2023 - https://github.com/kakaobrain/tcl - https://arxiv.org/abs/2212.00785 """ super().__init__() self.init_w = init_w self.init_b = init_b self.gumbel_tau = gumbel_tau self.learnable = learnable assert not ((init_w is None) ^ (init_b is None)) if learnable: self.w = nn.Parameter(torch.full([], float(init_w))) self.b = nn.Parameter(torch.full([], float(init_b))) else: self.w = init_w self.b = init_b def forward(self, x, deterministic=False): logits = x * self.w + self.b soft_mask = torch.sigmoid(logits) if deterministic: hard_mask = soft_mask.gt(0.5).type(logits.dtype) else: hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau) return hard_mask, soft_mask def extra_repr(self): return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}' def norm_img_tensor(tensor: torch.Tensor) -> torch.Tensor: """ Normalize image tensor to the range [0, 1]. Args: tensor (torch.Tensor): Input image tensor. Returns: torch.Tensor: Normalized image tensor. """ vmin = tensor.amin((2, 3), keepdims=True) - 1e-7 vmax = tensor.amax((2, 3), keepdims=True) + 1e-7 tensor = (tensor - vmin) / (vmax - vmin) return tensor class ImageMasker(Sim2Mask): def forward(self, x: torch.Tensor, infer: bool = False) -> torch.Tensor: """ Forward pass for generating image-level binary masks. Args: x (torch.Tensor): Input tensor. infer (bool): True for only inference stage. Returns: torch.Tensor: Binary mask. Reference: "Can CLIP Help Sound Source Localization?" WACV 2024 - https://arxiv.org/abs/2311.04066 """ if self.training or not infer: output = super().forward(x, False)[0] else: output = torch.sigmoid(x + self.b / self.w) return output class FeatureMasker(nn.Module): def __init__(self, thr: float = 0.5, tau: float = 0.07): """ Masker module for generating feature-level masks. Args: thr (float): Threshold for generating the mask. tau (float): Temperature for the sigmoid function. Reference: "Can CLIP Help Sound Source Localization?" WACV 2024 - https://arxiv.org/abs/2311.04066 """ super().__init__() self.thr = thr self.tau = tau def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for generating feature-level masks Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Generated mask. """ return torch.sigmoid((norm_img_tensor(x) - self.thr) / self.tau)