cella110n's picture
Update README.md
354fdd6 verified
|
raw
history blame
3.52 kB
metadata
license: apache-2.0

Finetuned from p1atdev/siglip-tagger-test-3
https://huggingface.co/p1atdev/siglip-tagger-test-3

test work

Usage:

import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass
from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig, AutoImageProcessor
from transformers.utils import ModelOutput

@dataclass
class SiglipForImageClassifierOutput(ModelOutput):
    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    pooler_output: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None

class SiglipForImageClassification(SiglipPreTrainedModel):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config,
    ):
        super().__init__(config)

        # self.num_labels = config.num_labels
        self.siglip = SiglipVisionModel(config)

        # Classifier head
        self.classifier = (
            nn.Linear(config.hidden_size, config.num_labels)
            if config.num_labels > 0
            else nn.Identity()
        )

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None
    ):
        outputs = self.siglip(pixel_values)
        pooler_output = outputs.pooler_output
        logits = self.classifier(pooler_output)

        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels)

        return SiglipForImageClassifierOutput(
            loss=loss,
            logits=logits,
            pooler_output=outputs.pooler_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

# モデル設定のロード
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = SiglipVisionConfig.from_pretrained('cella110n/siglip-tagger-FT3ep')
processor = AutoImageProcessor.from_pretrained("cella110n/siglip-tagger-FT3ep", config=config)
model = SiglipForImageClassification.from_pretrained('cella110n/siglip-tagger-FT3ep', torch_dtype=torch.bfloat16).to(device)

model.eval()
print("Model Loaded. device:", model.device)

from PIL import Image

# 入力画像サイズの確認と調整
img_path =  "path/to/image"
img = Image.open(img_path).

inputs = processor(images=img, return_tensors="pt")  # 画像をモデルに適した形式に変換
print("Image processed.")

# inputs.pixel_valuesの画像を表示
img = inputs.pixel_values[0].permute(1, 2, 0).cpu().numpy()
plt.imshow(img)
plt.axis('off')
plt.show()

# # モデルの予測実行
with torch.no_grad():
    logits = (model(
            **inputs.to(
            model.device,
            model.dtype
            )
        )
        .logits.detach()
        .cpu()
        .float()
    )

logits = np.clip(logits, 0.0, 1.0)  # オーバーフローを防ぐためにlogitsをクリップ

prob_cutoff = 0.3  # この確率以上のクラスのみを表示

result = {}

for prediction in logits:
    for i, prob in enumerate(prediction):
        if prob.item() > prob_cutoff:
            result[model.config.id2label[i]] = prob.item()

# resultを、高いほうから表示
sorted_result = sorted(result.items(), key=lambda x: x[1], reverse=True)
sorted_result