miguelcarv commited on
Commit
34f251f
1 Parent(s): fb2630f

first commit

Browse files
app.py CHANGED
@@ -1,63 +1,103 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ],
 
 
 
59
  )
60
 
61
 
 
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import json
4
+ from pheye_builder import create_model_and_transforms
5
+ from huggingface_hub import hf_hub_download
6
+ import torch
7
+ from PIL import Image
8
+ import os
9
+ import requests
10
 
11
+
12
+ def get_config(hf_model_path):
13
+ config_path = hf_hub_download(hf_model_path, "config.json")
14
+
15
+ with open(config_path, "r") as f:
16
+ config = json.load(f)
17
+
18
+ return config
19
+
20
+
21
+ def get_model_path(hf_model_path):
22
+ return hf_hub_download(hf_model_path, "checkpoint.pt")
23
+
24
+
25
+ HF_MODEL = "miguelcarv/Pheye-x2-672"
26
+ config = get_config(HF_MODEL)
27
+
28
+ print("Got config")
29
+
30
+ model, tokenizer = create_model_and_transforms(
31
+ clip_vision_encoder_path=config["encoder"],
32
+ lang_decoder_path=config["decoder"],
33
+ tokenizer_path=config["tokenizer"],
34
+ cross_attn_every_n_layers=config["cross_interval"],
35
+ level=config["level"],
36
+ reduce_factor=config["reduce"],
37
+ from_layer=config["from_layer"],
38
+ encoder_dtype=eval(config["encoder_dtype"]),
39
+ decoder_dtype=eval(config["decoder_dtype"]),
40
+ dtype=eval(config["other_params_dtype"])
41
+ )
42
+
43
+ if config["first_level"]:
44
+ model.vision_encoder.add_first_level_adapter()
45
+
46
+ print("Created model")
47
+
48
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+ model_path = get_model_path(HF_MODEL)
50
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
51
+ model = model.to(DEVICE)
52
+
53
+ print("Loaded model")
54
+
55
+ SYSTEM_PROMPT = "You are an AI visual assistant and you are seeing a single image. You will receive an instruction regarding that image. Your goal is to follow the instruction as faithfully as you can."
56
+
57
+ whiteboard = Image.open(requests.get("https://c1.staticflickr.com/7/6168/6207108414_a8833f410e_o.jpg", stream=True).raw).convert('RGB')
58
+ taxi_image = Image.open(requests.get("https://llava.hliu.cc/file=/nobackup/haotian/tmp/gradio/ca10383cc943e99941ecffdc4d34c51afb2da472/extreme_ironing.jpg", stream=True).raw).convert('RGB')
59
+
60
+
61
+ def generate_answer(img, question, max_new_tokens, num_beams):
62
+
63
+ image = [img]
64
+ prompt = [f"{SYSTEM_PROMPT}\n\nInstruction: {question}\nOutput:"]
65
+ inputs = tokenizer(prompt, padding='longest', return_tensors='pt')
66
+ print("Generating a response with the following parameters:")
67
+ print(f"""Question: {question}\nMax New Tokens: {max_new_tokens}\nNum Beams: {num_beams}""")
68
+
69
+ model.eval()
70
+ with torch.no_grad():
71
+ outputs = model.generate(vision_x=image,
72
+ lang_x=inputs.input_ids.to(DEVICE),
73
+ device=DEVICE,
74
+ max_new_tokens=max_new_tokens,
75
+ num_beams = num_beams,
76
+ eos_token_id = tokenizer.eos_token_id,
77
+ pad_token_id = tokenizer.pad_token_id,
78
+ attention_mask=inputs.attention_mask.to(DEVICE))
79
+ answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].split("Output:")[-1].lstrip()
80
+
81
+ return answer
82
+
83
+
84
+ # Create the Gradio interface
85
+ iface = gr.Interface(
86
+ fn=generate_answer,
87
+ inputs=[
88
+ gr.Image(type="pil", label="Image"),
89
+ gr.Textbox(label="Question"),
90
+ gr.Slider(minimum=5, maximum=500, step=1, value=50, label="Max New Tokens"),
91
+ gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Num Beams")
92
  ],
93
+ outputs=gr.Textbox(label="Answer"),
94
+ title="<h1 style='text-align: center; display: block;'>Pheye-x2 672x672 pixels</h1>",
95
+ examples=[[taxi_image, "What is unusual about this image?"], [whiteboard, "What is the main topic of the whiteboard?"]]
96
  )
97
 
98
 
99
+
100
+
101
  if __name__ == "__main__":
102
+ # Launch the Gradio app
103
+ iface.launch()
pheye_builder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_model_and_transforms
pheye_builder/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (250 Bytes). View file
 
pheye_builder/__pycache__/encoder.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
pheye_builder/__pycache__/factory.cpython-311.pyc ADDED
Binary file (5.9 kB). View file
 
