Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as T
|
3 |
+
from dalle_pytorch import VQGanVAE
|
4 |
+
from dalle.models import DALLE_Klue_Roberta
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
import yaml
|
9 |
+
from easydict import EasyDict
|
10 |
+
|
11 |
+
dalle_config_path = 'configs/dalle_config.yaml'
|
12 |
+
dalle_path = 'results/dalle_uk_final.pt'
|
13 |
+
|
14 |
+
vqgan_config_path = '/home/brad/Development/taming-transformers/configs/VQGAN_blue.yaml'
|
15 |
+
vqgan_path = '/home/brad/Development/taming-transformers/logs/2022-07-21T12-44-12_VQGAN_blue/checkpoints/best.ckpt'
|
16 |
+
|
17 |
+
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
|
18 |
+
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
|
20 |
+
|
21 |
+
with open(dalle_config_path, "r") as f:
|
22 |
+
dalle_config = yaml.load(f, Loader=yaml.Loader)
|
23 |
+
DALLE_CFG = EasyDict(dalle_config["DALLE_CFG"])
|
24 |
+
|
25 |
+
DALLE_CFG.VOCAB_SIZE = tokenizer.vocab_size
|
26 |
+
|
27 |
+
vae = VQGanVAE(
|
28 |
+
vqgan_model_path=vqgan_path,
|
29 |
+
vqgan_config_path=vqgan_config_path
|
30 |
+
)
|
31 |
+
|
32 |
+
DALLE_CFG.IMAGE_SIZE = vae.image_size
|
33 |
+
|
34 |
+
dalle_params = dict(
|
35 |
+
num_text_tokens=tokenizer.vocab_size,
|
36 |
+
text_seq_len=DALLE_CFG.TEXT_SEQ_LEN,
|
37 |
+
depth=DALLE_CFG.DEPTH,
|
38 |
+
heads=DALLE_CFG.HEADS,
|
39 |
+
dim_head=DALLE_CFG.DIM_HEAD,
|
40 |
+
reversible=DALLE_CFG.REVERSIBLE,
|
41 |
+
loss_img_weight=DALLE_CFG.LOSS_IMG_WEIGHT,
|
42 |
+
attn_types=DALLE_CFG.ATTN_TYPES,
|
43 |
+
ff_dropout=DALLE_CFG.FF_DROPOUT,
|
44 |
+
attn_dropout=DALLE_CFG.ATTN_DROPOUT,
|
45 |
+
stable=DALLE_CFG.STABLE,
|
46 |
+
shift_tokens=DALLE_CFG.SHIFT_TOKENS,
|
47 |
+
rotary_emb=DALLE_CFG.ROTARY_EMB,
|
48 |
+
)
|
49 |
+
|
50 |
+
dalle = DALLE_Klue_Roberta(
|
51 |
+
vae=vae,
|
52 |
+
wte_dir="models/roberta_large_wte.pt",
|
53 |
+
wpe_dir="models/roberta_large_wpe.pt",
|
54 |
+
**dalle_params
|
55 |
+
).to(device)
|
56 |
+
|
57 |
+
|
58 |
+
loaded_obj = torch.load(dalle_path, map_location=torch.device('cuda:0'))
|
59 |
+
dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
|
60 |
+
dalle.load_state_dict(weights)
|
61 |
+
|
62 |
+
def text_to_montage(text):
|
63 |
+
encoded_dict = tokenizer(
|
64 |
+
text,
|
65 |
+
return_tensors="pt",
|
66 |
+
padding="max_length",
|
67 |
+
truncation=True,
|
68 |
+
max_length=DALLE_CFG.TEXT_SEQ_LEN,
|
69 |
+
add_special_tokens=True,
|
70 |
+
return_token_type_ids=True, # for RoBERTa
|
71 |
+
).to(device)
|
72 |
+
|
73 |
+
encoded_text = encoded_dict['input_ids']
|
74 |
+
mask = encoded_dict['attention_mask']
|
75 |
+
|
76 |
+
image = dalle.generate_images(
|
77 |
+
encoded_text,
|
78 |
+
mask=mask,
|
79 |
+
filter_thres=0.9 # topk sampling at 0.9
|
80 |
+
)
|
81 |
+
|
82 |
+
return T.ToPILImage()(image.squeeze())
|
83 |
+
|
84 |
+
demo = gr.Interface(fn=text_to_montage, inputs="text", outputs="image")
|
85 |
+
|
86 |
+
demo.launch(server_name="0.0.0.0")
|