from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig from transformers.utils import ModelOutput from loss_fn import AsymmetricLossOptimized @dataclass class SiglipForImageClassifierOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None pooler_output: torch.FloatTensor | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None class SiglipForImageClassification(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" def __init__( self, config, ): super().__init__(config) self.num_labels = config.num_labels self.siglip = SiglipVisionModel(config) # Classifier head self.classifier = ( nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() ) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None ): outputs = self.siglip(pixel_values) pooler_output = outputs.pooler_output logits = self.classifier(pooler_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = AsymmetricLossOptimized() loss = loss_fct(logits, labels) return SiglipForImageClassifierOutput( loss=loss, logits=logits, pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )