File size: 7,770 Bytes
0b7b08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import random
import torch
import torch.nn as nn
import numpy as np

from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive


class FlamingoLayer(nn.Module):
    def __init__(self, decoder_layer):
        super().__init__()
        self.decoder_layer = decoder_layer
        self.vis_x = None
        self.image_nums = None
        self.image_start_index_list = None
        self.media_locations = None
        self.add_visual_token = False
        self.input_ids = None

    def is_conditioned(self) -> bool:
        """Check whether the layer is conditioned."""
        return self.vis_x is not None

    # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
    def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, num_beams=None, visual_tokens=None, data_list=None):
        self.vis_x = vis_x
        self.image_nums = image_nums
        self.image_start_index_list = image_start_index_list
        self.num_beams = num_beams
        self.visual_tokens = visual_tokens
        self.data_list = data_list
        self.input_ids = None


    def condition_media_locations(self, media_locations):
        self.media_locations = media_locations

    def condition_attend_previous(self, attend_previous):
        self.attend_previous = attend_previous

    def forward(
        self,
        hidden_states, # alignment with hugging face name
        attention_mask=None,
        **decoder_layer_kwargs,
    ):
        if self.media_locations is None:
            raise ValueError("media_locations must be conditioned before forward pass")

        if self.vis_x is not None:
            if self.training:
                single_length = self.vis_x.shape[-2]
                image_nums = self.image_nums
                image_start_index_list = self.image_start_index_list
                image_nums = [0] + np.cumsum(image_nums).tolist()
                for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
                    for index in start_indices:
                        if image_num_begin < image_num_end:
                            hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
                            image_num_begin += 1

                if self.visual_tokens is not None and len(self.visual_tokens) != 0:
                    for i, (x, y) in enumerate(self.data_list):
                        if len(self.visual_tokens[i].shape) > 1:
                            # print(self.visual_tokens[i].shape[0], "embedding")
                            hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
                        else:
                            # print(self.visual_tokens[i].shape[0], "embedding")
                            hidden_states[x, y] = self.visual_tokens[i]

            elif not self.training:
                if (
                    ("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or
                    ("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None)
                ):
                    single_length = self.vis_x.shape[-2]
                    image_nums = self.image_nums
                    image_start_index_list = self.image_start_index_list
                    image_nums = [0] + np.cumsum(image_nums).tolist()
                    for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
                        for index in start_indices:
                            if image_num_begin < image_num_end:
                                hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
                                image_num_begin += 1
                    if self.visual_tokens is not None and len(self.visual_tokens) != 0:
                        for i, (x, y) in enumerate(self.data_list):
                            # import pdb; pdb.set_trace()
                            # print(x, y, self.visual_tokens[i].shape)
                            if len(self.visual_tokens[i].shape) > 1:
                                # print(self.visual_tokens[i].shape[0], "embedding")
                                hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
                            else:
                                # print(self.visual_tokens[i].shape[0], "embedding")
                                hidden_states[x, y] = self.visual_tokens[i]
        hidden_states = self.decoder_layer(
            hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs
        )
        return hidden_states


class FlamingoLMMixin(nn.Module):
    """
    Mixin to add cross-attention layers to a language model.
    """

    def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
        self.decoder_layers_attr_name = decoder_layers_attr_name

    def _get_decoder_layers(self):
        return getattr_recursive(self, self.decoder_layers_attr_name)

    def _set_decoder_layers(self, value):
        setattr_recursive(self, self.decoder_layers_attr_name, value)

    def init_flamingo(
        self,
        media_token_id,
        use_media_placement_augmentation,
    ):
        """
        Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
        """
        self._set_decoder_layers(
            nn.ModuleList(
                [FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()]
            )
        )
        self.media_token_id = media_token_id
        self.use_media_placement_augmentation = use_media_placement_augmentation
        self.initialized_flamingo = True

    def forward(self, *input, **kwargs):
        """Condition the Flamingo layers on the media locations before forward()"""
        if not self.initialized_flamingo:
            raise ValueError(
                "Flamingo layers are not initialized. Please call `init_flamingo` first."
            )

        input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
        media_locations = input_ids == self.media_token_id
        attend_previous = (
            (random.random() < 0.5) if self.use_media_placement_augmentation else True
        )

        if (
            "gpt2" in self.__class__.__name__.lower()
            or "codegen" in self.__class__.__name__.lower()
        ):
            for layer in self.transformer.h:
                layer.condition_media_locations(media_locations)
                layer.condition_attend_previous(attend_previous)
        elif "gptneox" in self.__class__.__name__.lower():
            for layer in self.gpt_neox.layers:
                layer.condition_media_locations(media_locations)
                layer.condition_attend_previous(attend_previous)
        else:
            for layer in self.get_decoder().layers:
                layer.condition_media_locations(media_locations)
                layer.condition_attend_previous(attend_previous)
        return super().forward(
            *input, **kwargs
        )  # Call the other parent's forward method

    def is_conditioned(self) -> bool:
        """Check whether all decoder layers are already conditioned."""
        return all(l.is_conditioned() for l in self._get_decoder_layers())

    def clear_conditioned_layers(self):
        for layer in self._get_decoder_layers():
            layer.condition_vis_x(None)
            layer.condition_media_locations(None)
            layer.condition_attend_previous(None)