siglip-tagger-test-2 / modeling_siglip.py
p1atdev's picture
Upload 2 files
a467fd4 verified
raw
history blame contribute delete
No virus
2.91 kB
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig
from transformers.utils import ModelOutput
from loss_fn import AsymmetricLossOptimized
@dataclass
class SiglipForImageClassifierOutput(ModelOutput):
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
pooler_output: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
class SiglipForImageClassification(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
def __init__(
self,
config,
):
super().__init__(config)
self.num_labels = config.num_labels
self.siglip = SiglipVisionModel(config)
# Classifier head
self.classifier = (
nn.Linear(config.hidden_size, config.num_labels)
if config.num_labels > 0
else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None
):
outputs = self.siglip(pixel_values)
pooler_output = outputs.pooler_output
logits = self.classifier(pooler_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = AsymmetricLossOptimized()
loss = loss_fct(logits, labels)
return SiglipForImageClassifierOutput(
loss=loss,
logits=logits,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)