Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import timm | |
import config as CFG | |
class TextEncoder(nn.Module): | |
""" | |
Text/Poem encoder used in PoemTextModel and CLIPModel | |
... | |
Attributes: | |
----------- | |
model : a torch.nn.Module model | |
The image encoder model | |
Methods: | |
-------- | |
forward(x) | |
returns model embeddings of x (batch of texts/poems) (of the CLS token) | |
__init__() | |
creates the encoder model using huggingface transformers, | |
also freezes the model if it's not trainable. | |
""" | |
def __init__(self, encoder_model, encoder_pretrained_name, pretrained, trainable): | |
""" | |
creates the poem or text encoder model using transformers and loads weights from pretrained model if needed. | |
Also freezes the model if it's not trainable. | |
Parameters: | |
----------- | |
pretrained: bool | |
if pretrained=True, get pretrained model's weights. else create a fresh untrained model. | |
trainable: bool | |
if trainable=False, the model's weights will be frozen. | |
encoder_model: str | |
image encoder model name used as input to get the right model from configs. | |
encoder_pretrained_name: str | |
image encoder model to get weights from. (not used when pretrained=False) | |
""" | |
super().__init__() | |
if pretrained: | |
self.model = CFG.encoders[encoder_model].from_pretrained(encoder_pretrained_name) | |
else: | |
self.model = CFG.encoders[encoder_model](config=CFG.configs[encoder_model]()) | |
for p in self.model.parameters(): | |
p.requires_grad = trainable | |
# Using the CLS token hidden representation as the sentence's embedding | |
self.target_token_idx = 0 | |
def forward(self, input_ids, attention_mask): | |
""" | |
forwards and calculates embeddings of the input using attention mask. | |
Parameters: | |
----------- | |
input_ids: input ids (output of tokenizer) | |
attention masks: input masks (for example for padding, pad tokens will be masked) | |
Returns: | |
-------- | |
the embedding of the CLS (or target) token of the encoder's last hidden state | |
""" | |
output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
last_hidden_state = output.last_hidden_state | |
return last_hidden_state[:, self.target_token_idx, :] | |
class ProjectionHead(nn.Module): | |
""" | |
Projection head used to project embeddings from each encoder to a shared embedding space | |
... | |
Attributes: | |
----------- | |
projection : torch.nn.Linear | |
The main Dense projection (from encoder's embedding dim to shared embedding projection dim) | |
gelu: torch.nn.GELU | |
activation function | |
fc: torch.nn.Linear | |
a dense layer after projection (projection_dim to projection_dim) | |
dropout: torch.nn.Dropout | |
dropout after fc | |
layer_norm: torch.nn.LayerNorm | |
layer norm after dropout | |
Methods: | |
-------- | |
forward(x) | |
returns projection embeddings from x (encoder output embeddings) | |
__init__() | |
creates the projection head | |
""" | |
def __init__( | |
self, | |
embedding_dim, | |
projection_dim=CFG.projection_dim, | |
dropout=CFG.dropout | |
): | |
""" | |
Creates the projection head used after an encoder. | |
Parameters: | |
----------- | |
embedding_dim: int | |
dimension of the output embeddings of the encoder. | |
projection_dim: int, optional | |
dimension to project embeddings to. | |
dropout: float | |
fraction of the output of fc layer to be zeroed. | |
""" | |
super().__init__() | |
self.projection = nn.Linear(embedding_dim, projection_dim) | |
self.gelu = nn.GELU() | |
self.fc = nn.Linear(projection_dim, projection_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(projection_dim) | |
def forward(self, x): | |
""" | |
Forwards and calculates projected embeddings from encoder embeddings. | |
Parameters: | |
----------- | |
x: input (of shape (batch_size, embedding_dim)) | |
the output embedding of this projection head's encoder | |
Returns: | |
-------- | |
the embeddings in a shared embedding space (of shape (batch_size, projection_dim)) | |
""" | |
projected = self.projection(x) #main projection layer | |
x = self.gelu(projected) | |
x = self.fc(x) | |
x = self.dropout(x) | |
# the projected outputs are added to x as a residual connection | |
x = x + projected | |
x = self.layer_norm(x) | |
return x | |
class ImageEncoder(nn.Module): | |
""" | |
Image encoder used in CLIPModel | |
... | |
Attributes: | |
----------- | |
model : a torch.nn.Module model from timm (pytorch-image-models) | |
The image encoder model | |
Methods: | |
-------- | |
forward(x) | |
returns model embeddings of x (batch of images) | |
__init__() | |
creates the encoder model using timm and loads fine-tuned model's state dict if needed. | |
also freezes the model if it's not trainable. | |
""" | |
def __init__( | |
self, pretrained, trainable, model_name=CFG.image_encoder_model | |
): | |
""" | |
creates the encoder model using timm and loads fine-tuned model's state dict if needed. | |
Also freezes the model if it's not trainable. | |
Parameters: | |
----------- | |
pretrained: bool | |
if pretrained=True, get SOTA weights (or weights saved in image_encoder_weights_load_path). | |
else create a fresh untrained model. | |
trainable: bool | |
if trainable=False, the model's weights will be frozen. | |
model_name: str | |
image encoder model name used as input to timm.create_model. | |
""" | |
super().__init__() | |
self.model = timm.create_model( | |
model_name, pretrained, num_classes=0, global_pool="avg" | |
) | |
if pretrained and CFG.image_encoder_weights_load_path: | |
self.model.load_state_dict(torch.load(CFG.image_encoder_weights_load_path, map_location=CFG.device)) | |
for p in self.model.parameters(): | |
p.requires_grad = trainable | |
def forward(self, x): | |
""" | |
forwards and calculates embeddings of the input. | |
Parameters: | |
----------- | |
x: input (batch of transformed images) | |
Returns: | |
-------- | |
embeddings of the model for the input (of shape (batch_size, image_embedding)) | |
""" | |
return self.model(x) | |