File size: 16,282 Bytes
d32f7f0
 
09a0bd6
 
 
 
 
 
 
 
 
 
0d8b94f
09a0bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d32f7f0
09a0bd6
 
 
 
 
 
 
 
 
b3e7614
d32f7f0
 
09a0bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d32f7f0
120c986
 
 
09a0bd6
 
 
d32f7f0
b3e7614
 
d32f7f0
09a0bd6
 
d32f7f0
 
 
 
 
 
 
09a0bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45ea980
 
 
 
 
 
 
 
 
d32f7f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45ea980
d32f7f0
 
09a0bd6
 
 
 
 
 
 
 
 
d32f7f0
 
45ea980
d32f7f0
 
 
 
 
 
 
 
 
b3e7614
d32f7f0
09a0bd6
b3e7614
09a0bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
from transformers import TFPreTrainedModel, PreTrainedTokenizer, BatchEncoding

from tensorflow.keras.models import Model, load_model, Sequential
from tensorflow.keras.layers import Layer, Dense, concatenate, Input, add, Dropout, LayerNormalization, MultiHeadAttention, Embedding
import tensorflow as tf
import numpy as np

from typing import Dict

import re
import unicodedata

from .configuration_bilma import BilmaConfig

# copied from preprocessing.py
BLANK = ' '

RE_OPS = re.I | re.M | re.S
RE_USR = re.compile(r"""@\S+""", RE_OPS)
RE_TAG = re.compile(r"""#\S+""", RE_OPS)
RE_URL = re.compile(r"""(http|ftp|https)://\S+""", RE_OPS)
RE_NUM = re.compile(r"""[-+]?\d+\.?\d*""", RE_OPS)

SYMBOLS_ = "()[]¿?¡!{}~<>|"
SYMBOLS = set(";:,.@\\-\"/" + SYMBOLS_)



# ------------------
# Class declaration
# ------------------


class TFBilma(TFPreTrainedModel):
    config_class = BilmaConfig
    main_input_name = "input_ids"
    #base_model_prefix = "bilma"

    def __init__(self, config):
        self.seq_max_length = config.seq_max_length
        self.include_top = config.include_top
        self.add_head = config.add_head
        super().__init__(config)

        self.model = bilma(num_enc=config.num_hidden_layers,
                           embed_dim=config.hidden_size, 
                           max_length=config.seq_max_length,
                           num_heads=config.num_attention_heads,
                           ff_dim=config.hidden_size,
                           vocab_size=config.vocab_size,
                           rate=config.hidden_dropout_prob,
                           include_top = config.include_top,
                           add_head = config.add_head,
                           pooling = config.pooling)
            
    @property
    def dummy_inputs(self) -> Dict[str, tf.Tensor]:
    
        dummies = {}
        for key, spec in self.input_signature.items():
            dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
            if spec.shape[0] is None:
                dummy_shape[0] = 1
            dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
        
        
        return dummies
    
    @property
    def input_signature(self) -> Dict[str, tf.TensorSpec]:
        sig = {}
        sig["input_ids"] = tf.TensorSpec([None, self.seq_max_length], tf.int32, name="input_ids")
        return sig
    
    
    def call(self, inputs):        
        if isinstance(inputs, Dict) or isinstance(inputs, BatchEncoding):
            ins = tf.cast(inputs["input_ids"], tf.float32)
        else:
            ins = inputs
        if self.include_top:
            output = {"logits":self.model(ins)}
        else:
            if self.add_head is None:
                output = {"last_hidden_state":self.model(ins)}
            else:
                output = {"label":self.model(ins)}
        return output
    
    def get_loss_function():
        return loss_funtion()
    
    def get_acc_function():
        return accuracy_function()
    
    
# copied from bilma_model.py
# --------------------------

def loss_function(ignore_id=0):
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    def loss(real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, ignore_id))
        loss_ = loss_object(real, pred)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        sum_ = tf.reduce_sum(mask,axis=1)
        
        loss_ = tf.math.divide_no_nan(tf.reduce_sum(loss_, axis=1), sum_)
        return loss_
    return loss

def accuracy_function(ignore_id=0):
    def acc_mlm(real, pred):
        accuracies = tf.equal(tf.cast(real, tf.int64), tf.argmax(pred, axis=2))

        mask = tf.math.logical_not(tf.math.equal(real, ignore_id))
        accuracies = tf.math.logical_and(mask, accuracies)

        accuracies = tf.cast(accuracies, dtype=tf.float32)
        mask = tf.cast(mask, dtype=tf.float32)
        return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask))
    return acc_mlm

def mean_vectors(inputs, enc_vectors, max_length):
    p = tf.where(inputs == 3)
    pos = tf.transpose(p)[1]
    C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
    C = tf.reshape(C, (-1, max_length, 1))
    S = tf.reduce_sum(enc_vectors * C, 1)
    x = S / tf.expand_dims(tf.cast(pos, tf.float32), (1))
    return x

