Spaces:
Runtime error
Runtime error
File size: 7,770 Bytes
0b7b08a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import random
import torch
import torch.nn as nn
import numpy as np
from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
class FlamingoLayer(nn.Module):
def __init__(self, decoder_layer):
super().__init__()
self.decoder_layer = decoder_layer
self.vis_x = None
self.image_nums = None
self.image_start_index_list = None
self.media_locations = None
self.add_visual_token = False
self.input_ids = None
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, num_beams=None, visual_tokens=None, data_list=None):
self.vis_x = vis_x
self.image_nums = image_nums
self.image_start_index_list = image_start_index_list
self.num_beams = num_beams
self.visual_tokens = visual_tokens
self.data_list = data_list
self.input_ids = None
def condition_media_locations(self, media_locations):
self.media_locations = media_locations
def condition_attend_previous(self, attend_previous):
self.attend_previous = attend_previous
def forward(
self,
hidden_states, # alignment with hugging face name
attention_mask=None,
**decoder_layer_kwargs,
):
if self.media_locations is None:
raise ValueError("media_locations must be conditioned before forward pass")
if self.vis_x is not None:
if self.training:
single_length = self.vis_x.shape[-2]
image_nums = self.image_nums
image_start_index_list = self.image_start_index_list
image_nums = [0] + np.cumsum(image_nums).tolist()
for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
for index in start_indices:
if image_num_begin < image_num_end:
hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
image_num_begin += 1
if self.visual_tokens is not None and len(self.visual_tokens) != 0:
for i, (x, y) in enumerate(self.data_list):
if len(self.visual_tokens[i].shape) > 1:
# print(self.visual_tokens[i].shape[0], "embedding")
hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
else:
# print(self.visual_tokens[i].shape[0], "embedding")
hidden_states[x, y] = self.visual_tokens[i]
elif not self.training:
if (
("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or
("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None)
):
single_length = self.vis_x.shape[-2]
image_nums = self.image_nums
image_start_index_list = self.image_start_index_list
image_nums = [0] + np.cumsum(image_nums).tolist()
for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
for index in start_indices:
if image_num_begin < image_num_end:
hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
image_num_begin += 1
if self.visual_tokens is not None and len(self.visual_tokens) != 0:
for i, (x, y) in enumerate(self.data_list):
# import pdb; pdb.set_trace()
# print(x, y, self.visual_tokens[i].shape)
if len(self.visual_tokens[i].shape) > 1:
# print(self.visual_tokens[i].shape[0], "embedding")
hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
else:
# print(self.visual_tokens[i].shape[0], "embedding")
hidden_states[x, y] = self.visual_tokens[i]
hidden_states = self.decoder_layer(
hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs
)
return hidden_states
class FlamingoLMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_flamingo(
self,
media_token_id,
use_media_placement_augmentation,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
"""
self._set_decoder_layers(
nn.ModuleList(
[FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()]
)
)
self.media_token_id = media_token_id
self.use_media_placement_augmentation = use_media_placement_augmentation
self.initialized_flamingo = True
def forward(self, *input, **kwargs):
"""Condition the Flamingo layers on the media locations before forward()"""
if not self.initialized_flamingo:
raise ValueError(
"Flamingo layers are not initialized. Please call `init_flamingo` first."
)
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
media_locations = input_ids == self.media_token_id
attend_previous = (
(random.random() < 0.5) if self.use_media_placement_augmentation else True
)
if (
"gpt2" in self.__class__.__name__.lower()
or "codegen" in self.__class__.__name__.lower()
):
for layer in self.transformer.h:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
elif "gptneox" in self.__class__.__name__.lower():
for layer in self.gpt_neox.layers:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
else:
for layer in self.get_decoder().layers:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
return super().forward(
*input, **kwargs
) # Call the other parent's forward method
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
layer.condition_attend_previous(None)
|