Upload clip_finetune.py
Browse files- clip_finetune.py +238 -0
clip_finetune.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from functools import partial
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import evaluate
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from tqdm.notebook import tqdm
|
13 |
+
from transformers import (CLIPImageProcessor, CLIPModel, CLIPProcessor,
|
14 |
+
CLIPTokenizerFast, Trainer, TrainingArguments)
|
15 |
+
from datasets.formatting.formatting import LazyBatch
|
16 |
+
from huggingface_hub import HfApi, login, create_repo
|
17 |
+
|
18 |
+
# Environment settings
|
19 |
+
os.environ["CURL_CA_BUNDLE"] = ""
|
20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
21 |
+
|
22 |
+
# Seed setting
|
23 |
+
def seed_all(seed: int):
|
24 |
+
random.seed(seed)
|
25 |
+
torch.manual_seed(seed)
|
26 |
+
np.random.seed(seed)
|
27 |
+
|
28 |
+
seed_all(69)
|
29 |
+
|
30 |
+
# Dataset preparation
|
31 |
+
dataset = load_dataset("pcuenq/oxford-pets")
|
32 |
+
dataset_train_val = dataset['train'].train_test_split(test_size=0.3)
|
33 |
+
dataset_val_test = dataset_train_val['test'].train_test_split(test_size=0.2)
|
34 |
+
dataset = DatasetDict({
|
35 |
+
"train": dataset_train_val['train'],
|
36 |
+
"val": dataset_val_test['test'],
|
37 |
+
"test": dataset_val_test['train']
|
38 |
+
})
|
39 |
+
|
40 |
+
labels = set(dataset['train']['label'])
|
41 |
+
label2id = {label: i for i, label in enumerate(labels)}
|
42 |
+
id2label = {i: label for label, i in label2id.items()}
|
43 |
+
labels = list(label2id)
|
44 |
+
|
45 |
+
MODEL_NAME = "openai/clip-vit-base-patch32"
|
46 |
+
TOKENIZER = CLIPTokenizerFast.from_pretrained(MODEL_NAME)
|
47 |
+
IMAGE_PROCESSOR = CLIPImageProcessor.from_pretrained(MODEL_NAME)
|
48 |
+
|
49 |
+
# Transformation functions
|
50 |
+
def transform_class_labels(items: LazyBatch, tokenizer: CLIPTokenizerFast, label2id: dict[str, int]) -> dict[str, Any]:
|
51 |
+
label_prompt = [f"a photo of {label}" for label in items["label"]]
|
52 |
+
output = tokenizer(label_prompt, padding=True, return_tensors="pt")
|
53 |
+
items["input_ids"] = output["input_ids"]
|
54 |
+
items["attention_mask"] = output["attention_mask"]
|
55 |
+
items["label_id"] = [label2id[label] for label in items["label"]]
|
56 |
+
return items
|
57 |
+
|
58 |
+
def transform_image(items: LazyBatch, image_processor: CLIPImageProcessor) -> dict[str, Any]:
|
59 |
+
output = image_processor(items["image"], return_tensors="pt")
|
60 |
+
items["pixel_values"] = output["pixel_values"]
|
61 |
+
return items
|
62 |
+
|
63 |
+
dataset = dataset.map(partial(transform_class_labels, tokenizer=TOKENIZER, label2id=label2id), batched=True)
|
64 |
+
dataset.set_transform(partial(transform_image, image_processor=IMAGE_PROCESSOR))
|
65 |
+
|
66 |
+
# Utility functions
|
67 |
+
def get_module_device(module: nn.Module) -> torch.device:
|
68 |
+
return next(module.parameters()).device
|
69 |
+
|
70 |
+
def freeze_params(module: nn.Module, freeze_top_percent: float = 1.0) -> None:
|
71 |
+
all_params_length = len(list(module.parameters()))
|
72 |
+
for indx, param in enumerate(module.parameters()):
|
73 |
+
if int(all_params_length * freeze_top_percent) <= indx:
|
74 |
+
break
|
75 |
+
param.requires_grad = False
|
76 |
+
|
77 |
+
def print_trainable_parameters(model: nn.Module) -> None:
|
78 |
+
trainable_params = 0
|
79 |
+
all_param = 0
|
80 |
+
for _, param in model.named_parameters():
|
81 |
+
all_param += param.numel()
|
82 |
+
if param.requires_grad:
|
83 |
+
trainable_params += param.numel()
|
84 |
+
print(
|
85 |
+
f"Trainable params: {(trainable_params / 10**6):.4f}M || All params: {(all_param / 10**6):.4f}M || Trainable%: {100 * trainable_params / all_param:.2f}%"
|
86 |
+
)
|
87 |
+
|
88 |
+
# CLIP Classifier model
|
89 |
+
class CLIPClassifier(nn.Module):
|
90 |
+
def __init__(self, clip_model: CLIPModel, tokenizer: CLIPTokenizerFast, labels: list[str]):
|
91 |
+
super().__init__()
|
92 |
+
self.model = clip_model
|
93 |
+
self.tokenizer = tokenizer
|
94 |
+
self.logit_scale = self.model.logit_scale.exp()
|
95 |
+
self.label2id = {label: i for i, label in enumerate(labels)}
|
96 |
+
self.labels_embeddings = nn.Parameter(self.generate_labels_embeddings(labels))
|
97 |
+
|
98 |
+
def generate_labels_embeddings(self, labels: list[str]) -> torch.Tensor:
|
99 |
+
labels_inputs = self.tokenizer(
|
100 |
+
[f"a photo of {label}" for label in labels],
|
101 |
+
return_tensors="pt",
|
102 |
+
padding=True,
|
103 |
+
).to(get_module_device(self.model))
|
104 |
+
labels_embeddings = self.model.get_text_features(**labels_inputs)
|
105 |
+
labels_embeddings /= labels_embeddings.norm(p=2, dim=-1, keepdim=True)
|
106 |
+
return labels_embeddings
|
107 |
+
|
108 |
+
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
109 |
+
image_features = self.model.get_image_features(images)
|
110 |
+
image_features /= image_features.norm(p=2, dim=-1, keepdim=True)
|
111 |
+
return torch.matmul(image_features, self.labels_embeddings.T) * self.logit_scale
|
112 |
+
|
113 |
+
# Evaluation function
|
114 |
+
def calculate_accuracy(model: CLIPClassifier, dataloader: DataLoader) -> float:
|
115 |
+
metric = evaluate.load("accuracy")
|
116 |
+
predictions_list = []
|
117 |
+
references_list = []
|
118 |
+
device = get_module_device(model)
|
119 |
+
for batch in tqdm(dataloader, total=len(dataloader), desc="Evaluate model on dataset"):
|
120 |
+
batch["pixel_values"] = batch["pixel_values"].to(device)
|
121 |
+
predictions = model(batch["pixel_values"])
|
122 |
+
predictions_list.append(torch.argmax(predictions, dim=1))
|
123 |
+
references_list.append(batch["label_id"])
|
124 |
+
return metric.compute(
|
125 |
+
predictions=torch.concat(predictions_list),
|
126 |
+
references=torch.concat(references_list),
|
127 |
+
)["accuracy"]
|
128 |
+
|
129 |
+
def collate_fn(items: LazyBatch) -> dict[str, Any]:
|
130 |
+
return {
|
131 |
+
"pixel_values": torch.stack([item["pixel_values"] for item in items]),
|
132 |
+
"input_ids": torch.tensor([item["input_ids"] for item in items]),
|
133 |
+
"attention_mask": torch.tensor([item["attention_mask"] for item in items]),
|
134 |
+
"label_id": torch.tensor([item["label_id"] for item in items]),
|
135 |
+
"return_loss": True,
|
136 |
+
}
|
137 |
+
|
138 |
+
@torch.no_grad()
|
139 |
+
def evaluate_clip_classifier(
|
140 |
+
model: nn.Module,
|
141 |
+
dataset: Dataset,
|
142 |
+
tokenizer: CLIPTokenizerFast,
|
143 |
+
labels: list[str],
|
144 |
+
batch_size: int = 64,
|
145 |
+
num_workers: int = 5,
|
146 |
+
device: str = "cuda",
|
147 |
+
) -> None:
|
148 |
+
clip_classifier = CLIPClassifier(model, tokenizer, labels)
|
149 |
+
test_dataloader = DataLoader(
|
150 |
+
dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn
|
151 |
+
)
|
152 |
+
clip_classifier = clip_classifier.to(device)
|
153 |
+
acc = calculate_accuracy(clip_classifier, test_dataloader)
|
154 |
+
print(f"Model accuracy: {acc}")
|
155 |
+
|
156 |
+
def collate_train_fn(items: LazyBatch):
|
157 |
+
items = collate_fn(items)
|
158 |
+
items.pop("label_id")
|
159 |
+
return items
|
160 |
+
|
161 |
+
def get_default_training_args(
|
162 |
+
experiment_name: str,
|
163 |
+
lr: float,
|
164 |
+
batch_size: int = 256,
|
165 |
+
num_epoch: int = 4,
|
166 |
+
num_workers: int = 15,
|
167 |
+
) -> TrainingArguments:
|
168 |
+
return TrainingArguments(
|
169 |
+
experiment_name,
|
170 |
+
per_device_train_batch_size=batch_size,
|
171 |
+
learning_rate=lr,
|
172 |
+
num_train_epochs=num_epoch,
|
173 |
+
per_device_eval_batch_size=batch_size,
|
174 |
+
gradient_accumulation_steps=1,
|
175 |
+
logging_steps=10,
|
176 |
+
save_total_limit=2,
|
177 |
+
evaluation_strategy="epoch",
|
178 |
+
save_strategy="epoch",
|
179 |
+
fp16=True,
|
180 |
+
remove_unused_columns=False,
|
181 |
+
load_best_model_at_end=True,
|
182 |
+
dataloader_num_workers=num_workers,
|
183 |
+
)
|
184 |
+
|
185 |
+
# Training
|
186 |
+
clip_full_finetuned = CLIPModel.from_pretrained(MODEL_NAME)
|
187 |
+
trainer = Trainer(
|
188 |
+
model=clip_full_finetuned,
|
189 |
+
args=get_default_training_args("clip-all-layers-tuning-oxford-pets", 3e-6),
|
190 |
+
data_collator=collate_train_fn,
|
191 |
+
train_dataset=dataset["train"],
|
192 |
+
eval_dataset=dataset["val"],
|
193 |
+
)
|
194 |
+
|
195 |
+
trainer.train()
|
196 |
+
|
197 |
+
print_trainable_parameters(clip_full_finetuned)
|
198 |
+
evaluate_clip_classifier(clip_full_finetuned, dataset['test'], TOKENIZER, labels)
|
199 |
+
|
200 |
+
# Hugging Face Hub interaction
|
201 |
+
login(token='TOKEN')
|
202 |
+
api = HfApi()
|
203 |
+
repo_url = create_repo(repo_id="DGurgurov/clip-vit-base-patch32-oxford-pets", exist_ok=True)
|
204 |
+
print(f"Repository created at: {repo_url}")
|
205 |
+
|
206 |
+
api.upload_folder(
|
207 |
+
folder_path=f'clip-all-layers-tuning-oxford-pets/checkpoint-84',
|
208 |
+
path_in_repo='',
|
209 |
+
repo_id='DGurgurov/clip-vit-base-patch32-oxford-pets'
|
210 |
+
)
|
211 |
+
|
212 |
+
# README creation
|
213 |
+
readme_content = f"""
|
214 |
+
# CLIP ViT Base Patch32 Fine-tuned on Oxford Pets
|
215 |
+
|
216 |
+
This model is a fine-tuned version of OpenAI's CLIP model on the Oxford Pets dataset.
|
217 |
+
|
218 |
+
## Training Information
|
219 |
+
|
220 |
+
- **Model Name**: openai/clip-vit-base-patch32
|
221 |
+
- **Dataset**: oxford-pets
|
222 |
+
- **Training Epochs**: 4
|
223 |
+
- **Batch Size**: 256
|
224 |
+
- **Learning Rate**: 3e-6
|
225 |
+
- **Accuracy**: 93.74%
|
226 |
+
|
227 |
+
## License
|
228 |
+
[MIT]
|
229 |
+
"""
|
230 |
+
|
231 |
+
with open(f'clip-all-layers-tuning-oxford-pets/checkpoint-84/README.md', 'w') as f:
|
232 |
+
f.write(readme_content)
|
233 |
+
|
234 |
+
api.upload_file(
|
235 |
+
path_or_fileobj=f'clip-all-layers-tuning-oxford-pets/checkpoint-84/README.md',
|
236 |
+
path_in_repo='README.md',
|
237 |
+
repo_id='DGurgurov/clip-vit-base-patch32-oxford-pets'
|
238 |
+
)
|