virtex-redcaps / virtex /models /zero_shot_classification_eval.py
kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
import copy
import functools
from typing import Any, Dict
import json
import torch
from torch import nn
from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing
from virtex.modules.textual_heads import TextualHead
from virtex.modules.visual_backbones import VisualBackbone
class ZeroShotClassifier(nn.Module):
def __init__(
self,
visual: VisualBackbone,
textual: TextualHead,
):
super().__init__()
self.visual = visual
self.textual = textual
self.padding_idx = self.textual.padding_idx
# Clone the textual module for backward direction if doing captioning
# in both directions (separately).
self.backward_textual = copy.deepcopy(self.textual)
# Share weights for visual projection, and input/output embeddings.
self.backward_textual.visual_projection = self.textual.visual_projection
self.backward_textual.embedding = self.textual.embedding
self.backward_textual.output = self.textual.output
self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx,reduction='none')
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
# shape: (batch_size, channels, height, width)
visual_features = self.visual(batch["image"])
batch_size = visual_features.size(0)
classification_losses = []
#catagories shape: (1000, 20)
caption_tokens = batch["caption_tokens"]
backward_caption_tokens = batch["noitpac_tokens"]
caption_lengths = batch["caption_lengths"]
print
for i in range(caption_tokens.shape[0]):
# shape : (batch size, 20)
catagory_caption_tokens = caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1)
# shape : (batch size, 20)
catagory_backward_caption_tokens = backward_caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1)
# shape : (batch size)
catagory_caption_lengths = caption_lengths[i].unsqueeze(0).repeat(batch_size)
#print("caption_tokens.shape:",caption_tokens.shape)
#print("backward_caption_tokens.shape:",backward_caption_tokens.shape)
#print("caption_lengths.shape:",caption_lengths.shape)
#print("catagory_caption_tokens.shape:",catagory_caption_tokens.shape)
#print("catagory_backward_caption_tokens.shape:",catagory_backward_caption_tokens.shape)
#print("catagory_caption_lengths.shape:",catagory_caption_lengths.shape)
output_logits = self.textual(
visual_features, catagory_caption_tokens, catagory_caption_lengths
)
loss = self.loss(
output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),
catagory_caption_tokens[:, 1:].contiguous().view(-1)
)
# Do captioning in backward direction if specified.
backward_output_logits = self.backward_textual(
visual_features, catagory_backward_caption_tokens, catagory_caption_lengths
)
backward_loss = self.loss(
backward_output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),
catagory_backward_caption_tokens[:, 1:].contiguous().view(-1),
)
loss = loss.view(batch_size,-1).sum(dim=1)
backward_loss = backward_loss.view(batch_size,-1).sum(dim=1)
total_scores = (-loss - backward_loss)/catagory_caption_lengths
#print("loss.shape:",loss.shape)
#print("backward_loss.shape:",backward_loss.shape)
#print("loss.shape:",loss.shape)
#scores_caption = [torch.sum(x) for x in torch.chunk(loss, batch_size)]
#scores_noipac = [torch.sum(x) for x in torch.chunk(backward_loss, batch_size)]
#total_scores = [(scores_caption[j]+scores_noipac[j]).item() for j in range(batch_size)]
classification_losses.append(total_scores)
#classification_losses = torch.tensor(classification_losses)
classification_losses = torch.stack(classification_losses).t()
return classification_losses