Spaces:
Runtime error
Runtime error
from transformers import CLIPModel | |
from torch import nn | |
from peft import LoraConfig, get_peft_model | |
import torch | |
from torch import nn | |
import PIL | |
from PIL.Image import BICUBIC | |
import math | |
from torchvision import transforms | |
import torch.nn.functional as F | |
# level 4 which has 21 patches was being used in previous experiments so now I can't remove it or won't be able to load older models.... | |
LEVELS_TO_PATCHES = { | |
1 : 1, | |
2 : 5, | |
3 : 10, | |
4 : 21 | |
} | |
def cut_image_patches(image: PIL.Image, encoder_resolution: int = 224): | |
coordinates = [] | |
width, height = image.size | |
width_tiles = [i*encoder_resolution for i in range(math.ceil(width/encoder_resolution)-1)] | |
width_tiles.append(width-encoder_resolution) | |
height_tiles = [i*encoder_resolution for i in range(math.ceil(height/encoder_resolution)-1)] | |
height_tiles.append(height-encoder_resolution) | |
for w in width_tiles: | |
for h in height_tiles: | |
coordinates.append((w,h,w+encoder_resolution,h+encoder_resolution)) | |
cropped_images = [image.crop(c) for c in coordinates] | |
return cropped_images | |
class Encoder(nn.Module): | |
def __init__(self, clip_name, level = 2, dtype = None, use_dropout = True) -> None: | |
super().__init__() | |
if level not in LEVELS_TO_PATCHES: | |
raise ValueError("Resolution not supported") | |
self.n_patches = LEVELS_TO_PATCHES[level] | |
self.vision_model = CLIPModel.from_pretrained(clip_name, torch_dtype=dtype).vision_model | |
self.has_first_adapter = False | |
self.image_size = self.vision_model.config.image_size | |
self.patch_size = self.vision_model.config.patch_size | |
self.use_dropout = use_dropout | |
self.dtype = dtype | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
std = (0.26862954, 0.26130258, 0.27577711) | |
self.image_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std), | |
]) | |
self.norm_lvl_1 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype) | |
self.norm_lvl_2 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype) | |
# this was being used in previous experiments so now I can't remove it or won't be able to load older models.... | |
self.norm_lvl_3 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype) | |
if level == 1: | |
self.connector = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype) | |
else: | |
self.connector = Position(self.n_patches, self.vision_model.config.hidden_size, dtype=dtype) | |
config_level2 = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "patch_embedding", "fc1", "fc2"], | |
lora_dropout=0.05 if self.use_dropout else 0, | |
bias="none" | |
) | |
self.vision_model = get_peft_model(self.vision_model, config_level2, "second") | |
def add_first_level_adapter(self): | |
config_224 = LoraConfig( | |
r=8, | |
lora_alpha=16, | |
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "patch_embedding", "fc1", "fc2"], | |
lora_dropout=0.05 if self.use_dropout else 0, | |
bias="none" | |
) | |
self.vision_model.add_adapter("first", config_224) | |
self.has_first_adapter = True | |
def forward(self, images: list, device = "cpu", **kwargs): | |
""" | |
shape (B, C, H, W) in list form | |
""" | |
B = len(images) | |
h = int((self.image_size/self.patch_size) ** 2 + 1) | |
resized_images = {1: [], 2: []} | |
for i in images: | |
resized_images[1].append(self.image_transform(i.resize((self.image_size,self.image_size), resample=BICUBIC))) | |
if self.n_patches == 5: | |
for crop in cut_image_patches(i.resize((self.image_size * 2,self.image_size * 2), resample=BICUBIC), encoder_resolution=self.image_size): | |
resized_images[2].append(self.image_transform(crop)) | |
elif self.n_patches == 10: | |
for crop in cut_image_patches(i.resize((self.image_size * 3,self.image_size * 3), resample=BICUBIC), encoder_resolution=self.image_size): | |
resized_images[2].append(self.image_transform(crop)) | |
vision_features = [] | |
for res, imgs in resized_images.items(): | |
if imgs != []: | |
resized_images[res] = torch.stack(imgs, dim = 0).to(device) | |
if res == 1 and self.has_first_adapter: | |
self.vision_model.set_adapter("first") | |
vision_features.append(self.norm_lvl_1(self.vision_model(resized_images[res]).last_hidden_state)) | |
elif res == 1: | |
with self.vision_model.disable_adapter(): | |
vision_features.append(self.norm_lvl_1(self.vision_model(resized_images[res]).last_hidden_state)) | |
elif res == 2: | |
self.vision_model.set_adapter("second") | |
if self.n_patches == 5: | |
vision_features.append(self.norm_lvl_2(self.vision_model(resized_images[res]).last_hidden_state.view(B, h * 4, -1))) | |
elif self.n_patches == 10: | |
vision_features.append(self.norm_lvl_2(self.vision_model(resized_images[res]).last_hidden_state.view(B, h * 9, -1))) | |
vision_features = torch.cat(vision_features, dim = 1) | |
vision_features = self.connector(vision_features) | |
return vision_features | |
class Position(nn.Module): | |
def __init__(self, n_patches, dim, dtype) -> None: | |
super().__init__() | |
self.embedding = nn.Embedding(max(LEVELS_TO_PATCHES.values()), dim, dtype=dtype) | |
self.n_patches = n_patches | |
self.apply(self._init_weights) | |
def forward(self, vision_features): | |
batch_size, seq_len, dim = vision_features.size() | |
single_encoder_dim = seq_len // self.n_patches | |
device = vision_features.get_device() | |
pos = torch.LongTensor(list(range(self.n_patches))).to(device if device != -1 else "cpu") | |
pos = torch.repeat_interleave(self.embedding(pos).unsqueeze(0), single_encoder_dim, 1).expand(batch_size, -1, -1) | |
return vision_features + pos | |
def _init_weights(self, module): | |
"""Initialize the weights.""" | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
for name, p in module.named_parameters(): | |
if name == "fc1.weight" or name == "fc2.weight" or name == "to_out.weight": | |
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.n_decoder_layers))) |