def mean_diff_vectors(inputs, enc_vectors, max_length):
    p = tf.where(inputs == 3)
    pos = tf.transpose(p)[1]
    C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
    C = tf.reshape(C, (-1, max_length, 1))
    vecs = enc_vectors * C
    S = tf.reduce_sum(vecs, 1)
    mu = S / tf.expand_dims(tf.cast(pos, tf.float32), (1))
    x = tf.reduce_sum(mu - vecs, 1) / tf.expand_dims(tf.cast(pos, tf.float32), (1))
    return x

def max_vectors(inputs, enc_vectors, max_length):
    p = tf.where(inputs == 3)
    pos = tf.transpose(p)[1]
    C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
    C = tf.reshape(C, (-1, max_length, 1))
    x = tf.reduce_max(enc_vectors * C, 1)
    return x

def cls_vectors(inputs, enc_vectors, max_length):
    x = tf.squeeze(enc_vectors[:, 0:1, :], axis=1)
    return x


def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True, add_head=None, pooling=None):
    capt_inputs_ids = Input(shape=(max_length, ), name='input_ids')
    capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding")
    capt_inputs = capt_embedding(capt_inputs_ids)
    
    enc = Encoder(num_enc, embed_dim, max_length, num_heads, ff_dim, rate=rate, name="bilma/encoder")
    enc_output = enc(capt_inputs)
    if include_top:
        fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output)
    else:
        x = enc_output
        if pooling == "mean":
            x = mean_vectors(capt_inputs_ids, x, max_length)
        elif pooling == "cls":                
            x = cls_vectors(capt_inputs_ids, x, max_length)
        elif pooling == "max":
            x = max_vectors(capt_inputs_ids, x, max_length)
        
        if add_head is None:
            fin_output = x
        else:            
            for i, m in enumerate(add_head[:-1]):
                x = Dense(m, use_bias=True, activation="relu", name=f"bilma/dense_ex_{i}")(x)
            fin_output = Dense(add_head[-1], use_bias=True, activation="softmax", name=f"bilma/dense_ex_final")(x)
    
    caption_model = Model(inputs=capt_inputs_ids, outputs=fin_output, name="bilma_model")
    return caption_model

def load(model_file):
    custom_objects={"EncoderBlock": EncoderBlock, 
                    "Encoder": Encoder,
                    "loss": loss_function(),
                    "acc_mlm":accuracy_function(),
                   }
    return load_model(model_file, custom_objects=custom_objects)


