Spaces:
Running
Running
import numpy as np | |
import tensorflow as tf | |
def layer_norm(x): | |
mean = tf.reduce_mean(input_tensor=x, axis=[1,2], keepdims=True) | |
std = tf.math.reduce_std(x, axis=[1,2], keepdims=True) | |
x = (x - mean) / std | |
return x | |
def set_attention(Q, K, dim, num_heads, name='set_attention'): | |
with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): | |
q = tf.compat.v1.layers.dense(Q, dim, name='query') | |
k = tf.compat.v1.layers.dense(K, dim, name='key') | |
v = tf.compat.v1.layers.dense(K, dim, name='value') | |
q_ = tf.concat(tf.split(q, num_heads, axis=-1), axis=0) | |
k_ = tf.concat(tf.split(k, num_heads, axis=-1), axis=0) | |
v_ = tf.concat(tf.split(v, num_heads, axis=-1), axis=0) | |
logits = tf.matmul(q_, k_, transpose_b=True)/np.sqrt(dim) # [B*Nh,Nq,Nk] | |
A = tf.nn.softmax(logits, axis=-1) | |
o = q_ + tf.matmul(A, v_) | |
o = tf.concat(tf.split(o, num_heads, axis=0), axis=-1) | |
# o = tf.contrib.layers.layer_norm(o) | |
o = o + tf.compat.v1.layers.dense(o, dim, activation=tf.nn.relu, name='output') | |
# o = tf.contrib.layers.layer_norm(o) | |
return o | |
def set_transformer(inputs, layer_sizes, name, num_heads=4, num_inds=16): | |
with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): | |
out = inputs | |
for i, size in enumerate(layer_sizes): | |
inds = tf.compat.v1.get_variable(f'inds_{i}', shape=[1,num_inds,size], dtype=tf.float32, trainable=True, | |
initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) | |
inds = tf.tile(inds, [tf.shape(input=out)[0],1,1]) | |
tmp = set_attention(inds, out, size, num_heads, name=f'self_attn_{i}_pre') | |
out = set_attention(out, tmp, size, num_heads, name=f'self_attn_{i}_post') | |
return out | |
def set_pooling(inputs, name, num_heads=4): | |
B = tf.shape(input=inputs)[0] | |
N = tf.shape(input=inputs)[1] | |
d = inputs.get_shape().as_list()[-1] | |
with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): | |
seed = tf.compat.v1.get_variable('pool_seed', shape=[1,1,d], dtype=tf.float32, trainable=True, | |
initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) | |
seed = tf.tile(seed, [B,1,1]) | |
out = set_attention(seed, inputs, d, num_heads, name='pool_attn') | |
out = tf.squeeze(out, axis=1) | |
return out | |