Spaces:
Running
Running
File size: 4,676 Bytes
b20af9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
import torch.nn as nn
def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False):
"""Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
References:
- https://github.com/yandexdataschool/gumbel_dpg/blob/master/gumbel.py
- https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
Note:
X - Y ~ Logistic(0,1) s.t. X, Y ~ Gumbel(0, 1).
That is, we can implement gumbel_sigmoid using Logistic distribution.
"""
logistic = torch.rand_like(logits)
logistic = logistic.div_(1. - logistic).log_() # ~Logistic(0,1)
gumbels = (logits + logistic) / tau # ~Logistic(logits, tau)
y_soft = gumbels.sigmoid_()
if hard:
# Straight through.
y_hard = y_soft.gt(0.5).type(y_soft.dtype)
# gt_ break gradient flow
# y_hard = y_soft.gt_(0.5) # gt_() maintain dtype, different to gt()
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret
class Sim2Mask(nn.Module):
def __init__(self, init_w: float = 1.0, init_b: float = 0.0, gumbel_tau: float = 1.0, learnable: bool = True):
"""
Sim2Mask module for generating binary masks.
Args:
init_w (float): Initial value for weight.
init_b (float): Initial value for bias.
gumbel_tau (float): Gumbel-Softmax temperature.
learnable (bool): If True, weight and bias are learnable parameters.
Reference:
"Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs" CVPR 2023
- https://github.com/kakaobrain/tcl
- https://arxiv.org/abs/2212.00785
"""
super().__init__()
self.init_w = init_w
self.init_b = init_b
self.gumbel_tau = gumbel_tau
self.learnable = learnable
assert not ((init_w is None) ^ (init_b is None))
if learnable:
self.w = nn.Parameter(torch.full([], float(init_w)))
self.b = nn.Parameter(torch.full([], float(init_b)))
else:
self.w = init_w
self.b = init_b
def forward(self, x, deterministic=False):
logits = x * self.w + self.b
soft_mask = torch.sigmoid(logits)
if deterministic:
hard_mask = soft_mask.gt(0.5).type(logits.dtype)
else:
hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)
return hard_mask, soft_mask
def extra_repr(self):
return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'
def norm_img_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""
Normalize image tensor to the range [0, 1].
Args:
tensor (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Normalized image tensor.
"""
vmin = tensor.amin((2, 3), keepdims=True) - 1e-7
vmax = tensor.amax((2, 3), keepdims=True) + 1e-7
tensor = (tensor - vmin) / (vmax - vmin)
return tensor
class ImageMasker(Sim2Mask):
def forward(self, x: torch.Tensor, infer: bool = False) -> torch.Tensor:
"""
Forward pass for generating image-level binary masks.
Args:
x (torch.Tensor): Input tensor.
infer (bool): True for only inference stage.
Returns:
torch.Tensor: Binary mask.
Reference:
"Can CLIP Help Sound Source Localization?" WACV 2024
- https://arxiv.org/abs/2311.04066
"""
if self.training or not infer:
output = super().forward(x, False)[0]
else:
output = torch.sigmoid(x + self.b / self.w)
return output
class FeatureMasker(nn.Module):
def __init__(self, thr: float = 0.5, tau: float = 0.07):
"""
Masker module for generating feature-level masks.
Args:
thr (float): Threshold for generating the mask.
tau (float): Temperature for the sigmoid function.
Reference:
"Can CLIP Help Sound Source Localization?" WACV 2024
- https://arxiv.org/abs/2311.04066
"""
super().__init__()
self.thr = thr
self.tau = tau
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for generating feature-level masks
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Generated mask.
"""
return torch.sigmoid((norm_img_tensor(x) - self.thr) / self.tau)
|