pheye_builder/__pycache__/phEYE.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
pheye_builder/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.28 kB). View file
 
pheye_builder/__pycache__/wrapper_lm.cpython-311.pyc ADDED
Binary file (7.12 kB). View file
 
pheye_builder/__pycache__/xattn.cpython-311.pyc ADDED
Binary file (7.7 kB). View file
 
pheye_builder/encoder.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import CLIPModel
3
+ from torch import nn
4
+ from peft import LoraConfig, get_peft_model
5
+ import torch
6
+ from torch import nn
7
+ import PIL
8
+ from PIL.Image import BICUBIC
9
+ import math
10
+ from torchvision import transforms
11
+ import torch.nn.functional as F
12
+
13
+
14
+ # level 4 which has 21 patches was being used in previous experiments so now I can't remove it or won't be able to load older models....
15
+ LEVELS_TO_PATCHES = {
16
+ 1 : 1,
17
+ 2 : 5,
18
+ 3 : 10,
19
+ 4 : 21
20
+ }
21
+
22
+ def cut_image_patches(image: PIL.Image, encoder_resolution: int = 224):
23
+
24
+ coordinates = []
25
+
26
+ width, height = image.size
27
+
28
+ width_tiles = [i*encoder_resolution for i in range(math.ceil(width/encoder_resolution)-1)]
29
+ width_tiles.append(width-encoder_resolution)
30
+ height_tiles = [i*encoder_resolution for i in range(math.ceil(height/encoder_resolution)-1)]
31
+ height_tiles.append(height-encoder_resolution)
32
+
33
+ for w in width_tiles:
34
+ for h in height_tiles:
35
+ coordinates.append((w,h,w+encoder_resolution,h+encoder_resolution))
36
+
37
+ cropped_images = [image.crop(c) for c in coordinates]
38
+
39
+ return cropped_images
40
+
41
+ class Encoder(nn.Module):
42
+
43
+ def __init__(self, clip_name, level = 2, dtype = None, use_dropout = True) -> None:
44
+ super().__init__()
45
+
46
+ if level not in LEVELS_TO_PATCHES:
47
+ raise ValueError("Resolution not supported")
48
+
49
+ self.n_patches = LEVELS_TO_PATCHES[level]
50
+ self.vision_model = CLIPModel.from_pretrained(clip_name, torch_dtype=dtype).vision_model
51
+ self.has_first_adapter = False
52
+ self.image_size = self.vision_model.config.image_size
53
+ self.patch_size = self.vision_model.config.patch_size
54
+ self.use_dropout = use_dropout
55
+ self.dtype = dtype
56
+
57
+ mean = (0.48145466, 0.4578275, 0.40821073)
58
+ std = (0.26862954, 0.26130258, 0.27577711)
59
+ self.image_transform = transforms.Compose([
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(mean=mean, std=std),
62
+ ])
63
+
64
+ self.norm_lvl_1 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
65
+ self.norm_lvl_2 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
66
+
67
+ # this was being used in previous experiments so now I can't remove it or won't be able to load older models....
68
+ self.norm_lvl_3 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
69
+
70
+ if level == 1:
71
+ self.connector = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
72
+ else:
73
+ self.connector = Position(self.n_patches, self.vision_model.config.hidden_size, dtype=dtype)
74
+
75
+ config_level2 = LoraConfig(
76
+ r=16,
77
+ lora_alpha=32,
78
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "patch_embedding", "fc1", "fc2"],
79
+ lora_dropout=0.05 if self.use_dropout else 0,
80
+ bias="none"
81
+ )
82
+ self.vision_model = get_peft_model(self.vision_model, config_level2, "second")
83
+
84
+ def add_first_level_adapter(self):
85
+
86
+ config_224 = LoraConfig(
87
+ r=8,
88
+ lora_alpha=16,
89
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "patch_embedding", "fc1", "fc2"],
90
+ lora_dropout=0.05 if self.use_dropout else 0,
91
+ bias="none"
92
+ )
93
+
94
+ self.vision_model.add_adapter("first", config_224)
95
+ self.has_first_adapter = True
96
+
97
+
98
+ def forward(self, images: list, device = "cpu", **kwargs):
99
+ """
100
+ shape (B, C, H, W) in list form
101
+ """
102
+ B = len(images)
103
+ h = int((self.image_size/self.patch_size) ** 2 + 1)
104
+ resized_images = {1: [], 2: []}
105
+
106
+ for i in images:
107
+ resized_images[1].append(self.image_transform(i.resize((self.image_size,self.image_size), resample=BICUBIC)))
108
+
109
+ if self.n_patches == 5:
110
+ for crop in cut_image_patches(i.resize((self.image_size * 2,self.image_size * 2), resample=BICUBIC), encoder_resolution=self.image_size):
111
+ resized_images[2].append(self.image_transform(crop))
112
+ elif self.n_patches == 10:
113
+ for crop in cut_image_patches(i.resize((self.image_size * 3,self.image_size * 3), resample=BICUBIC), encoder_resolution=self.image_size):
114
+ resized_images[2].append(self.image_transform(crop))
115
+
116
+
117
+ vision_features = []
118
+ for res, imgs in resized_images.items():
119
+ if imgs != []:
120
+ resized_images[res] = torch.stack(imgs, dim = 0).to(device)
121
+
122
+ if res == 1 and self.has_first_adapter:
123
+ self.vision_model.set_adapter("first")
124
+ vision_features.append(self.norm_lvl_1(self.vision_model(resized_images[res]).last_hidden_state))
125
+ elif res == 1:
126
+ with self.vision_model.disable_adapter():
127
+ vision_features.append(self.norm_lvl_1(self.vision_model(resized_images[res]).last_hidden_state))
128
+ elif res == 2:
129
+ self.vision_model.set_adapter("second")
130
+ if self.n_patches == 5:
131
+ vision_features.append(self.norm_lvl_2(self.vision_model(resized_images[res]).last_hidden_state.view(B, h * 4, -1)))
132
+ elif self.n_patches == 10:
133
+ vision_features.append(self.norm_lvl_2(self.vision_model(resized_images[res]).last_hidden_state.view(B, h * 9, -1)))
134
+
135
+ vision_features = torch.cat(vision_features, dim = 1)
136
+ vision_features = self.connector(vision_features)
137
+
138
+ return vision_features
139
+
140
+
141
+ class Position(nn.Module):
142
+
143
+ def __init__(self, n_patches, dim, dtype) -> None:
144
+ super().__init__()
145
+
146
+ self.embedding = nn.Embedding(max(LEVELS_TO_PATCHES.values()), dim, dtype=dtype)
147
+ self.n_patches = n_patches
148
+
149
+ self.apply(self._init_weights)
150
+
151
+ def forward(self, vision_features):
152
+
153
+ batch_size, seq_len, dim = vision_features.size()
154
+ single_encoder_dim = seq_len // self.n_patches
155
+ device = vision_features.get_device()
156
+
157
+ pos = torch.LongTensor(list(range(self.n_patches))).to(device if device != -1 else "cpu")
158
+ pos = torch.repeat_interleave(self.embedding(pos).unsqueeze(0), single_encoder_dim, 1).expand(batch_size, -1, -1)
159
+
160
+ return vision_features + pos
161
+
162
+
163
+ def _init_weights(self, module):
164
+ """Initialize the weights."""
165
+ if isinstance(module, nn.Linear):
166
+ module.weight.data.normal_(mean=0.0, std=0.02)
167
+ if module.bias is not None:
168
+ module.bias.data.zero_()
169
+ elif isinstance(module, nn.Embedding):
170
+ module.weight.data.normal_(mean=0.0, std=0.02)
171
+ if module.padding_idx is not None:
172
+ module.weight.data[module.padding_idx].zero_()
173
+ elif isinstance(module, nn.LayerNorm):
174
+ module.bias.data.zero_()
175
+ module.weight.data.fill_(1.0)
176
+
177
+ for name, p in module.named_parameters():
178
+ if name == "fc1.weight" or name == "fc2.weight" or name == "to_out.weight":
179
+ p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.n_decoder_layers)))
pheye_builder/factory.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
+ import torch
5
+
6
+ from .phEYE import phEYE
7
+ from .wrapper_lm import phEYELMMixin
8
+ from .utils import extend_instance
9
+ from .encoder import Encoder
10
+
11
+
12
+ def create_model_and_transforms(
13
+ clip_vision_encoder_path: str,
14
+ lang_decoder_path: str,
15
+ tokenizer_path: str,
16
+ dtype,
17
+ cross_attn_every_n_layers: int = 1,
18
+ use_local_files: bool = False,
19
+ decoder_layers_attr_name: str = None,
20
+ freeze_lm_embeddings: bool = True,
21
+ cache_dir: Optional[str] = None,
22
+ level: int = 2,
23
+ encoder_dtype : torch.dtype = None,
24
+ decoder_dtype : torch.dtype = None,
25
+ use_dropout : bool = False,
26
+ **pheye_kwargs,
27
+ ):
28
+ """
29
+ Initialize a phEYE model from a pretrained vision encoder and language encoder.
30
+ Appends special tokens to the tokenizer and freezes backbones.
31
+
32
+ Args:
33
+ clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
34
+ clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
35
+ lang_encoder_path (str): path to pretrained language encoder
36
+ tokenizer_path (str): path to pretrained tokenizer
37
+ cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
38
+ use_local_files (bool, optional): whether to use local files. Defaults to False.
39
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
40
+ freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
41
+ cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
42
+ Returns:
43
+ phEYE: phEYE model from pretrained vision and language encoders
44
+ Image processor: Pipeline to preprocess input images
45
+ Tokenizer: A tokenizer for the language model
46
+ """
47
+
48
+ vision_encoder = Encoder(clip_vision_encoder_path, level=level, dtype=encoder_dtype, use_dropout=use_dropout)
49
+
50
+
51
+ text_tokenizer = AutoTokenizer.from_pretrained(
52
+ tokenizer_path,
53
+ local_files_only=use_local_files,
54
+ trust_remote_code=True,
55
+ cache_dir=cache_dir,
56
+ )
57
+
58
+ if text_tokenizer.pad_token is None:
59
+ text_tokenizer.pad_token = text_tokenizer.eos_token
60
+
61
+ #print(lang_decoder_path)
62
+ lang_config = AutoConfig.from_pretrained(lang_decoder_path)
63
+ #print(lang_config)
64
+ lang_encoder = AutoModelForCausalLM.from_config(
65
+ lang_config,
66
+ #local_files_only=use_local_files,
67
+ #trust_remote_code=True,
68
+ torch_dtype=decoder_dtype
69
+ )
70
+
71
+ lang_encoder.config.decoder_start_token_id = None
72
+ lang_encoder.config.pad_token_id = text_tokenizer.pad_token_id
73
+
74
+ # convert LM to phEYELM
75
+ extend_instance(lang_encoder, phEYELMMixin)
76
+
77
+ if decoder_layers_attr_name is None:
78
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
79
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
80
+
81
+ model = phEYE(
82
+ vision_encoder,
83
+ lang_encoder,
84
+ vis_dim=vision_encoder.vision_model.config.hidden_size,
85
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
86
+ dtype=dtype,
87
+ **pheye_kwargs,
88
+ )
89
+
90
+ # Freeze all parameters
91
+ model.lang_encoder.requires_grad_(False)
92
+ assert sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad) == 0
93
+
94
+ # Unfreeze perceiver, cross_attn_layers, and LM input embeddings
95
+ model.lang_encoder.cross_attn_layers.requires_grad_(True)
96
+ if not freeze_lm_embeddings:
97
+ model.lang_encoder.get_input_embeddings().requires_grad_(True)
98
+
99
+ print(
100
+ f"phEYE model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
101
+ )
102
+
103
+ return model, text_tokenizer
104
+
105
+
106
+ def _infer_decoder_layers_attr_name(model):
107
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
108
+ if k.lower() in model.__class__.__name__.lower():
109
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
110
+
111
+ raise ValueError(
112
+ f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
113
+ )
114
+
115
+
116
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
117
+ "opt": "model.decoder.layers",
118
+ "gpt": "transformer.h",
119
+ "gpt-j": "transformer.h",
120
+ "pythia": "gpt_neox.layers",
121
+ "llama": "model.layers",
122
+ "gptneoxforcausallm": "gpt_neox.layers",
123
+ "mpt": "transformer.blocks",
124
+ "mosaicgpt": "transformer.blocks",
125
+ "phi" : "model.layers"
126
+ }
pheye_builder/phEYE.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import LoraConfig, get_peft_model
3
+ from torch import nn
4
+ import os
5
+
6
+
7
+ class phEYE(nn.Module):
8
+ def __init__(
9
+ self,
10
+ vision_encoder: nn.Module,
11
+ lang_encoder: nn.Module,
12
+ vis_dim: int,
13
+ dtype: torch.dtype,
14
+ cross_attn_every_n_layers: int = 1,
15
+ gradient_checkpointing: bool = False,
16
+ reduce_factor = 1,
17
+ from_layer = 0
18
+ ):
19
+ """
20
+ Args:
21
+ vision_encoder (nn.Module): module with OpenCLIP model
22
+ lang_encoder (nn.Module): HF causal language model
23
+ vis_dim (int): Dimension of the visual features.
24
+ Visual features are projected to match this shape along the last dimension.
25
+ cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
26
+ """
27
+ super().__init__()
28
+ self.vis_dim = vis_dim
29
+ if hasattr(lang_encoder.config, "d_model"):
30
+ self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
31
+ else:
32
+ self.lang_dim = lang_encoder.config.hidden_size
33
+
34
+ self.vision_encoder = vision_encoder
35
+ self.lang_encoder = lang_encoder
36
+ self.lang_encoder.init_pheye(
37
+ lang_hidden_size=self.lang_dim,
38
+ vis_hidden_size=self.vis_dim,
39
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
40
+ gradient_checkpointing=gradient_checkpointing,
41
+ reduce_factor=reduce_factor,
42
+ from_layer=from_layer,
43
+ dtype=dtype
44
+ )
45
+ self._use_gradient_checkpointing = gradient_checkpointing
46
+
47
+ def forward(
48
+ self,
49
+ vision_x: list,
50
+ lang_x: torch.Tensor,
51
+ attention_mask: torch.Tensor = None,
52
+ labels: torch.Tensor = None,
53
+ clear_conditioned_layers: bool = True,
54
+ past_key_values = None,
55
+ use_cache: bool = False,
56
+ device="cpu",
57
+ is_textcaps = False
58
+ ):
59
+ """
60
+ Forward pass of phEYE.
61
+
62
+ Args:
63
+ vision_x (list): Vision input
64
+ shape (B, C, H, W)
65
+ lang_x (torch.Tensor): Language input ids
66
+ shape (B, txt_seq)
67
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
68
+ labels (torch.Tensor, optional): Labels. Defaults to None.
69
+ clear_conditioned_layers: if True, clear the conditioned layers
70
+ once the foward pass is completed. Set this to false if the
71
+ same set of images will be reused in another subsequent
72
+ forward pass.
73
+ past_key_values: pre-computed values to pass to language model.
74
+ See past_key_values documentation in Hugging Face
75
+ CausalLM models.
76
+ use_cache: whether to use cached key values. See use_cache
77
+ documentation in Hugging Face CausalLM models.
78
+ """
79
+ assert (
80
+ self.lang_encoder.initialized_pheye
81
+ ), "Wrapper layers are not initialized. Please call `initialized_pheye` first."
82
+
83
+ assert (
84
+ self.lang_encoder._use_cached_vision_x or vision_x is not None
85
+ ), "Must provide either vision_x or have precached media using cache_media()."
86
+
87
+ if self.lang_encoder._use_cached_vision_x:
88
+ # Case: use cached; vision_x should be cached and other
89
+ # vision-related inputs should not be provided.
90
+ assert (
91
+ vision_x is None
92
+ ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
93
+ assert self.lang_encoder.is_conditioned()
94
+
95
+ else:
96
+ # Case: do not use caching (i.e. this is a standard forward pass);
97
+ self._encode_vision_x(vision_x=vision_x, device=device, is_textcaps=is_textcaps)
98
+
99
+ #print(f"Text features shape: {lang_x.shape}")
100
+ output = self.lang_encoder(
101
+ input_ids=lang_x,
102
+ attention_mask=attention_mask,
103
+ labels=labels,
104
+ past_key_values=past_key_values,
105
+ use_cache=use_cache,
106
+ )
107
+
108
+ if clear_conditioned_layers:
109
+ self.lang_encoder.clear_conditioned_layers()
110
+
111
+ return output
112
+
113
+ def generate(
114
+ self,
115
+ vision_x: list,
116
+ lang_x: torch.Tensor,
117
+ attention_mask: torch.Tensor = None,
118
+ device = "cpu",
119
+ **kwargs,
120
+ ):
121
+ """
122
+ Generate text conditioned on vision and language inputs.
123
+
124
+ Args:
125
+ vision_x (list): Vision input
126
+ shape (B, C, H, W)
127
+ images in the same chunk are collated along T_img, and frames are collated along F
128
+ currently only F=1 is supported (single-frame videos)
129
+ lang_x (torch.Tensor): Language input
130
+ shape (B, T_txt)
131
+ **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
132
+ max_length (int, optional): Maximum length of the output. Defaults to None.
133
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
134
+ num_beams (int, optional): Number of beams. Defaults to 1.
135
+ max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
136
+ temperature (float, optional): Temperature. Defaults to 1.0.
137
+ top_k (int, optional): Top k. Defaults to 50.
138
+ top_p (float, optional): Top p. Defaults to 1.0.
139
+ no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
140
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
141
+ num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
142
+ do_sample (bool, optional): Do sample. Defaults to False.
143
+ early_stopping (bool, optional): Early stopping. Defaults to False.
144
+ Returns:
145
+ torch.Tensor: lang_x with generated tokens appended to it
146
+ """
147
+ num_beams = kwargs.pop("num_beams", 1)
148
+
149
+ self.lang_encoder._use_cached_vision_x = True
150
+ self._encode_vision_x(vision_x=vision_x, device=device, repeat=num_beams)
151
+
152
+ output = self.lang_encoder.generate(
153
+ input_ids=lang_x,
154
+ attention_mask=attention_mask,
155
+ num_beams=num_beams,
156
+ **kwargs,
157
+ )
158
+
159
+ self.lang_encoder.clear_conditioned_layers()
160
+ self.lang_encoder._use_cached_vision_x = False
161
+ return output
162
+
163
+ def _encode_vision_x(self, vision_x: list, device="cpu", repeat = 1, is_textcaps = False):
164
+ """
165
+ Compute vision features by passing images through vision encoder and conditioning language model.
166
+ Args:
167
+ vision_x (list): Vision input
168
+ shape (B, C, H, W)
169
+ """
170
+ if is_textcaps:
171
+ vision_x = vision_x[::5]
172
+ repeat = 5
173
+
174
+ vision_x = self.vision_encoder(vision_x, device=device)
175
+
176
+ if repeat > 1:
177
+ vision_x = vision_x.repeat_interleave(repeat, dim=0)
178
+
179
+ for layer in self.lang_encoder._get_decoder_layers():
180
+ layer.condition_vis_x(vision_x)
181
+
182
+
183
+ def cache_media(self, vision_x: list, device="cpu"):
184
+ """
185
+ Cache vision_x features from list of images for log-likelihood evaluation
186
+ This is not meant to be used to cache things for generate().
187
+ Args:
188
+ vision_x (torch.Tensor): Vision input
189
+ shape (B, F, C, H, W)
190
+ """
191
+ self._encode_vision_x(vision_x=vision_x, device=device)
192
+ self.lang_encoder._use_cached_vision_x = True
193
+
194
+ def uncache_media(self):
195
+ """
196
+ Clear all conditioning.
197
+ """
198
+ self.lang_encoder.clear_conditioned_layers()
199
+ self.lang_encoder._use_cached_vision_x = False
200
+
201
+ def save_model(self, _path):
202
+ os.mkdir(_path)
203
+ torch.save(self.vision_encoder.state_dict(), _path+"vision_encoder.pt")
204
+ torch.save(self.lang_encoder.state_dict(), _path+"lang_encoder.pt")
205
+
206
+ def add_lora_decoder(self):
207
+
208
+ config = LoraConfig(
209
+ r=16,
210
+ lora_alpha=32,
211
+ target_modules=["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"],
212
+ lora_dropout=0.05,
213
+ bias="none"
214
+ )
215
+
216
+ self.lang_encoder.old_decoder_blocks = get_peft_model(self.lang_encoder.old_decoder_blocks, config)
217
+
218
+ def merge_and_unload(self):
219
+ self.lang_encoder.old_decoder_blocks = self.lang_encoder.old_decoder_blocks.merge_and_unload()
220
+
pheye_builder/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extend_instance(obj, mixin):
2
+ """Apply mixins to a class instance after creation"""
3
+ base_cls = obj.__class__
4
+ base_cls_name = obj.__class__.__name__
5
+ obj.__class__ = type(
6
+ base_cls_name, (mixin, base_cls), {}
7
+ ) # mixin needs to go first for our forward() logic to work
8
+
9
+
10
+ def getattr_recursive(obj, att):
11
+ """
12
+ Return nested attribute of obj
13
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
14
+ """
15
+ if att == "":
16
+ return obj
17
+ i = att.find(".")
18
+ if i < 0:
19
+ return getattr(obj, att)
20
+ else:
21
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
22
+
23
+
24
+ def setattr_recursive(obj, att, val):
25
+ """
26
+ Set nested attribute of obj
27
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
28
+ """
29
+ if "." in att:
30
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
31
+ setattr(obj, att.split(".")[-1], val)
32
+
33
+
34
+ def apply_with_stopping_condition(
35
+ module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
36
+ ):
37
+ if stopping_condition(module):
38
+ return
39
+ if apply_condition(module):
40
+ apply_fn(module, **other_args)
41
+ for child in module.children():
42
+ apply_with_stopping_condition(
43
+ child,
44
+ apply_fn,
45
+ apply_condition=apply_condition,
46
+ stopping_condition=stopping_condition,
47
+ **other_args
48
+ )
pheye_builder/wrapper_lm.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .xattn import CrossAttentionBlock
3
+ from .utils import getattr_recursive, setattr_recursive
4
+
5
+
6
+ class WrapperLayer(nn.Module):
7
+ """
8
+ WrapperLayer is a wrapper around the CrossAttentionBlock and DecoderLayer.
9
+ """
10
+
11
+ def __init__(
12
+ self, cross_attn_layer, decoder_layer, gradient_checkpointing=False
13
+ ):
14
+ super().__init__()
15
+ self.cross_attn_layer = cross_attn_layer
16
+ self.decoder_layer = decoder_layer
17
+ self.vis_x = None
18
+ if self.cross_attn_layer is not None:
19
+ self.cross_attn_layer._use_gradient_checkpointing = (
20
+ gradient_checkpointing
21
+ )
22
+ self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
23
+
24
+ def is_conditioned(self) -> bool:
25
+ """Check whether the layer is conditioned."""
26
+ return self.vis_x is not None
27
+
28
+ # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
29
+ def condition_vis_x(self, vis_x):
30
+ self.vis_x = vis_x
31
+
32
+ def forward(
33
+ self,
34
+ lang_x,
35
+ attention_mask=None,
36
+ **decoder_layer_kwargs,
37
+ ):
38
+ # Cross attention
39
+ if self.cross_attn_layer is not None:
40
+ if self.vis_x is None:
41
+ raise ValueError("vis_x must be conditioned before forward pass")
42
+
43
+ lang_x = self.cross_attn_layer(
44
+ lang_x,
45
+ self.vis_x
46
+ )
47
+
48
+ # Normal decoder layer
49
+ lang_x = self.decoder_layer(
50
+ lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
51
+ )
52
+
53
+ return lang_x
54
+
55
+
56
+ class phEYELMMixin(nn.Module):
57
+ """
58
+ Mixin to add cross-attention layers to a language model.
59
+ """
60
+
61
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
62
+ self.decoder_layers_attr_name = decoder_layers_attr_name
63
+
64
+ def _get_decoder_layers(self):
65
+ return getattr_recursive(self, self.decoder_layers_attr_name)
66
+
67
+ def _set_decoder_layers(self, value):
68
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
69
+
70
+ def init_pheye(
71
+ self,
72
+ lang_hidden_size,
73
+ vis_hidden_size,
74
+ dtype,
75
+ cross_attn_every_n_layers,
76
+ gradient_checkpointing,
77
+ reduce_factor=1,
78
+ from_layer=0
79
+ ):
80
+ """
81
+ Initialize phEYE by adding a new cross attn to the decoder.
82
+ """
83
+ self.old_decoder_blocks = self._get_decoder_layers()
84
+ self.cross_attn_layers = nn.ModuleList(
85
+ [
86
+ CrossAttentionBlock(
87
+ dim_text=lang_hidden_size, dim_visual=vis_hidden_size, reduce_factor=reduce_factor, layer_idx=layer_idx, n_decoder_layers=len(self.old_decoder_blocks), dtype=dtype
88
+ )
89
+ if (layer_idx + 1) % cross_attn_every_n_layers == 0 and layer_idx >= from_layer
90
+ else None
91
+ for layer_idx, _ in enumerate(self._get_decoder_layers())
92
+ ]
93
+ )
94
+ self.init_pheye_layers(gradient_checkpointing)
95
+ self.initialized_pheye = True
96
+ self._use_cached_vision_x = False
97
+
98
+ def init_pheye_layers(self, gradient_checkpointing):
99
+ """
100
+ Re initializes the WrapperLayers.
101
+ Propagates any changes made to self.cross_attn_layers or self.old_decoder_blocks
102
+ """
103
+ self._set_decoder_layers(
104
+ nn.ModuleList(
105
+ [
106
+ WrapperLayer(
107
+ cross_attn_layer, decoder_layer, gradient_checkpointing
108
+ )
109
+ for cross_attn_layer, decoder_layer in zip(
110
+ self.cross_attn_layers, self.old_decoder_blocks
111
+ )
112
+ ]
113
+ )
114
+ )
115
+
116
+ def forward(self, input_ids, attention_mask, **kwargs):
117
+ if not self.initialized_pheye:
118
+ raise ValueError(
119
+ "phEYE layers are not initialized. Please call `init_pheye` first."
120
+ )
121
+
122
+ kwargs["input_ids"] = input_ids
123
+ kwargs["attention_mask"] = attention_mask
124
+ return super().forward(**kwargs) # Call the other parent's forward method
125
+
126
+ def is_conditioned(self) -> bool:
127
+ """Check whether all decoder layers are already conditioned."""
128
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
129
+
130
+ def clear_conditioned_layers(self):
131
+ for layer in self._get_decoder_layers():
132
+ layer.condition_vis_x(None)
pheye_builder/xattn.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ from einops import rearrange
6
+ from einops_exts import rearrange_many
7
+ from torch import einsum, nn
8
+ import math
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+
14
+ class FeedForward(nn.Module):
15
+
16
+ def __init__(self, dim, dtype, reduce_factor = 1):
17
+ super().__init__()
18
+ mult = 4
19
+ self.norm = nn.LayerNorm(dim, dtype=dtype)
20
+ inner_dim = int(dim * mult) // reduce_factor
21
+
22
+ self.fc1 = nn.Linear(dim, inner_dim, dtype=dtype)
23
+ self.fc2 = nn.Linear(inner_dim, dim, dtype=dtype)
24
+ self.act = nn.GELU()
25
+
26
+ def forward(self, x):
27
+
28
+ x = self.norm(x)
29
+ x = self.fc1(x)
30
+ x = self.act(x)
31
+ x = self.fc2(x)
32
+
33
+ return x
34
+
35
+ # cross attention
36
+ class CrossAttention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ *,
40
+ dim_text,
41
+ dim_visual,
42
+ dtype,
43
+ dim_head=64,
44
+ reduce_factor=1
45
+ ):
46
+ super().__init__()
47
+ self.scale = dim_head**-0.5
48
+ max_dim = max(dim_text, dim_visual)
49
+ self.heads = max_dim // dim_head
50
+ assert max_dim % dim_head == 0, f"Number of heads in CrossAttention is not an int - {self.heads}"
51
+ inner_dim = max_dim // reduce_factor
52
+
53
+ self.norm = nn.LayerNorm(dim_text, dtype=dtype)
54
+
55
+ self.to_q = nn.Linear(dim_text, inner_dim, dtype=dtype)
56
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, dtype=dtype)
57
+ #self.to_kv_second = nn.Linear(dim_visual, inner_dim * 2)
58
+ self.to_out = nn.Linear(inner_dim, dim_text, dtype=dtype)
59
+ #self.g = []
60
+ #self.l = []
61
+
62
+ def forward(self, x, media):
63
+ """
64
+ Args:
65
+ x (torch.Tensor): text features
66
+ shape (B, txt_seq, D_txt)
67
+ media (torch.Tensor): image features
68
+ shape (B, img_seq, D_img) where img_seq is the number of concatenated features from the ViT. For example:
69
+ for an encoder of 224x224 with patch size 14 and processing images of 896x896 (with 3 levels) it will be (1 + 4 + 16) * 257 = 5397
70
+ """
71
+
72
+ h = self.heads
73
+
74
+ x = self.norm(x)
75
+ q = self.to_q(x)
76
+
77
+ k, v = self.to_kv(media).chunk(2, dim=-1)
78
+ """k_s, v_s = self.to_kv(media[:, 257:, :]).chunk(2, dim=-1)
79
+ k = torch.cat((k, k_s), 1)
80
+ v = torch.cat((v, v_s), 1)"""
81
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
82
+
83
+ q = q * self.scale
84
+
85
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
86
+
87
+ attn = sim.softmax(dim=-1)
88
+ #idk = torch.mean(attn.squeeze()[:, 65:, :], (0, 1))
89
+ #self.g.append(torch.sum(idk[:257]).item())
90
+ #self.l.append(torch.sum(idk[257:]).item())
91
+
92
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
93
+ out = rearrange(out, "b h n d -> b n (h d)")
94
+ return self.to_out(out)
95
+
96
+ # cross attention
97
+ class CrossAttentionBlock(nn.Module):
98
+ def __init__(
99
+ self,
100
+ *,
101
+ dim_text,
102
+ dim_visual,
103
+ dtype,
104
+ dim_head=64,
105
+ reduce_factor = 1,
106
+ layer_idx=0,
107
+ n_decoder_layers = 24
108
+ ):
109
+ super().__init__()
110
+ self.attn = CrossAttention(
111
+ dim_text=dim_text,
112
+ dim_visual=dim_visual,
113
+ dim_head=dim_head,
114
+ reduce_factor=reduce_factor,
115
+ dtype=dtype
116
+ )
117
+
118
+ self.ff = FeedForward(dim_text, reduce_factor=reduce_factor, dtype=dtype)
119
+ self.layer_idx = layer_idx
120
+ self.n_decoder_layers = n_decoder_layers
121
+
122
+ self.apply(self._init_weights)
123
+
124
+ def forward(
125
+ self,
126
+ x,
127
+ media
128
+ ):
129
+
130
+ x = (
131
+ self.attn(
132
+ x,
133
+ media
134
+ )
135
+ + x
136
+ )
137
+
138
+
139
+ x = self.ff(x) + x
140
+
141
+ return x
142
+
143
+ def _init_weights(self, module):
144
+ """Initialize the weights."""
145
+ if isinstance(module, nn.Linear):
146
+ module.weight.data.normal_(mean=0.0, std=0.01)
147
+ if module.bias is not None:
148
+ module.bias.data.zero_()
149
+ elif isinstance(module, nn.Embedding):
150
+ module.weight.data.normal_(mean=0.0, std=0.02)
151
+ if module.padding_idx is not None:
152
+ module.weight.data[module.padding_idx].zero_()
153
+ elif isinstance(module, nn.LayerNorm):
154
+ module.bias.data.zero_()
155
+ module.weight.data.fill_(1.0)
156
+
157
+ for name, p in module.named_parameters():
158
+ if name == "fc2.weight" or name == "to_out.weight":
159
+ p.data.normal_(mean=0.0, std=(0.01 / math.sqrt(2 * max(self.n_decoder_layers, 36))))
requirements.txt CHANGED
@@ -1 +1,8 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ transformers==4.37.0
3
+ pillow==10.3.0
4
+ torch==2.1.1
5
+ torchvision==0.16.1
6
+ peft==0.7.0
7
+ einops==0.6.1
8
+ einops-exts==0.0.4