Spaces:
Running
Running
File size: 8,051 Bytes
08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 5953c71 08545c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import os
import numpy as np
import pickle
import torch
import transformers
from PIL import Image
from open_clip import create_model_from_pretrained, create_model_and_transforms
import json
# XLM model functions
from multilingual_clip import pt_multilingual_clip
from model_loading import load_model
class CustomDataSet(torch.utils.data.Dataset):
def __init__(self, main_dir, compose, image_name_list):
self.main_dir = main_dir
self.transform = compose
self.total_imgs = image_name_list
def __len__(self):
return len(self.total_imgs)
def get_image_name(self, idx):
return self.total_imgs[idx]
def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
image = Image.open(img_loc)
return self.transform(image)
def features_pickle(file_path=None):
with open(file_path, 'rb') as handle:
features_pickle = pickle.load(handle)
return features_pickle
def dataset_loading(file_name):
with open(file_name) as filino:
data = [json.loads(file_i) for file_i in filino]
sorted_data = sorted(data, key=lambda x: x['id'])
image_name_list = [lin["image_name"] for lin in sorted_data]
return sorted_data, image_name_list
def text_encoder(language_model, text):
"""Normalize the text embeddings"""
embedding = language_model(text)
norm_embedding = embedding / np.linalg.norm(embedding)
return embedding, norm_embedding
def compare_embeddings(logit_scale, img_embs, txt_embs):
image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)
text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
logits_per_text = logit_scale * text_features @ image_features.t()
return logits_per_text
# Done
def compare_embeddings_text(full_text_embds, txt_embs):
full_text_embds_features = full_text_embds / full_text_embds.norm(dim=-1, keepdim=True)
text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
logits_per_text_full = text_features @ full_text_embds_features.t()
return logits_per_text_full
def find_image(language_model,clip_model, text_query, dataset, image_features, text_features_new,sorted_data, images_path,num=1):
embedding, _ = text_encoder(language_model, text_query)
logit_scale = clip_model.logit_scale.exp().float().to('cpu')
language_logits, text_logits = {}, {}
language_logits["Arabic"] = compare_embeddings(logit_scale, torch.from_numpy(image_features), torch.from_numpy(embedding))
text_logits["Arabic_text"] = compare_embeddings_text(torch.from_numpy(text_features_new), torch.from_numpy(embedding))
for _, txt_logits in language_logits.items():
probs = txt_logits.softmax(dim=-1).cpu().detach().numpy().T
file_paths = []
labels, json_data = {}, {}
for i in range(1, num+1):
idx = np.argsort(probs, axis=0)[-i, 0]
path = images_path + dataset.get_image_name(idx)
path_l = (path,f"{sorted_data[idx]['caption_ar']}")
labels[f" Image # {i}"] = probs[idx]
json_data[f" Image # {i}"] = sorted_data[idx]
file_paths.append(path_l)
json_text = {}
for _, txt_logits_full in text_logits.items():
probs_text = txt_logits_full.softmax(dim=-1).cpu().detach().numpy().T
for j in range(1, num+1):
idx = np.argsort(probs_text, axis=0)[-j, 0]
json_text[f" Text # {j}"] = sorted_data[idx]
return file_paths, labels, json_data, json_text
class AraClip():
def __init__(self):
self.text_model = load_model('bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M', in_features= 768, out_features=768)
self.language_model = lambda queries: np.asarray(self.text_model(queries).detach().to('cpu'))
self.clip_model, self.compose = create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP-512')
self.sorted_data_xtd, self.image_name_list_xtd = dataset_loading("photos/en_ar_XTD10_edited_v2.jsonl")
self.sorted_data_flicker8k, self.image_name_list_flicker8k = dataset_loading("photos/flicker_8k.jsonl")
def load_pickle_file(self, file_name):
return features_pickle(file_name)
def load_xtd_dataset(self):
dataset = CustomDataSet("photos/XTD10_dataset", self.compose, self.image_name_list_xtd)
return dataset
def load_flicker8k_dataset(self):
dataset = CustomDataSet("photos/Flicker8k_Dataset", self.compose, self.image_name_list_flicker8k)
return dataset
araclip = AraClip()
def predict(text, num, dadtaset_select):
if dadtaset_select == "XTD dataset":
image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_xtd_dataset(), araclip.load_pickle_file("cashed_pickles/XTD_pickles/image_features_XTD_1000_images_arabert_siglib_best_model.pickle") , araclip.load_pickle_file("cashed_pickles/XTD_pickles/image_features_XTD_1000_images_arabert_siglib_best_model.pickle"), araclip.sorted_data_xtd, 'photos/XTD10_dataset/', num=int(num))
else:
image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_flicker8k_dataset(), araclip.load_pickle_file("cashed_pickles/flicker_8k/image_features_flicker_8k_images_arabert_siglib_best_model.pickle") , araclip.load_pickle_file("cashed_pickles/flicker_8k/text_features_flicker_8k_images_arabert_siglib_best_model.pickle"), araclip.sorted_data_flicker8k, "photos/Flicker8k_Dataset/", num=int(num))
return image_paths, labels, json_data, json_text
class Mclip():
def __init__(self) -> None:
self.tokenizer_mclip = transformers.AutoTokenizer.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
self.text_model_mclip = pt_multilingual_clip.MultilingualCLIP.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
self.language_model_mclip = lambda queries: np.asarray(self.text_model_mclip.forward(queries, self.tokenizer_mclip).detach().to('cpu'))
self.clip_model_mclip, _, self.compose_mclip = create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
self.sorted_data_xtd, self.image_name_list_xtd = dataset_loading("photos/en_ar_XTD10_edited_v2.jsonl")
self.sorted_data_flicker8k, self.image_name_list_flicker8k = dataset_loading("photos/flicker_8k.jsonl")
def load_pickle_file(self, file_name):
return features_pickle(file_name)
def load_xtd_dataset(self):
dataset = CustomDataSet("photos/XTD10_dataset", self.compose_mclip, self.image_name_list_xtd)
return dataset
def load_flicker8k_dataset(self):
dataset = CustomDataSet("photos/Flicker8k_Dataset", self.compose_mclip, self.image_name_list_flicker8k)
return dataset
mclip = Mclip()
def predict_mclip(text, num, dadtaset_select):
if dadtaset_select == "XTD dataset":
image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_xtd_dataset() , mclip.load_pickle_file("cashed_pickles/XTD_pickles/image_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.load_pickle_file("cashed_pickles/XTD_pickles/text_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.sorted_data_xtd , 'photos/XTD10_dataset/', num=int(num))
else:
image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_flicker8k_dataset() , mclip.load_pickle_file("cashed_pickles/flicker_8k/image_features_flicker_8k_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.load_pickle_file("cashed_pickles/flicker_8k/text_features_flicker_8k_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.sorted_data_flicker8k , 'photos/Flicker8k_Dataset/', num=int(num))
return image_paths, labels, json_data, json_text
|