virtex-redcaps / virtex /models /contrastive.py
kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
from typing import Any, Dict
import torch
from torch import nn
import torch.distributed as dist
from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing
from virtex.modules.textual_heads import TextualHead
from virtex.modules.visual_backbones import VisualBackbone
class ImageTextContrastiveModel(nn.Module):
def __init__(
self,
visual: VisualBackbone,
textual: TextualHead,
label_smoothing: float = 0.0
):
super().__init__()
self.visual = visual
self.textual = textual
self.padding_idx = self.textual.padding_idx
self.visual_projection = nn.Linear(
self.visual.visual_feature_size,
self.textual.textual_feature_size,
bias=False,
)
self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07)))
self.loss = CrossEntropyLossWithLabelSmoothing(
label_smoothing, ignore_index=self.padding_idx
)
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
# Check if logit_scale needs to be clipped from last iteration.
self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 3.912)
# 50 times
# shape: (batch_size, channels, height, width)
visual_features = self.visual(batch["image"])
batch_size = visual_features.size(0)
# shape: (batch_size, channels)
visual_features = visual_features.mean(dim=[2, 3]).view(batch_size, -1)
# shape: (batch_size, textual_feature_size)
visual_features = self.visual_projection(visual_features)
caption_tokens = batch["caption_tokens"]
caption_lengths = batch["caption_lengths"]
# shape: (batch_size, max_caption_length, hidden_size)
textual_features = self.textual(caption_tokens, caption_lengths)
# Take features from the first time-step (as BERT-* models do).
# shape: (batch_size, hidden_size)
textual_features = textual_features[:, 0, :]
# Normalize visual and textual features.
# shape: (batch_size, textual_feature_size)
visual_features = visual_features / visual_features.norm(dim=-1, keepdim=True)
textual_features = textual_features / textual_features.norm(
dim=-1, keepdim=True
)
# Gather textual features from all processes into one large tensor to
# increase negative samples for contrastive learning.
gathered_textual_features = [
torch.zeros_like(textual_features) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered_textual_features, textual_features)
# Shift features of current rank to zeroth index for easy implementation.
gathered_textual_features[0], gathered_textual_features[dist.get_rank()] = (
gathered_textual_features[dist.get_rank()],
gathered_textual_features[0],
)
# shape: (batch_size * world_size, textual_feature_size)
gathered_textual_features = torch.cat(gathered_textual_features, dim=0)
# Calculate pairwise cosine similarity as logits.
logit_scale = self.logit_scale.exp()
visual_logits = logit_scale * visual_features @ gathered_textual_features.t()
# Targets are an identity matrix (image [i] should match with caption [i])
visual_loss = self.loss(
visual_logits, torch.arange(visual_logits.size(0)).to(visual_logits.device)
)
# Do the same thing for visual features.
gathered_visual_features = [
torch.zeros_like(visual_features) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered_visual_features, visual_features)
gathered_visual_features[0], gathered_visual_features[dist.get_rank()] = (
gathered_visual_features[dist.get_rank()],
gathered_visual_features[0],
)
# shape: (batch_size * world_size, textual_feature_size)
gathered_visual_features = torch.cat(gathered_visual_features, dim=0)
# Calculate pairwise cosine similarity as logits.
logit_scale = self.logit_scale.exp()
textual_logits = logit_scale * textual_features @ gathered_visual_features.t()
# Targets are an identity matrix (image [i] should match with caption [i])
textual_loss = self.loss(
textual_logits,
torch.arange(textual_logits.size(0)).to(textual_logits.device),
)
loss = 0.5 * (visual_loss + textual_loss)
output_dict: Dict[str, Any] = {
"loss": loss,
# Single scalar per batch for logging in training script.
"loss_components": {"contrastive": loss.clone().detach()},
}
return output_dict