Spaces:
Runtime error
Runtime error
File size: 6,118 Bytes
0e403da 1d30073 0e403da 7bbddfb 70e7b84 7bbddfb 70e7b84 7bbddfb 51051f5 7bbddfb 51051f5 01703c9 0997afc 6b66811 0e403da 7bbddfb 0e403da 7bbddfb 0e403da 7bbddfb 0e403da 8e1a8c8 51051f5 943ee2f 7bbddfb 597e1ba 7bbddfb 51051f5 01703c9 0997afc 01703c9 7bbddfb 51051f5 7bbddfb bae7bad 943ee2f 597e1ba 7bbddfb 51051f5 7bbddfb 0997afc 7bbddfb a3759e2 7bbddfb 51051f5 7bbddfb 597e1ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import streamlit as st
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
from t5_vae_flax_alt.src.t5_vae import FlaxT5VaeForAutoencoding
st.set_page_config(
page_title="T5-VAE",
page_icon="π",
layout="wide",
initial_sidebar_state="expanded"
)
st.title('T5-VAE πππ')
st.markdown('''
This is a variational autoencoder trained on text.
It allows interpolating on text at a high level, try it out!
See how it works [here](http://fras.uk/ml/large%20prior-free%20models/transformer-vae/2020/08/13/Transformers-as-Variational-Autoencoders.html).
''')
st.markdown('''
### [t5-vae-python](https://huggingface.co/flax-community/t5-vae-python)
This model is trained on lines of Python code from GitHub ([dataset](https://huggingface.co/datasets/Fraser/python-lines)).
''')
@st.cache(allow_output_mutation=True)
def get_model():
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-python")
assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
return model, tokenizer
model, tokenizer = get_model()
def add_decoder_input_ids(examples):
arr_input_ids = jnp.array(examples["input_ids"])
pad = tokenizer.pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, tokenizer.pad_token_id, model.config.decoder_start_token_id)
arr_attention_mask = jnp.array(examples['attention_mask'])
ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32)
examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1)
for k in ['decoder_input_ids', 'decoder_attention_mask']:
examples[k] = examples[k].tolist()
return examples
def prepare_inputs(inputs):
for k, v in inputs.items():
inputs[k] = jnp.array(v)
return add_decoder_input_ids(inputs)
def get_latent(text):
return model(**prepare_inputs(tokenizer([text]))).latent_codes[0]
def tokens_from_latent(latent_codes):
model.config.is_encoder_decoder = True
output_ids = model.generate(
latent_codes=jnp.array([latent_codes]),
bos_token_id=model.config.decoder_start_token_id,
min_length=1,
max_length=32,
)
return output_ids
def slerp(ratio, t1, t2):
'''
Perform a spherical interpolation between 2 vectors.
Most of the volume of a high-dimensional orange is in the skin, not the pulp.
This also applies for multivariate Gaussian distributions.
To that end we can interpolate between samples by following the surface of a n-dimensional sphere rather than a straight line.
Args:
ratio: Interpolation ratio.
t1: Tensor1
t2: Tensor2
'''
low_norm = t1 / jnp.linalg.norm(t1, axis=1, keepdims=True)
high_norm = t2 / jnp.linalg.norm(t2, axis=1, keepdims=True)
omega = jnp.arccos((low_norm * high_norm).sum(1))
so = jnp.sin(omega)
res = (jnp.sin((1.0 - ratio) * omega) / so)[0] * t1 + (jnp.sin(ratio * omega) / so)[0] * t2
return res
def decode(cnt, ratio, txt_1, txt_2):
if not txt_1 or not txt_2:
return ''
cnt.write('Getting latents...')
lt_1, lt_2 = get_latent(txt_1), get_latent(txt_2)
lt_new = slerp(ratio, lt_1, lt_2)
cnt.write('Decoding latent...')
tkns = tokens_from_latent(lt_new)
return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
in_1 = st.text_input("A line of Python code.", "x = a - 1")
in_2 = st.text_input("Another line of Python code.", "x = a + 10 * 2")
r = st.slider('Python Interpolation Ratio', min_value=0.0, max_value=1.0, value=0.5)
container = st.empty()
container.write('Loading...')
out = decode(container, r, in_1, in_2)
container.empty()
st.write('Output: ' + out)
st.markdown('''
### [t5-vae-wiki](https://huggingface.co/flax-community/t5-vae-wiki)
This model is trained on just 5% of the sentences on wikipedia.
We'll release another model trained on the full [dataset](https://github.com/ChunyuanLI/Optimus/blob/master/download_datasets.md) soon.
''')
@st.cache(allow_output_mutation=True)
def get_wiki_model():
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-wiki")
assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
return model, tokenizer
model, tokenizer = get_wiki_model()
in_1 = st.text_input("A sentence.", "Children are looking for the water to be clear.")
in_2 = st.text_input("Another sentence.", "There are two people playing soccer.")
r = st.slider('English Interpolation Ratio', min_value=0.0, max_value=1.0, value=0.5)
container = st.empty()
container.write('Loading...')
out = decode(container, r, in_1, in_2)
container.empty()
st.write('Output: ' + out)
st.markdown('''
Try arithmetic in latent space.
Here latent codes for each sentence are found and arithmetic is done with them.
Here it runs the sum `C + (B - A) = ?`
''')
def arithmetic(cnt, txt_a, txt_b, txt_c):
if not txt_a or not txt_b or not txt_c:
return ''
cnt.write('getting latents...')
lt_a, lt_b, lt_c = get_latent(txt_a), get_latent(txt_b), get_latent(txt_c)
lt_d = lt_c + (lt_b - lt_a)
cnt.write('decoding C + (B - A)...')
tkns = tokens_from_latent(lt_d)
return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
in_a = st.text_input("A", "A girl makes a silly face.")
in_b = st.text_input("B", "Two girls are playing soccer.")
in_c = st.text_input("C", "A girl is looking through a microscope.")
st.markdown('''
A is to B as C is to...
''')
container = st.empty()
container.write('Loading...')
out = arithmetic(container, in_a, in_b, in_c)
container.empty()
st.write('Output: ' + out)
|