|
import os |
|
import random |
|
from functools import partial |
|
from typing import Any |
|
|
|
import evaluate |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from datasets import Dataset, DatasetDict, load_dataset |
|
from torch.utils.data import DataLoader |
|
from tqdm.notebook import tqdm |
|
from transformers import (CLIPImageProcessor, CLIPModel, CLIPProcessor, |
|
CLIPTokenizerFast, Trainer, TrainingArguments) |
|
from datasets.formatting.formatting import LazyBatch |
|
from huggingface_hub import HfApi, login, create_repo |
|
|
|
|
|
os.environ["CURL_CA_BUNDLE"] = "" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def seed_all(seed: int): |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
seed_all(69) |
|
|
|
|
|
dataset = load_dataset("pcuenq/oxford-pets") |
|
dataset_train_val = dataset['train'].train_test_split(test_size=0.3) |
|
dataset_val_test = dataset_train_val['test'].train_test_split(test_size=0.2) |
|
dataset = DatasetDict({ |
|
"train": dataset_train_val['train'], |
|
"val": dataset_val_test['test'], |
|
"test": dataset_val_test['train'] |
|
}) |
|
|
|
labels = set(dataset['train']['label']) |
|
label2id = {label: i for i, label in enumerate(labels)} |
|
id2label = {i: label for label, i in label2id.items()} |
|
labels = list(label2id) |
|
|
|
MODEL_NAME = "openai/clip-vit-base-patch32" |
|
TOKENIZER = CLIPTokenizerFast.from_pretrained(MODEL_NAME) |
|
IMAGE_PROCESSOR = CLIPImageProcessor.from_pretrained(MODEL_NAME) |
|
|
|
|
|
def transform_class_labels(items: LazyBatch, tokenizer: CLIPTokenizerFast, label2id: dict[str, int]) -> dict[str, Any]: |
|
label_prompt = [f"a photo of {label}" for label in items["label"]] |
|
output = tokenizer(label_prompt, padding=True, return_tensors="pt") |
|
items["input_ids"] = output["input_ids"] |
|
items["attention_mask"] = output["attention_mask"] |
|
items["label_id"] = [label2id[label] for label in items["label"]] |
|
return items |
|
|
|
def transform_image(items: LazyBatch, image_processor: CLIPImageProcessor) -> dict[str, Any]: |
|
output = image_processor(items["image"], return_tensors="pt") |
|
items["pixel_values"] = output["pixel_values"] |
|
return items |
|
|
|
dataset = dataset.map(partial(transform_class_labels, tokenizer=TOKENIZER, label2id=label2id), batched=True) |
|
dataset.set_transform(partial(transform_image, image_processor=IMAGE_PROCESSOR)) |
|
|
|
|
|
def get_module_device(module: nn.Module) -> torch.device: |
|
return next(module.parameters()).device |
|
|
|
def freeze_params(module: nn.Module, freeze_top_percent: float = 1.0) -> None: |
|
all_params_length = len(list(module.parameters())) |
|
for indx, param in enumerate(module.parameters()): |
|
if int(all_params_length * freeze_top_percent) <= indx: |
|
break |
|
param.requires_grad = False |
|
|
|
def print_trainable_parameters(model: nn.Module) -> None: |
|
trainable_params = 0 |
|
all_param = 0 |
|
for _, param in model.named_parameters(): |
|
all_param += param.numel() |
|
if param.requires_grad: |
|
trainable_params += param.numel() |
|
print( |
|
f"Trainable params: {(trainable_params / 10**6):.4f}M || All params: {(all_param / 10**6):.4f}M || Trainable%: {100 * trainable_params / all_param:.2f}%" |
|
) |
|
|
|
|
|
class CLIPClassifier(nn.Module): |
|
def __init__(self, clip_model: CLIPModel, tokenizer: CLIPTokenizerFast, labels: list[str]): |
|
super().__init__() |
|
self.model = clip_model |
|
self.tokenizer = tokenizer |
|
self.logit_scale = self.model.logit_scale.exp() |
|
self.label2id = {label: i for i, label in enumerate(labels)} |
|
self.labels_embeddings = nn.Parameter(self.generate_labels_embeddings(labels)) |
|
|
|
def generate_labels_embeddings(self, labels: list[str]) -> torch.Tensor: |
|
labels_inputs = self.tokenizer( |
|
[f"a photo of {label}" for label in labels], |
|
return_tensors="pt", |
|
padding=True, |
|
).to(get_module_device(self.model)) |
|
labels_embeddings = self.model.get_text_features(**labels_inputs) |
|
labels_embeddings /= labels_embeddings.norm(p=2, dim=-1, keepdim=True) |
|
return labels_embeddings |
|
|
|
def forward(self, images: torch.Tensor) -> torch.Tensor: |
|
image_features = self.model.get_image_features(images) |
|
image_features /= image_features.norm(p=2, dim=-1, keepdim=True) |
|
return torch.matmul(image_features, self.labels_embeddings.T) * self.logit_scale |
|
|
|
|
|
def calculate_accuracy(model: CLIPClassifier, dataloader: DataLoader) -> float: |
|
metric = evaluate.load("accuracy") |
|
predictions_list = [] |
|
references_list = [] |
|
device = get_module_device(model) |
|
for batch in tqdm(dataloader, total=len(dataloader), desc="Evaluate model on dataset"): |
|
batch["pixel_values"] = batch["pixel_values"].to(device) |
|
predictions = model(batch["pixel_values"]) |
|
predictions_list.append(torch.argmax(predictions, dim=1)) |
|
references_list.append(batch["label_id"]) |
|
return metric.compute( |
|
predictions=torch.concat(predictions_list), |
|
references=torch.concat(references_list), |
|
)["accuracy"] |
|
|
|
def collate_fn(items: LazyBatch) -> dict[str, Any]: |
|
return { |
|
"pixel_values": torch.stack([item["pixel_values"] for item in items]), |
|
"input_ids": torch.tensor([item["input_ids"] for item in items]), |
|
"attention_mask": torch.tensor([item["attention_mask"] for item in items]), |
|
"label_id": torch.tensor([item["label_id"] for item in items]), |
|
"return_loss": True, |
|
} |
|
|
|
@torch.no_grad() |
|
def evaluate_clip_classifier( |
|
model: nn.Module, |
|
dataset: Dataset, |
|
tokenizer: CLIPTokenizerFast, |
|
labels: list[str], |
|
batch_size: int = 64, |
|
num_workers: int = 5, |
|
device: str = "cuda", |
|
) -> None: |
|
clip_classifier = CLIPClassifier(model, tokenizer, labels) |
|
test_dataloader = DataLoader( |
|
dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn |
|
) |
|
clip_classifier = clip_classifier.to(device) |
|
acc = calculate_accuracy(clip_classifier, test_dataloader) |
|
print(f"Model accuracy: {acc}") |
|
|
|
def collate_train_fn(items: LazyBatch): |
|
items = collate_fn(items) |
|
items.pop("label_id") |
|
return items |
|
|
|
def get_default_training_args( |
|
experiment_name: str, |
|
lr: float, |
|
batch_size: int = 256, |
|
num_epoch: int = 4, |
|
num_workers: int = 15, |
|
) -> TrainingArguments: |
|
return TrainingArguments( |
|
experiment_name, |
|
per_device_train_batch_size=batch_size, |
|
learning_rate=lr, |
|
num_train_epochs=num_epoch, |
|
per_device_eval_batch_size=batch_size, |
|
gradient_accumulation_steps=1, |
|
logging_steps=10, |
|
save_total_limit=2, |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
fp16=True, |
|
remove_unused_columns=False, |
|
load_best_model_at_end=True, |
|
dataloader_num_workers=num_workers, |
|
) |
|
|
|
|
|
clip_full_finetuned = CLIPModel.from_pretrained(MODEL_NAME) |
|
trainer = Trainer( |
|
model=clip_full_finetuned, |
|
args=get_default_training_args("clip-all-layers-tuning-oxford-pets", 3e-6), |
|
data_collator=collate_train_fn, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["val"], |
|
) |
|
|
|
trainer.train() |
|
|
|
print_trainable_parameters(clip_full_finetuned) |
|
evaluate_clip_classifier(clip_full_finetuned, dataset['test'], TOKENIZER, labels) |
|
|
|
|
|
login(token='TOKEN') |
|
api = HfApi() |
|
repo_url = create_repo(repo_id="DGurgurov/clip-vit-base-patch32-oxford-pets", exist_ok=True) |
|
print(f"Repository created at: {repo_url}") |
|
|
|
api.upload_folder( |
|
folder_path=f'clip-all-layers-tuning-oxford-pets/checkpoint-84', |
|
path_in_repo='', |
|
repo_id='DGurgurov/clip-vit-base-patch32-oxford-pets' |
|
) |
|
|
|
|
|
readme_content = f""" |
|
# CLIP ViT Base Patch32 Fine-tuned on Oxford Pets |
|
|
|
This model is a fine-tuned version of OpenAI's CLIP model on the Oxford Pets dataset. |
|
|
|
## Training Information |
|
|
|
- **Model Name**: openai/clip-vit-base-patch32 |
|
- **Dataset**: oxford-pets |
|
- **Training Epochs**: 4 |
|
- **Batch Size**: 256 |
|
- **Learning Rate**: 3e-6 |
|
- **Accuracy**: 93.74% |
|
|
|
## License |
|
[MIT] |
|
""" |
|
|
|
with open(f'clip-all-layers-tuning-oxford-pets/checkpoint-84/README.md', 'w') as f: |
|
f.write(readme_content) |
|
|
|
api.upload_file( |
|
path_or_fileobj=f'clip-all-layers-tuning-oxford-pets/checkpoint-84/README.md', |
|
path_in_repo='README.md', |
|
repo_id='DGurgurov/clip-vit-base-patch32-oxford-pets' |
|
) |
|
|