Spaces:
Runtime error
Runtime error
miguelcarv
commited on
Commit
•
34f251f
1
Parent(s):
fb2630f
first commit
Browse files- app.py +95 -55
- pheye_builder/__init__.py +1 -0
- pheye_builder/__pycache__/__init__.cpython-311.pyc +0 -0
- pheye_builder/__pycache__/encoder.cpython-311.pyc +0 -0
- pheye_builder/__pycache__/factory.cpython-311.pyc +0 -0
- pheye_builder/__pycache__/phEYE.cpython-311.pyc +0 -0
- pheye_builder/__pycache__/utils.cpython-311.pyc +0 -0
- pheye_builder/__pycache__/wrapper_lm.cpython-311.pyc +0 -0
- pheye_builder/__pycache__/xattn.cpython-311.pyc +0 -0
- pheye_builder/encoder.py +179 -0
- pheye_builder/factory.py +126 -0
- pheye_builder/phEYE.py +220 -0
- pheye_builder/utils.py +48 -0
- pheye_builder/wrapper_lm.py +132 -0
- pheye_builder/xattn.py +159 -0
- requirements.txt +8 -1
app.py
CHANGED
@@ -1,63 +1,103 @@
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
""
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
],
|
|
|
|
|
|
|
59 |
)
|
60 |
|
61 |
|
|
|
|
|
62 |
if __name__ == "__main__":
|
63 |
-
|
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
+
import json
|
4 |
+
from pheye_builder import create_model_and_transforms
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
import os
|
9 |
+
import requests
|
10 |
|
11 |
+
|
12 |
+
def get_config(hf_model_path):
|
13 |
+
config_path = hf_hub_download(hf_model_path, "config.json")
|
14 |
+
|
15 |
+
with open(config_path, "r") as f:
|
16 |
+
config = json.load(f)
|
17 |
+
|
18 |
+
return config
|
19 |
+
|
20 |
+
|
21 |
+
def get_model_path(hf_model_path):
|
22 |
+
return hf_hub_download(hf_model_path, "checkpoint.pt")
|
23 |
+
|
24 |
+
|
25 |
+
HF_MODEL = "miguelcarv/Pheye-x2-672"
|
26 |
+
config = get_config(HF_MODEL)
|
27 |
+
|
28 |
+
print("Got config")
|
29 |
+
|
30 |
+
model, tokenizer = create_model_and_transforms(
|
31 |
+
clip_vision_encoder_path=config["encoder"],
|
32 |
+
lang_decoder_path=config["decoder"],
|
33 |
+
tokenizer_path=config["tokenizer"],
|
34 |
+
cross_attn_every_n_layers=config["cross_interval"],
|
35 |
+
level=config["level"],
|
36 |
+
reduce_factor=config["reduce"],
|
37 |
+
from_layer=config["from_layer"],
|
38 |
+
encoder_dtype=eval(config["encoder_dtype"]),
|
39 |
+
decoder_dtype=eval(config["decoder_dtype"]),
|
40 |
+
dtype=eval(config["other_params_dtype"])
|
41 |
+
)
|
42 |
+
|
43 |
+
if config["first_level"]:
|
44 |
+
model.vision_encoder.add_first_level_adapter()
|
45 |
+
|
46 |
+
print("Created model")
|
47 |
+
|
48 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
49 |
+
model_path = get_model_path(HF_MODEL)
|
50 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
51 |
+
model = model.to(DEVICE)
|
52 |
+
|
53 |
+
print("Loaded model")
|
54 |
+
|
55 |
+
SYSTEM_PROMPT = "You are an AI visual assistant and you are seeing a single image. You will receive an instruction regarding that image. Your goal is to follow the instruction as faithfully as you can."
|
56 |
+
|
57 |
+
whiteboard = Image.open(requests.get("https://c1.staticflickr.com/7/6168/6207108414_a8833f410e_o.jpg", stream=True).raw).convert('RGB')
|
58 |
+
taxi_image = Image.open(requests.get("https://llava.hliu.cc/file=/nobackup/haotian/tmp/gradio/ca10383cc943e99941ecffdc4d34c51afb2da472/extreme_ironing.jpg", stream=True).raw).convert('RGB')
|
59 |
+
|
60 |
+
|
61 |
+
def generate_answer(img, question, max_new_tokens, num_beams):
|
62 |
+
|
63 |
+
image = [img]
|
64 |
+
prompt = [f"{SYSTEM_PROMPT}\n\nInstruction: {question}\nOutput:"]
|
65 |
+
inputs = tokenizer(prompt, padding='longest', return_tensors='pt')
|
66 |
+
print("Generating a response with the following parameters:")
|
67 |
+
print(f"""Question: {question}\nMax New Tokens: {max_new_tokens}\nNum Beams: {num_beams}""")
|
68 |
+
|
69 |
+
model.eval()
|
70 |
+
with torch.no_grad():
|
71 |
+
outputs = model.generate(vision_x=image,
|
72 |
+
lang_x=inputs.input_ids.to(DEVICE),
|
73 |
+
device=DEVICE,
|
74 |
+
max_new_tokens=max_new_tokens,
|
75 |
+
num_beams = num_beams,
|
76 |
+
eos_token_id = tokenizer.eos_token_id,
|
77 |
+
pad_token_id = tokenizer.pad_token_id,
|
78 |
+
attention_mask=inputs.attention_mask.to(DEVICE))
|
79 |
+
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].split("Output:")[-1].lstrip()
|
80 |
+
|
81 |
+
return answer
|
82 |
+
|
83 |
+
|
84 |
+
# Create the Gradio interface
|
85 |
+
iface = gr.Interface(
|
86 |
+
fn=generate_answer,
|
87 |
+
inputs=[
|
88 |
+
gr.Image(type="pil", label="Image"),
|
89 |
+
gr.Textbox(label="Question"),
|
90 |
+
gr.Slider(minimum=5, maximum=500, step=1, value=50, label="Max New Tokens"),
|
91 |
+
gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Num Beams")
|
92 |
],
|
93 |
+
outputs=gr.Textbox(label="Answer"),
|
94 |
+
title="<h1 style='text-align: center; display: block;'>Pheye-x2 672x672 pixels</h1>",
|
95 |
+
examples=[[taxi_image, "What is unusual about this image?"], [whiteboard, "What is the main topic of the whiteboard?"]]
|
96 |
)
|
97 |
|
98 |
|
99 |
+
|
100 |
+
|
101 |
if __name__ == "__main__":
|
102 |
+
# Launch the Gradio app
|
103 |
+
iface.launch()
|
pheye_builder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .factory import create_model_and_transforms
|
pheye_builder/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (250 Bytes). View file
|
|
pheye_builder/__pycache__/encoder.cpython-311.pyc
ADDED
Binary file (12.6 kB). View file
|
|
pheye_builder/__pycache__/factory.cpython-311.pyc
ADDED
Binary file (5.9 kB). View file
|
|
pheye_builder/__pycache__/phEYE.cpython-311.pyc
ADDED
Binary file (10.2 kB). View file
|
|
pheye_builder/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (2.28 kB). View file
|
|
pheye_builder/__pycache__/wrapper_lm.cpython-311.pyc
ADDED
Binary file (7.12 kB). View file
|
|
pheye_builder/__pycache__/xattn.cpython-311.pyc
ADDED
Binary file (7.7 kB). View file
|
|
pheye_builder/encoder.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import CLIPModel
|
3 |
+
from torch import nn
|
4 |
+
from peft import LoraConfig, get_peft_model
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import PIL
|
8 |
+
from PIL.Image import BICUBIC
|
9 |
+
import math
|
10 |
+
from torchvision import transforms
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
# level 4 which has 21 patches was being used in previous experiments so now I can't remove it or won't be able to load older models....
|
15 |
+
LEVELS_TO_PATCHES = {
|
16 |
+
1 : 1,
|
17 |
+
2 : 5,
|
18 |
+
3 : 10,
|
19 |
+
4 : 21
|
20 |
+
}
|
21 |
+
|
22 |
+
def cut_image_patches(image: PIL.Image, encoder_resolution: int = 224):
|
23 |
+
|
24 |
+
coordinates = []
|
25 |
+
|
26 |
+
width, height = image.size
|
27 |
+
|
28 |
+
width_tiles = [i*encoder_resolution for i in range(math.ceil(width/encoder_resolution)-1)]
|
29 |
+
width_tiles.append(width-encoder_resolution)
|
30 |
+
height_tiles = [i*encoder_resolution for i in range(math.ceil(height/encoder_resolution)-1)]
|
31 |
+
height_tiles.append(height-encoder_resolution)
|
32 |
+
|
33 |
+
for w in width_tiles:
|
34 |
+
for h in height_tiles:
|
35 |
+
coordinates.append((w,h,w+encoder_resolution,h+encoder_resolution))
|
36 |
+
|
37 |
+
cropped_images = [image.crop(c) for c in coordinates]
|
38 |
+
|
39 |
+
return cropped_images
|
40 |
+
|
41 |
+
class Encoder(nn.Module):
|
42 |
+
|
43 |
+
def __init__(self, clip_name, level = 2, dtype = None, use_dropout = True) -> None:
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
if level not in LEVELS_TO_PATCHES:
|
47 |
+
raise ValueError("Resolution not supported")
|
48 |
+
|
49 |
+
self.n_patches = LEVELS_TO_PATCHES[level]
|
50 |
+
self.vision_model = CLIPModel.from_pretrained(clip_name, torch_dtype=dtype).vision_model
|
51 |
+
self.has_first_adapter = False
|
52 |
+
self.image_size = self.vision_model.config.image_size
|
53 |
+
self.patch_size = self.vision_model.config.patch_size
|
54 |
+
self.use_dropout = use_dropout
|
55 |
+
self.dtype = dtype
|
56 |
+
|
57 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
58 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
59 |
+
self.image_transform = transforms.Compose([
|
60 |
+
transforms.ToTensor(),
|
61 |
+
transforms.Normalize(mean=mean, std=std),
|
62 |
+
])
|
63 |
+
|
64 |
+
self.norm_lvl_1 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
|
65 |
+
self.norm_lvl_2 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
|
66 |
+
|
67 |
+
# this was being used in previous experiments so now I can't remove it or won't be able to load older models....
|
68 |
+
self.norm_lvl_3 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
|
69 |
+
|
70 |
+
if level == 1:
|
71 |
+
self.connector = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
|
72 |
+
else:
|
73 |
+
self.connector = Position(self.n_patches, self.vision_model.config.hidden_size, dtype=dtype)
|
74 |
+
|
75 |
+
config_level2 = LoraConfig(
|
76 |
+
r=16,
|
77 |
+
lora_alpha=32,
|
78 |
+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "patch_embedding", "fc1", "fc2"],
|
79 |
+
lora_dropout=0.05 if self.use_dropout else 0,
|
80 |
+
bias="none"
|
81 |
+
)
|
82 |
+
self.vision_model = get_peft_model(self.vision_model, config_level2, "second")
|
83 |
+
|
84 |
+
def add_first_level_adapter(self):
|
85 |
+
|
86 |
+
config_224 = LoraConfig(
|
87 |
+
r=8,
|
88 |
+
lora_alpha=16,
|
89 |
+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "patch_embedding", "fc1", "fc2"],
|
90 |
+
lora_dropout=0.05 if self.use_dropout else 0,
|
91 |
+
bias="none"
|
92 |
+
)
|
93 |
+
|
94 |
+
self.vision_model.add_adapter("first", config_224)
|
95 |
+
self.has_first_adapter = True
|
96 |
+
|
97 |
+
|
98 |
+
def forward(self, images: list, device = "cpu", **kwargs):
|
99 |
+
"""
|
100 |
+
shape (B, C, H, W) in list form
|
101 |
+
"""
|
102 |
+
B = len(images)
|
103 |
+
h = int((self.image_size/self.patch_size) ** 2 + 1)
|
104 |
+
resized_images = {1: [], 2: []}
|
105 |
+
|
106 |
+
for i in images:
|
107 |
+
resized_images[1].append(self.image_transform(i.resize((self.image_size,self.image_size), resample=BICUBIC)))
|
108 |
+
|
109 |
+
if self.n_patches == 5:
|
110 |
+
for crop in cut_image_patches(i.resize((self.image_size * 2,self.image_size * 2), resample=BICUBIC), encoder_resolution=self.image_size):
|
111 |
+
resized_images[2].append(self.image_transform(crop))
|
112 |
+
elif self.n_patches == 10:
|
113 |
+
for crop in cut_image_patches(i.resize((self.image_size * 3,self.image_size * 3), resample=BICUBIC), encoder_resolution=self.image_size):
|
114 |
+
resized_images[2].append(self.image_transform(crop))
|
115 |
+
|
116 |
+
|
117 |
+
vision_features = []
|
118 |
+
for res, imgs in resized_images.items():
|
119 |
+
if imgs != []:
|
120 |
+
resized_images[res] = torch.stack(imgs, dim = 0).to(device)
|
121 |
+
|
122 |
+
if res == 1 and self.has_first_adapter:
|
123 |
+
self.vision_model.set_adapter("first")
|
124 |
+
vision_features.append(self.norm_lvl_1(self.vision_model(resized_images[res]).last_hidden_state))
|
125 |
+
elif res == 1:
|
126 |
+
with self.vision_model.disable_adapter():
|
127 |
+
vision_features.append(self.norm_lvl_1(self.vision_model(resized_images[res]).last_hidden_state))
|
128 |
+
elif res == 2:
|
129 |
+
self.vision_model.set_adapter("second")
|
130 |
+
if self.n_patches == 5:
|
131 |
+
vision_features.append(self.norm_lvl_2(self.vision_model(resized_images[res]).last_hidden_state.view(B, h * 4, -1)))
|
132 |
+
elif self.n_patches == 10:
|
133 |
+
vision_features.append(self.norm_lvl_2(self.vision_model(resized_images[res]).last_hidden_state.view(B, h * 9, -1)))
|
134 |
+
|
135 |
+
vision_features = torch.cat(vision_features, dim = 1)
|
136 |
+
vision_features = self.connector(vision_features)
|
137 |
+
|
138 |
+
return vision_features
|
139 |
+
|
140 |
+
|
141 |
+
class Position(nn.Module):
|
142 |
+
|
143 |
+
def __init__(self, n_patches, dim, dtype) -> None:
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
self.embedding = nn.Embedding(max(LEVELS_TO_PATCHES.values()), dim, dtype=dtype)
|
147 |
+
self.n_patches = n_patches
|
148 |
+
|
149 |
+
self.apply(self._init_weights)
|
150 |
+
|
151 |
+
def forward(self, vision_features):
|
152 |
+
|
153 |
+
batch_size, seq_len, dim = vision_features.size()
|
154 |
+
single_encoder_dim = seq_len // self.n_patches
|
155 |
+
device = vision_features.get_device()
|
156 |
+
|
157 |
+
pos = torch.LongTensor(list(range(self.n_patches))).to(device if device != -1 else "cpu")
|
158 |
+
pos = torch.repeat_interleave(self.embedding(pos).unsqueeze(0), single_encoder_dim, 1).expand(batch_size, -1, -1)
|
159 |
+
|
160 |
+
return vision_features + pos
|
161 |
+
|
162 |
+
|
163 |
+
def _init_weights(self, module):
|
164 |
+
"""Initialize the weights."""
|
165 |
+
if isinstance(module, nn.Linear):
|
166 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
167 |
+
if module.bias is not None:
|
168 |
+
module.bias.data.zero_()
|
169 |
+
elif isinstance(module, nn.Embedding):
|
170 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
171 |
+
if module.padding_idx is not None:
|
172 |
+
module.weight.data[module.padding_idx].zero_()
|
173 |
+
elif isinstance(module, nn.LayerNorm):
|
174 |
+
module.bias.data.zero_()
|
175 |
+
module.weight.data.fill_(1.0)
|
176 |
+
|
177 |
+
for name, p in module.named_parameters():
|
178 |
+
if name == "fc1.weight" or name == "fc2.weight" or name == "to_out.weight":
|
179 |
+
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.n_decoder_layers)))
|
pheye_builder/factory.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .phEYE import phEYE
|
7 |
+
from .wrapper_lm import phEYELMMixin
|
8 |
+
from .utils import extend_instance
|
9 |
+
from .encoder import Encoder
|
10 |
+
|
11 |
+
|
12 |
+
def create_model_and_transforms(
|
13 |
+
clip_vision_encoder_path: str,
|
14 |
+
lang_decoder_path: str,
|
15 |
+
tokenizer_path: str,
|
16 |
+
dtype,
|
17 |
+
cross_attn_every_n_layers: int = 1,
|
18 |
+
use_local_files: bool = False,
|
19 |
+
decoder_layers_attr_name: str = None,
|
20 |
+
freeze_lm_embeddings: bool = True,
|
21 |
+
cache_dir: Optional[str] = None,
|
22 |
+
level: int = 2,
|
23 |
+
encoder_dtype : torch.dtype = None,
|
24 |
+
decoder_dtype : torch.dtype = None,
|
25 |
+
use_dropout : bool = False,
|
26 |
+
**pheye_kwargs,
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Initialize a phEYE model from a pretrained vision encoder and language encoder.
|
30 |
+
Appends special tokens to the tokenizer and freezes backbones.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
|
34 |
+
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
|
35 |
+
lang_encoder_path (str): path to pretrained language encoder
|
36 |
+
tokenizer_path (str): path to pretrained tokenizer
|
37 |
+
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
|
38 |
+
use_local_files (bool, optional): whether to use local files. Defaults to False.
|
39 |
+
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
|
40 |
+
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
|
41 |
+
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
|
42 |
+
Returns:
|
43 |
+
phEYE: phEYE model from pretrained vision and language encoders
|
44 |
+
Image processor: Pipeline to preprocess input images
|
45 |
+
Tokenizer: A tokenizer for the language model
|
46 |
+
"""
|
47 |
+
|
48 |
+
vision_encoder = Encoder(clip_vision_encoder_path, level=level, dtype=encoder_dtype, use_dropout=use_dropout)
|
49 |
+
|
50 |
+
|
51 |
+
text_tokenizer = AutoTokenizer.from_pretrained(
|
52 |
+
tokenizer_path,
|
53 |
+
local_files_only=use_local_files,
|
54 |
+
trust_remote_code=True,
|
55 |
+
cache_dir=cache_dir,
|
56 |
+
)
|
57 |
+
|
58 |
+
if text_tokenizer.pad_token is None:
|
59 |
+
text_tokenizer.pad_token = text_tokenizer.eos_token
|
60 |
+
|
61 |
+
#print(lang_decoder_path)
|
62 |
+
lang_config = AutoConfig.from_pretrained(lang_decoder_path)
|
63 |
+
#print(lang_config)
|
64 |
+
lang_encoder = AutoModelForCausalLM.from_config(
|
65 |
+
lang_config,
|
66 |
+
#local_files_only=use_local_files,
|
67 |
+
#trust_remote_code=True,
|
68 |
+
torch_dtype=decoder_dtype
|
69 |
+
)
|
70 |
+
|
71 |
+
lang_encoder.config.decoder_start_token_id = None
|
72 |
+
lang_encoder.config.pad_token_id = text_tokenizer.pad_token_id
|
73 |
+
|
74 |
+
# convert LM to phEYELM
|
75 |
+
extend_instance(lang_encoder, phEYELMMixin)
|
76 |
+
|
77 |
+
if decoder_layers_attr_name is None:
|
78 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
79 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
80 |
+
|
81 |
+
model = phEYE(
|
82 |
+
vision_encoder,
|
83 |
+
lang_encoder,
|
84 |
+
vis_dim=vision_encoder.vision_model.config.hidden_size,
|
85 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
86 |
+
dtype=dtype,
|
87 |
+
**pheye_kwargs,
|
88 |
+
)
|
89 |
+
|
90 |
+
# Freeze all parameters
|
91 |
+
model.lang_encoder.requires_grad_(False)
|
92 |
+
assert sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad) == 0
|
93 |
+
|
94 |
+
# Unfreeze perceiver, cross_attn_layers, and LM input embeddings
|
95 |
+
model.lang_encoder.cross_attn_layers.requires_grad_(True)
|
96 |
+
if not freeze_lm_embeddings:
|
97 |
+
model.lang_encoder.get_input_embeddings().requires_grad_(True)
|
98 |
+
|
99 |
+
print(
|
100 |
+
f"phEYE model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
|
101 |
+
)
|
102 |
+
|
103 |
+
return model, text_tokenizer
|
104 |
+
|
105 |
+
|
106 |
+
def _infer_decoder_layers_attr_name(model):
|
107 |
+
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
108 |
+
if k.lower() in model.__class__.__name__.lower():
|
109 |
+
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
110 |
+
|
111 |
+
raise ValueError(
|
112 |
+
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
117 |
+
"opt": "model.decoder.layers",
|
118 |
+
"gpt": "transformer.h",
|
119 |
+
"gpt-j": "transformer.h",
|
120 |
+
"pythia": "gpt_neox.layers",
|
121 |
+
"llama": "model.layers",
|
122 |
+
"gptneoxforcausallm": "gpt_neox.layers",
|
123 |
+
"mpt": "transformer.blocks",
|
124 |
+
"mosaicgpt": "transformer.blocks",
|
125 |
+
"phi" : "model.layers"
|
126 |
+
}
|
pheye_builder/phEYE.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from peft import LoraConfig, get_peft_model
|
3 |
+
from torch import nn
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
class phEYE(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
vision_encoder: nn.Module,
|
11 |
+
lang_encoder: nn.Module,
|
12 |
+
vis_dim: int,
|
13 |
+
dtype: torch.dtype,
|
14 |
+
cross_attn_every_n_layers: int = 1,
|
15 |
+
gradient_checkpointing: bool = False,
|
16 |
+
reduce_factor = 1,
|
17 |
+
from_layer = 0
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
Args:
|
21 |
+
vision_encoder (nn.Module): module with OpenCLIP model
|
22 |
+
lang_encoder (nn.Module): HF causal language model
|
23 |
+
vis_dim (int): Dimension of the visual features.
|
24 |
+
Visual features are projected to match this shape along the last dimension.
|
25 |
+
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
self.vis_dim = vis_dim
|
29 |
+
if hasattr(lang_encoder.config, "d_model"):
|
30 |
+
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
|
31 |
+
else:
|
32 |
+
self.lang_dim = lang_encoder.config.hidden_size
|
33 |
+
|
34 |
+
self.vision_encoder = vision_encoder
|
35 |
+
self.lang_encoder = lang_encoder
|
36 |
+
self.lang_encoder.init_pheye(
|
37 |
+
lang_hidden_size=self.lang_dim,
|
38 |
+
vis_hidden_size=self.vis_dim,
|
39 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
40 |
+
gradient_checkpointing=gradient_checkpointing,
|
41 |
+
reduce_factor=reduce_factor,
|
42 |
+
from_layer=from_layer,
|
43 |
+
dtype=dtype
|
44 |
+
)
|
45 |
+
self._use_gradient_checkpointing = gradient_checkpointing
|
46 |
+
|
47 |
+
def forward(
|
48 |
+
self,
|
49 |
+
vision_x: list,
|
50 |
+
lang_x: torch.Tensor,
|
51 |
+
attention_mask: torch.Tensor = None,
|
52 |
+
labels: torch.Tensor = None,
|
53 |
+
clear_conditioned_layers: bool = True,
|
54 |
+
past_key_values = None,
|
55 |
+
use_cache: bool = False,
|
56 |
+
device="cpu",
|
57 |
+
is_textcaps = False
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Forward pass of phEYE.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
vision_x (list): Vision input
|
64 |
+
shape (B, C, H, W)
|
65 |
+
lang_x (torch.Tensor): Language input ids
|
66 |
+
shape (B, txt_seq)
|
67 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
68 |
+
labels (torch.Tensor, optional): Labels. Defaults to None.
|
69 |
+
clear_conditioned_layers: if True, clear the conditioned layers
|
70 |
+
once the foward pass is completed. Set this to false if the
|
71 |
+
same set of images will be reused in another subsequent
|
72 |
+
forward pass.
|
73 |
+
past_key_values: pre-computed values to pass to language model.
|
74 |
+
See past_key_values documentation in Hugging Face
|
75 |
+
CausalLM models.
|
76 |
+
use_cache: whether to use cached key values. See use_cache
|
77 |
+
documentation in Hugging Face CausalLM models.
|
78 |
+
"""
|
79 |
+
assert (
|
80 |
+
self.lang_encoder.initialized_pheye
|
81 |
+
), "Wrapper layers are not initialized. Please call `initialized_pheye` first."
|
82 |
+
|
83 |
+
assert (
|
84 |
+
self.lang_encoder._use_cached_vision_x or vision_x is not None
|
85 |
+
), "Must provide either vision_x or have precached media using cache_media()."
|
86 |
+
|
87 |
+
if self.lang_encoder._use_cached_vision_x:
|
88 |
+
# Case: use cached; vision_x should be cached and other
|
89 |
+
# vision-related inputs should not be provided.
|
90 |
+
assert (
|
91 |
+
vision_x is None
|
92 |
+
), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
|
93 |
+
assert self.lang_encoder.is_conditioned()
|
94 |
+
|
95 |
+
else:
|
96 |
+
# Case: do not use caching (i.e. this is a standard forward pass);
|
97 |
+
self._encode_vision_x(vision_x=vision_x, device=device, is_textcaps=is_textcaps)
|
98 |
+
|
99 |
+
#print(f"Text features shape: {lang_x.shape}")
|
100 |
+
output = self.lang_encoder(
|
101 |
+
input_ids=lang_x,
|
102 |
+
attention_mask=attention_mask,
|
103 |
+
labels=labels,
|
104 |
+
past_key_values=past_key_values,
|
105 |
+
use_cache=use_cache,
|
106 |
+
)
|
107 |
+
|
108 |
+
if clear_conditioned_layers:
|
109 |
+
self.lang_encoder.clear_conditioned_layers()
|
110 |
+
|
111 |
+
return output
|
112 |
+
|
113 |
+
def generate(
|
114 |
+
self,
|
115 |
+
vision_x: list,
|
116 |
+
lang_x: torch.Tensor,
|
117 |
+
attention_mask: torch.Tensor = None,
|
118 |
+
device = "cpu",
|
119 |
+
**kwargs,
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
Generate text conditioned on vision and language inputs.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
vision_x (list): Vision input
|
126 |
+
shape (B, C, H, W)
|
127 |
+
images in the same chunk are collated along T_img, and frames are collated along F
|
128 |
+
currently only F=1 is supported (single-frame videos)
|
129 |
+
lang_x (torch.Tensor): Language input
|
130 |
+
shape (B, T_txt)
|
131 |
+
**kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
|
132 |
+
max_length (int, optional): Maximum length of the output. Defaults to None.
|
133 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
134 |
+
num_beams (int, optional): Number of beams. Defaults to 1.
|
135 |
+
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
|
136 |
+
temperature (float, optional): Temperature. Defaults to 1.0.
|
137 |
+
top_k (int, optional): Top k. Defaults to 50.
|
138 |
+
top_p (float, optional): Top p. Defaults to 1.0.
|
139 |
+
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
|
140 |
+
length_penalty (float, optional): Length penalty. Defaults to 1.0.
|
141 |
+
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
|
142 |
+
do_sample (bool, optional): Do sample. Defaults to False.
|
143 |
+
early_stopping (bool, optional): Early stopping. Defaults to False.
|
144 |
+
Returns:
|
145 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
146 |
+
"""
|
147 |
+
num_beams = kwargs.pop("num_beams", 1)
|
148 |
+
|
149 |
+
self.lang_encoder._use_cached_vision_x = True
|
150 |
+
self._encode_vision_x(vision_x=vision_x, device=device, repeat=num_beams)
|
151 |
+
|
152 |
+
output = self.lang_encoder.generate(
|
153 |
+
input_ids=lang_x,
|
154 |
+
attention_mask=attention_mask,
|
155 |
+
num_beams=num_beams,
|
156 |
+
**kwargs,
|
157 |
+
)
|
158 |
+
|
159 |
+
self.lang_encoder.clear_conditioned_layers()
|
160 |
+
self.lang_encoder._use_cached_vision_x = False
|
161 |
+
return output
|
162 |
+
|
163 |
+
def _encode_vision_x(self, vision_x: list, device="cpu", repeat = 1, is_textcaps = False):
|
164 |
+
"""
|
165 |
+
Compute vision features by passing images through vision encoder and conditioning language model.
|
166 |
+
Args:
|
167 |
+
vision_x (list): Vision input
|
168 |
+
shape (B, C, H, W)
|
169 |
+
"""
|
170 |
+
if is_textcaps:
|
171 |
+
vision_x = vision_x[::5]
|
172 |
+
repeat = 5
|
173 |
+
|
174 |
+
vision_x = self.vision_encoder(vision_x, device=device)
|
175 |
+
|
176 |
+
if repeat > 1:
|
177 |
+
vision_x = vision_x.repeat_interleave(repeat, dim=0)
|
178 |
+
|
179 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
180 |
+
layer.condition_vis_x(vision_x)
|
181 |
+
|
182 |
+
|
183 |
+
def cache_media(self, vision_x: list, device="cpu"):
|
184 |
+
"""
|
185 |
+
Cache vision_x features from list of images for log-likelihood evaluation
|
186 |
+
This is not meant to be used to cache things for generate().
|
187 |
+
Args:
|
188 |
+
vision_x (torch.Tensor): Vision input
|
189 |
+
shape (B, F, C, H, W)
|
190 |
+
"""
|
191 |
+
self._encode_vision_x(vision_x=vision_x, device=device)
|
192 |
+
self.lang_encoder._use_cached_vision_x = True
|
193 |
+
|
194 |
+
def uncache_media(self):
|
195 |
+
"""
|
196 |
+
Clear all conditioning.
|
197 |
+
"""
|
198 |
+
self.lang_encoder.clear_conditioned_layers()
|
199 |
+
self.lang_encoder._use_cached_vision_x = False
|
200 |
+
|
201 |
+
def save_model(self, _path):
|
202 |
+
os.mkdir(_path)
|
203 |
+
torch.save(self.vision_encoder.state_dict(), _path+"vision_encoder.pt")
|
204 |
+
torch.save(self.lang_encoder.state_dict(), _path+"lang_encoder.pt")
|
205 |
+
|
206 |
+
def add_lora_decoder(self):
|
207 |
+
|
208 |
+
config = LoraConfig(
|
209 |
+
r=16,
|
210 |
+
lora_alpha=32,
|
211 |
+
target_modules=["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"],
|
212 |
+
lora_dropout=0.05,
|
213 |
+
bias="none"
|
214 |
+
)
|
215 |
+
|
216 |
+
self.lang_encoder.old_decoder_blocks = get_peft_model(self.lang_encoder.old_decoder_blocks, config)
|
217 |
+
|
218 |
+
def merge_and_unload(self):
|
219 |
+
self.lang_encoder.old_decoder_blocks = self.lang_encoder.old_decoder_blocks.merge_and_unload()
|
220 |
+
|
pheye_builder/utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def extend_instance(obj, mixin):
|
2 |
+
"""Apply mixins to a class instance after creation"""
|
3 |
+
base_cls = obj.__class__
|
4 |
+
base_cls_name = obj.__class__.__name__
|
5 |
+
obj.__class__ = type(
|
6 |
+
base_cls_name, (mixin, base_cls), {}
|
7 |
+
) # mixin needs to go first for our forward() logic to work
|
8 |
+
|
9 |
+
|
10 |
+
def getattr_recursive(obj, att):
|
11 |
+
"""
|
12 |
+
Return nested attribute of obj
|
13 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
14 |
+
"""
|
15 |
+
if att == "":
|
16 |
+
return obj
|
17 |
+
i = att.find(".")
|
18 |
+
if i < 0:
|
19 |
+
return getattr(obj, att)
|
20 |
+
else:
|
21 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
22 |
+
|
23 |
+
|
24 |
+
def setattr_recursive(obj, att, val):
|
25 |
+
"""
|
26 |
+
Set nested attribute of obj
|
27 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
28 |
+
"""
|
29 |
+
if "." in att:
|
30 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
31 |
+
setattr(obj, att.split(".")[-1], val)
|
32 |
+
|
33 |
+
|
34 |
+
def apply_with_stopping_condition(
|
35 |
+
module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
|
36 |
+
):
|
37 |
+
if stopping_condition(module):
|
38 |
+
return
|
39 |
+
if apply_condition(module):
|
40 |
+
apply_fn(module, **other_args)
|
41 |
+
for child in module.children():
|
42 |
+
apply_with_stopping_condition(
|
43 |
+
child,
|
44 |
+
apply_fn,
|
45 |
+
apply_condition=apply_condition,
|
46 |
+
stopping_condition=stopping_condition,
|
47 |
+
**other_args
|
48 |
+
)
|
pheye_builder/wrapper_lm.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from .xattn import CrossAttentionBlock
|
3 |
+
from .utils import getattr_recursive, setattr_recursive
|
4 |
+
|
5 |
+
|
6 |
+
class WrapperLayer(nn.Module):
|
7 |
+
"""
|
8 |
+
WrapperLayer is a wrapper around the CrossAttentionBlock and DecoderLayer.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self, cross_attn_layer, decoder_layer, gradient_checkpointing=False
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.cross_attn_layer = cross_attn_layer
|
16 |
+
self.decoder_layer = decoder_layer
|
17 |
+
self.vis_x = None
|
18 |
+
if self.cross_attn_layer is not None:
|
19 |
+
self.cross_attn_layer._use_gradient_checkpointing = (
|
20 |
+
gradient_checkpointing
|
21 |
+
)
|
22 |
+
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
|
23 |
+
|
24 |
+
def is_conditioned(self) -> bool:
|
25 |
+
"""Check whether the layer is conditioned."""
|
26 |
+
return self.vis_x is not None
|
27 |
+
|
28 |
+
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
|
29 |
+
def condition_vis_x(self, vis_x):
|
30 |
+
self.vis_x = vis_x
|
31 |
+
|
32 |
+
def forward(
|
33 |
+
self,
|
34 |
+
lang_x,
|
35 |
+
attention_mask=None,
|
36 |
+
**decoder_layer_kwargs,
|
37 |
+
):
|
38 |
+
# Cross attention
|
39 |
+
if self.cross_attn_layer is not None:
|
40 |
+
if self.vis_x is None:
|
41 |
+
raise ValueError("vis_x must be conditioned before forward pass")
|
42 |
+
|
43 |
+
lang_x = self.cross_attn_layer(
|
44 |
+
lang_x,
|
45 |
+
self.vis_x
|
46 |
+
)
|
47 |
+
|
48 |
+
# Normal decoder layer
|
49 |
+
lang_x = self.decoder_layer(
|
50 |
+
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
|
51 |
+
)
|
52 |
+
|
53 |
+
return lang_x
|
54 |
+
|
55 |
+
|
56 |
+
class phEYELMMixin(nn.Module):
|
57 |
+
"""
|
58 |
+
Mixin to add cross-attention layers to a language model.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
62 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
63 |
+
|
64 |
+
def _get_decoder_layers(self):
|
65 |
+
return getattr_recursive(self, self.decoder_layers_attr_name)
|
66 |
+
|
67 |
+
def _set_decoder_layers(self, value):
|
68 |
+
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
69 |
+
|
70 |
+
def init_pheye(
|
71 |
+
self,
|
72 |
+
lang_hidden_size,
|
73 |
+
vis_hidden_size,
|
74 |
+
dtype,
|
75 |
+
cross_attn_every_n_layers,
|
76 |
+
gradient_checkpointing,
|
77 |
+
reduce_factor=1,
|
78 |
+
from_layer=0
|
79 |
+
):
|
80 |
+
"""
|
81 |
+
Initialize phEYE by adding a new cross attn to the decoder.
|
82 |
+
"""
|
83 |
+
self.old_decoder_blocks = self._get_decoder_layers()
|
84 |
+
self.cross_attn_layers = nn.ModuleList(
|
85 |
+
[
|
86 |
+
CrossAttentionBlock(
|
87 |
+
dim_text=lang_hidden_size, dim_visual=vis_hidden_size, reduce_factor=reduce_factor, layer_idx=layer_idx, n_decoder_layers=len(self.old_decoder_blocks), dtype=dtype
|
88 |
+
)
|
89 |
+
if (layer_idx + 1) % cross_attn_every_n_layers == 0 and layer_idx >= from_layer
|
90 |
+
else None
|
91 |
+
for layer_idx, _ in enumerate(self._get_decoder_layers())
|
92 |
+
]
|
93 |
+
)
|
94 |
+
self.init_pheye_layers(gradient_checkpointing)
|
95 |
+
self.initialized_pheye = True
|
96 |
+
self._use_cached_vision_x = False
|
97 |
+
|
98 |
+
def init_pheye_layers(self, gradient_checkpointing):
|
99 |
+
"""
|
100 |
+
Re initializes the WrapperLayers.
|
101 |
+
Propagates any changes made to self.cross_attn_layers or self.old_decoder_blocks
|
102 |
+
"""
|
103 |
+
self._set_decoder_layers(
|
104 |
+
nn.ModuleList(
|
105 |
+
[
|
106 |
+
WrapperLayer(
|
107 |
+
cross_attn_layer, decoder_layer, gradient_checkpointing
|
108 |
+
)
|
109 |
+
for cross_attn_layer, decoder_layer in zip(
|
110 |
+
self.cross_attn_layers, self.old_decoder_blocks
|
111 |
+
)
|
112 |
+
]
|
113 |
+
)
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, input_ids, attention_mask, **kwargs):
|
117 |
+
if not self.initialized_pheye:
|
118 |
+
raise ValueError(
|
119 |
+
"phEYE layers are not initialized. Please call `init_pheye` first."
|
120 |
+
)
|
121 |
+
|
122 |
+
kwargs["input_ids"] = input_ids
|
123 |
+
kwargs["attention_mask"] = attention_mask
|
124 |
+
return super().forward(**kwargs) # Call the other parent's forward method
|
125 |
+
|
126 |
+
def is_conditioned(self) -> bool:
|
127 |
+
"""Check whether all decoder layers are already conditioned."""
|
128 |
+
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
129 |
+
|
130 |
+
def clear_conditioned_layers(self):
|
131 |
+
for layer in self._get_decoder_layers():
|
132 |
+
layer.condition_vis_x(None)
|
pheye_builder/xattn.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Based on: https://github.com/lucidrains/flamingo-pytorch
|
3 |
+
"""
|
4 |
+
|
5 |
+
from einops import rearrange
|
6 |
+
from einops_exts import rearrange_many
|
7 |
+
from torch import einsum, nn
|
8 |
+
import math
|
9 |
+
|
10 |
+
def exists(val):
|
11 |
+
return val is not None
|
12 |
+
|
13 |
+
|
14 |
+
class FeedForward(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, dim, dtype, reduce_factor = 1):
|
17 |
+
super().__init__()
|
18 |
+
mult = 4
|
19 |
+
self.norm = nn.LayerNorm(dim, dtype=dtype)
|
20 |
+
inner_dim = int(dim * mult) // reduce_factor
|
21 |
+
|
22 |
+
self.fc1 = nn.Linear(dim, inner_dim, dtype=dtype)
|
23 |
+
self.fc2 = nn.Linear(inner_dim, dim, dtype=dtype)
|
24 |
+
self.act = nn.GELU()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
|
28 |
+
x = self.norm(x)
|
29 |
+
x = self.fc1(x)
|
30 |
+
x = self.act(x)
|
31 |
+
x = self.fc2(x)
|
32 |
+
|
33 |
+
return x
|
34 |
+
|
35 |
+
# cross attention
|
36 |
+
class CrossAttention(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
*,
|
40 |
+
dim_text,
|
41 |
+
dim_visual,
|
42 |
+
dtype,
|
43 |
+
dim_head=64,
|
44 |
+
reduce_factor=1
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.scale = dim_head**-0.5
|
48 |
+
max_dim = max(dim_text, dim_visual)
|
49 |
+
self.heads = max_dim // dim_head
|
50 |
+
assert max_dim % dim_head == 0, f"Number of heads in CrossAttention is not an int - {self.heads}"
|
51 |
+
inner_dim = max_dim // reduce_factor
|
52 |
+
|
53 |
+
self.norm = nn.LayerNorm(dim_text, dtype=dtype)
|
54 |
+
|
55 |
+
self.to_q = nn.Linear(dim_text, inner_dim, dtype=dtype)
|
56 |
+
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, dtype=dtype)
|
57 |
+
#self.to_kv_second = nn.Linear(dim_visual, inner_dim * 2)
|
58 |
+
self.to_out = nn.Linear(inner_dim, dim_text, dtype=dtype)
|
59 |
+
#self.g = []
|
60 |
+
#self.l = []
|
61 |
+
|
62 |
+
def forward(self, x, media):
|
63 |
+
"""
|
64 |
+
Args:
|
65 |
+
x (torch.Tensor): text features
|
66 |
+
shape (B, txt_seq, D_txt)
|
67 |
+
media (torch.Tensor): image features
|
68 |
+
shape (B, img_seq, D_img) where img_seq is the number of concatenated features from the ViT. For example:
|
69 |
+
for an encoder of 224x224 with patch size 14 and processing images of 896x896 (with 3 levels) it will be (1 + 4 + 16) * 257 = 5397
|
70 |
+
"""
|
71 |
+
|
72 |
+
h = self.heads
|
73 |
+
|
74 |
+
x = self.norm(x)
|
75 |
+
q = self.to_q(x)
|
76 |
+
|
77 |
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
78 |
+
"""k_s, v_s = self.to_kv(media[:, 257:, :]).chunk(2, dim=-1)
|
79 |
+
k = torch.cat((k, k_s), 1)
|
80 |
+
v = torch.cat((v, v_s), 1)"""
|
81 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
|
82 |
+
|
83 |
+
q = q * self.scale
|
84 |
+
|
85 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
86 |
+
|
87 |
+
attn = sim.softmax(dim=-1)
|
88 |
+
#idk = torch.mean(attn.squeeze()[:, 65:, :], (0, 1))
|
89 |
+
#self.g.append(torch.sum(idk[:257]).item())
|
90 |
+
#self.l.append(torch.sum(idk[257:]).item())
|
91 |
+
|
92 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
93 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
94 |
+
return self.to_out(out)
|
95 |
+
|
96 |
+
# cross attention
|
97 |
+
class CrossAttentionBlock(nn.Module):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
*,
|
101 |
+
dim_text,
|
102 |
+
dim_visual,
|
103 |
+
dtype,
|
104 |
+
dim_head=64,
|
105 |
+
reduce_factor = 1,
|
106 |
+
layer_idx=0,
|
107 |
+
n_decoder_layers = 24
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
self.attn = CrossAttention(
|
111 |
+
dim_text=dim_text,
|
112 |
+
dim_visual=dim_visual,
|
113 |
+
dim_head=dim_head,
|
114 |
+
reduce_factor=reduce_factor,
|
115 |
+
dtype=dtype
|
116 |
+
)
|
117 |
+
|
118 |
+
self.ff = FeedForward(dim_text, reduce_factor=reduce_factor, dtype=dtype)
|
119 |
+
self.layer_idx = layer_idx
|
120 |
+
self.n_decoder_layers = n_decoder_layers
|
121 |
+
|
122 |
+
self.apply(self._init_weights)
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
x,
|
127 |
+
media
|
128 |
+
):
|
129 |
+
|
130 |
+
x = (
|
131 |
+
self.attn(
|
132 |
+
x,
|
133 |
+
media
|
134 |
+
)
|
135 |
+
+ x
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
x = self.ff(x) + x
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
def _init_weights(self, module):
|
144 |
+
"""Initialize the weights."""
|
145 |
+
if isinstance(module, nn.Linear):
|
146 |
+
module.weight.data.normal_(mean=0.0, std=0.01)
|
147 |
+
if module.bias is not None:
|
148 |
+
module.bias.data.zero_()
|
149 |
+
elif isinstance(module, nn.Embedding):
|
150 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
151 |
+
if module.padding_idx is not None:
|
152 |
+
module.weight.data[module.padding_idx].zero_()
|
153 |
+
elif isinstance(module, nn.LayerNorm):
|
154 |
+
module.bias.data.zero_()
|
155 |
+
module.weight.data.fill_(1.0)
|
156 |
+
|
157 |
+
for name, p in module.named_parameters():
|
158 |
+
if name == "fc2.weight" or name == "to_out.weight":
|
159 |
+
p.data.normal_(mean=0.0, std=(0.01 / math.sqrt(2 * max(self.n_decoder_layers, 36))))
|
requirements.txt
CHANGED
@@ -1 +1,8 @@
|
|
1 |
-
huggingface_hub==0.22.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.22.2
|
2 |
+
transformers==4.37.0
|
3 |
+
pillow==10.3.0
|
4 |
+
torch==2.1.1
|
5 |
+
torchvision==0.16.1
|
6 |
+
peft==0.7.0
|
7 |
+
einops==0.6.1
|
8 |
+
einops-exts==0.0.4
|