File size: 2,543 Bytes
9016314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import tensorflow as tf

def layer_norm(x):
    mean = tf.reduce_mean(input_tensor=x, axis=[1,2], keepdims=True)
    std = tf.math.reduce_std(x, axis=[1,2], keepdims=True)
    x = (x - mean) / std
    return x

def set_attention(Q, K, dim, num_heads, name='set_attention'):
    with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
        q = tf.compat.v1.layers.dense(Q, dim, name='query')
        k = tf.compat.v1.layers.dense(K, dim, name='key')
        v = tf.compat.v1.layers.dense(K, dim, name='value')

        q_ = tf.concat(tf.split(q, num_heads, axis=-1), axis=0)
        k_ = tf.concat(tf.split(k, num_heads, axis=-1), axis=0)
        v_ = tf.concat(tf.split(v, num_heads, axis=-1), axis=0)

        logits = tf.matmul(q_, k_, transpose_b=True)/np.sqrt(dim) # [B*Nh,Nq,Nk]
        A = tf.nn.softmax(logits, axis=-1)
        o = q_ + tf.matmul(A, v_)
        
        o = tf.concat(tf.split(o, num_heads, axis=0), axis=-1)
        # o = tf.contrib.layers.layer_norm(o)
        o = o + tf.compat.v1.layers.dense(o, dim, activation=tf.nn.relu, name='output')
        # o = tf.contrib.layers.layer_norm(o)

    return o

def set_transformer(inputs, layer_sizes, name, num_heads=4, num_inds=16):
    with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
        out = inputs
        for i, size in enumerate(layer_sizes):
            inds = tf.compat.v1.get_variable(f'inds_{i}', shape=[1,num_inds,size], dtype=tf.float32, trainable=True,
                                    initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
            inds = tf.tile(inds, [tf.shape(input=out)[0],1,1])
            tmp = set_attention(inds, out, size, num_heads, name=f'self_attn_{i}_pre')
            out = set_attention(out, tmp, size, num_heads, name=f'self_attn_{i}_post')

    return out

def set_pooling(inputs, name, num_heads=4):
    B = tf.shape(input=inputs)[0]
    N = tf.shape(input=inputs)[1]
    d = inputs.get_shape().as_list()[-1]
    with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
        seed = tf.compat.v1.get_variable('pool_seed', shape=[1,1,d], dtype=tf.float32, trainable=True, 
                                initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
        seed = tf.tile(seed, [B,1,1])
        out = set_attention(seed, inputs, d, num_heads, name='pool_attn')
        out = tf.squeeze(out, axis=1)

    return out