# 
# Copied from transformer_text.py   
# -------------------------------
class EncoderBlock(Layer):
    def __init__(self, layer_num, patch_dim, num_heads, ff_dim, rate=0.1, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.ln = layer_num
        self.p_d = patch_dim
        self.n_h = num_heads
        self.f_d = ff_dim
        self.rate = rate
        
        self.att = MultiHeadAttention(num_heads=num_heads, key_dim=patch_dim, name=f"bilma/MHA_{layer_num}")
        self.ffn = Sequential(
            #[Conv1D(ff_dim, kernel_size=1, activation=tf.nn.gelu), 
            # Conv1D(patch_dim, kernel_size=1),]
            [Dense(ff_dim, activation=tf.nn.gelu, name=f"bilma/dense1_{layer_num}"), 
             Dense(patch_dim, name=f"bilma/dense2_{layer_num}")] 
        )
        #self.layernorm0 = LayerNormalization(epsilon=1e-6)
        self.layernorm1 = LayerNormalization(epsilon=1e-6, name=f"ln1_{layer_num}")
        self.layernorm2 = LayerNormalization(epsilon=1e-6, name=f"ln2_{layer_num}")
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
        
    def get_config(self):
        config = super(EncoderBlock, self).get_config()
        config.update({"layer_num":self.ln, "patch_dim":self.p_d, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate})
        return config

    def call(self, inputs, training=False):
        #inputs = self.layernorm0(inputs)
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(add([inputs, attn_output]))
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(add([out1, ffn_output]))
    

class DecoderBlock(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.e_d = embed_dim
        self.n_h = num_heads
        self.f_d = ff_dim
        self.rate = rate
        
        self.att1 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.att2 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = Sequential(
            #[Conv1D(ff_dim, kernel_size=1, activation=tf.nn.gelu), 
            # Conv1D(embed_dim, kernel_size=1),]
            [Dense(ff_dim, activation=tf.nn.gelu), 
             Dense(embed_dim),]
        )
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
        self.dropout3 = Dropout(rate)
        
    def get_config(self):
        config = super(DecoderBlock, self).get_config()
        config.update({"embed_dim":self.e_d, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate})
        return config

    def call(self, inputs, encoder_output, look_ahead_mask, padding_mask, training=None):
        y, attn_output1 = self.att1(inputs, inputs, attention_mask=look_ahead_mask, return_attention_scores=True)
        y = self.dropout1(y, training=training)
        y = add([inputs, y])                
        out1 = self.layernorm1(y)
        
        y, attn_encoder = self.att2(out1, encoder_output, attention_mask=padding_mask, return_attention_scores=True)
        y = self.dropout2(y, training=training)
        y = add([out1, y])                
        out2 = self.layernorm1(y)
        
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        final_output =  self.layernorm2(out2 + ffn_output)
        
        return final_output, attn_output1, attn_encoder

class Encoder(Layer):
    def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        self.n = n        
        self.embed_dim = embed_dim
        self.max_length = max_length
        self.n_h = num_heads
        self.f_d = ff_dim
        self.rate = rate
        self._layers = [EncoderBlock(i, embed_dim, num_heads, ff_dim, rate=0.1, name=f"enc_block_{i}") for i in range(n)]
        self.pe = positional_encoding(self.max_length, self.embed_dim)
        
    def get_config(self):
        config = super(Encoder, self).get_config()
        config.update({"n": self.n, "embed_dim":self.embed_dim, "max_length": self.max_length, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate})
        return config
    
    def call(self, x, training=False):
        x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        x = x + self.pe[:, :tf.shape(x)[1], :]
        for layer in self._layers:
            x = layer(x, training)
        return x

    
class Decoder(Layer):
    def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        self.n = n
        self.embed_dim = embed_dim
        self.max_length = max_length
        self.n_h = num_heads
        self.f_d = ff_dim
        self.rate = rate
        self._layers = [DecoderBlock(embed_dim, num_heads, ff_dim, rate=0.1) for _ in range(n)]
        self.pe = positional_encoding(self.max_length, self.embed_dim)
    
    def get_config(self):
        config = super(Decoder, self).get_config()
        config.update({"n": self.n, "embed_dim":self.embed_dim, "max_length": self.max_length, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate})
        return config
    
    def call(self, x, encoder_output, look_ahead_mask, padding_mask, training):      
        x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        x = x + self.pe[:, :tf.shape(x)[1], :]
        
        for layer in self._layers:
            x, self_att, enc_att = layer(x, encoder_output, look_ahead_mask, padding_mask, training)

        return x




# =========================================
#   M A S K S 
# =========================================
def create_padding_mask(seq):
    """
    For self-attention
    seq shape(bs, max_length, emb_dim)
    output shape (bs, max_length, max_length)
    """
    mask = tf.cast(tf.not_equal(seq, 0), tf.bool)
    mask = tf.reduce_any(mask, 2)
    mask = tf.repeat(mask, seq.shape[1], 0)
    mask = tf.reshape(mask, (-1,seq.shape[1], seq.shape[1]))
    return tf.cast(mask, tf.float32)


def create_cross_padding_mask(seq, target_seq):
    """
    For cross-attention
    seq shape(bs, k, image_features)
    target_seq(bs, max_length, emb_dim)
    output shape (bs, max_length, k)
    """
    mask = tf.cast(tf.not_equal(target_seq, 0), tf.bool)
    mask = tf.reduce_any(mask, 2)
    mask = tf.repeat(mask, seq.shape[1], 0)
    mask = tf.reshape(mask, (-1, tf.shape(seq)[1], tf.shape(target_seq)[1]))
    mask = tf.transpose(mask, [0, 2, 1])
    return mask


def create_look_ahead_mask(seq):
    """
    seq shape(bs, max_length, emb_dim)
    output 2D matrix of shape (bs, max_length, max_length) with ones on the diagonal and below.
    """
    size = seq.shape[1]
    mask = tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    mask = tf.expand_dims(mask, 0)
    mask = tf.repeat(mask, tf.shape(seq)[0], 0)
    return mask


def create_masks(seq, target_seq):
    decoder_mask = create_padding_mask(target_seq)
    decoder_mask *= create_look_ahead_mask(target_seq)
    cross_att_mask = create_cross_padding_mask(seq, target_seq)
    return decoder_mask, cross_att_mask
        
    
def create_masks_looking_ahead(seq, target_seq):
    decoder_mask = create_padding_mask(target_seq)
    cross_att_mask = create_cross_padding_mask(seq, target_seq)
    return decoder_mask, cross_att_mask
    
# =========================================
#   P O S I T I O N A L   E N C O D I N G
# =========================================
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

@tf.autograph.experimental.do_not_convert
def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)

class PatchEncoder(Layer):
    def __init__(self, num_patches, projection_dim, **kwargs):
        super(PatchEncoder, self).__init__(**kwargs)
        self.num_patches = num_patches
        self.projection_dim = projection_dim
        self.projection = Dense(units=projection_dim)
        self.position_embedding = Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )
    
    def get_config(self):
        config = super(PatchEncoder, self).get_config()
        config.update({"num_patches": self.num_patches, "projection_dim":self.projection_dim})
        return config

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded