Spaces:
Running
Running
from dataclasses import dataclass | |
import torch | |
import torch.nn as nn | |
from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig | |
from transformers.utils import ModelOutput | |
class SiglipForImageClassifierOutput(ModelOutput): | |
loss: torch.FloatTensor | None = None | |
logits: torch.FloatTensor | None = None | |
pooler_output: torch.FloatTensor | None = None | |
hidden_states: tuple[torch.FloatTensor, ...] | None = None | |
attentions: tuple[torch.FloatTensor, ...] | None = None | |
class SiglipForImageClassification(SiglipPreTrainedModel): | |
config_class = SiglipVisionConfig | |
main_input_name = "pixel_values" | |
def __init__( | |
self, | |
config, | |
): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.siglip = SiglipVisionModel(config) | |
# Classifier head | |
self.classifier = ( | |
nn.Linear(config.hidden_size, config.num_labels) | |
if config.num_labels > 0 | |
else nn.Identity() | |
) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def forward( | |
self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None | |
): | |
outputs = self.siglip(pixel_values) | |
pooler_output = outputs.pooler_output | |
logits = self.classifier(pooler_output) | |
loss = None | |
return SiglipForImageClassifierOutput( | |
loss=loss, | |
logits=logits, | |
pooler_output=outputs.pooler_output, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |