|
from PIL import Image |
|
from IPython.display import display |
|
import torch as th |
|
|
|
from glide_text2im.download import load_checkpoint |
|
from glide_text2im.model_creation import ( |
|
create_model_and_diffusion, |
|
model_and_diffusion_defaults, |
|
model_and_diffusion_defaults_upsampler |
|
) |
|
has_cuda = th.cuda.is_available() |
|
device = th.device('cpu' if not has_cuda else 'cuda') |
|
|
|
options = model_and_diffusion_defaults() |
|
options['use_fp16'] = has_cuda |
|
options['timestep_respacing'] = '100' |
|
model, diffusion = create_model_and_diffusion(**options) |
|
model.eval() |
|
if has_cuda: |
|
model.convert_to_fp16() |
|
model.to(device) |
|
model.load_state_dict(load_checkpoint('base', device)) |
|
print('total base parameters', sum(x.numel() for x in model.parameters())) |
|
|
|
options_up = model_and_diffusion_defaults_upsampler() |
|
options_up['use_fp16'] = has_cuda |
|
options_up['timestep_respacing'] = 'fast27' |
|
model_up, diffusion_up = create_model_and_diffusion(**options_up) |
|
model_up.eval() |
|
if has_cuda: |
|
model_up.convert_to_fp16() |
|
model_up.to(device) |
|
model_up.load_state_dict(load_checkpoint('upsample', device)) |
|
print('total upsampler parameters', sum(x.numel() for x in model_up.parameters())) |
|
def show_images(batch: th.Tensor): |
|
""" Display a batch of images inline. """ |
|
scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu() |
|
reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3]) |
|
display(Image.fromarray(reshaped.numpy())) |
|
|
|
prompt = "an oil painting of a corgi" |
|
batch_size = 1 |
|
guidance_scale = 3.0 |
|
|
|
|
|
|
|
upsample_temp = 0.997 |
|
|
|
samples = diffusion.p_sample_loop( |
|
model_fn, |
|
(full_batch_size, 3, options["image_size"], options["image_size"]), |
|
device=device, |
|
clip_denoised=True, |
|
progress=True, |
|
model_kwargs=model_kwargs, |
|
cond_fn=None, |
|
)[:batch_size] |
|
model.del_cache() |
|
|
|
|
|
show_images(samples) |
|
|
|
|
|
import gradio as gr |
|
def generate_upsampled_image_from_text(prompt): |
|
|
|
prompt = prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = model_up.tokenizer.encode(prompt) |
|
tokens, mask = model_up.tokenizer.padded_tokens_and_mask( |
|
tokens, options_up['text_ctx'] |
|
) |
|
|
|
|
|
model_kwargs = dict( |
|
|
|
low_res=((samples + 1) * 127.5).round() / 127.5 - 1, |
|
|
|
|
|
tokens=th.tensor( |
|
[tokens] * batch_size, device=device |
|
), |
|
mask=th.tensor( |
|
[mask] * batch_size, |
|
dtype=th.bool, |
|
device=device, |
|
), |
|
) |
|
|
|
|
|
model_up.del_cache() |
|
up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"]) |
|
up_samples = diffusion_up.ddim_sample_loop( |
|
model_up, |
|
up_shape, |
|
noise=th.randn(up_shape, device=device) * upsample_temp, |
|
device=device, |
|
clip_denoised=True, |
|
progress=True, |
|
model_kwargs=model_kwargs, |
|
cond_fn=None, |
|
)[:batch_size] |
|
model_up.del_cache() |
|
|
|
|
|
show_images(up_samples) |
|
demo = gr.Interface(fn =generate_upsampled_image_from_text,inputs ="text",outputs ="image") |
|
demo.launch() |