|
from PIL import Image |
|
from IPython.display import display |
|
import torch as th |
|
|
|
from glide_text2im.download import load_checkpoint |
|
from glide_text2im.model_creation import ( |
|
create_model_and_diffusion, |
|
model_and_diffusion_defaults, |
|
model_and_diffusion_defaults_upsampler |
|
) |
|
|
|
options = model_and_diffusion_defaults() |
|
options['use_fp16'] = has_cuda |
|
options['timestep_respacing'] = '100' |
|
model, diffusion = create_model_and_diffusion(**options) |
|
model.eval() |
|
if has_cuda: |
|
model.convert_to_fp16() |
|
model.to(device) |
|
model.load_state_dict(load_checkpoint('base', device)) |
|
print('total base parameters', sum(x.numel() for x in model.parameters())) |
|
def show_images(batch: th.Tensor): |
|
""" Display a batch of images inline. """ |
|
scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu() |
|
reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3]) |
|
display(Image.fromarray(reshaped.numpy())) |
|
|
|
prompt = "" |
|
batch_size = 1 |
|
guidance_scale = 3.0 |
|
|
|
|
|
|
|
upsample_temp = 0.997 |
|
import gradio as gr |
|
def generate_image_from_text(prompt): |
|
|
|
prompt = prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = model.tokenizer.encode(prompt) |
|
tokens, mask = model.tokenizer.padded_tokens_and_mask( |
|
tokens, options['text_ctx'] |
|
) |
|
|
|
|
|
full_batch_size = batch_size * 2 |
|
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask( |
|
[], options['text_ctx'] |
|
) |
|
|
|
|
|
model_kwargs = dict( |
|
tokens=th.tensor( |
|
[tokens] * batch_size + [uncond_tokens] * batch_size, device=device |
|
), |
|
mask=th.tensor( |
|
[mask] * batch_size + [uncond_mask] * batch_size, |
|
dtype=th.bool, |
|
device=device, |
|
), |
|
) |
|
|
|
|
|
def model_fn(x_t, ts, **kwargs): |
|
half = x_t[: len(x_t) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = model(combined, ts, **kwargs) |
|
eps, rest = model_out[:, :3], model_out[:, 3:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return th.cat([eps, rest], dim=1) |
|
|
|
|
|
model.del_cache() |
|
samples = diffusion.p_sample_loop( |
|
model_fn, |
|
(full_batch_size, 3, options["image_size"], options["image_size"]), |
|
device=device, |
|
clip_denoised=True, |
|
progress=True, |
|
model_kwargs=model_kwargs, |
|
cond_fn=None, |
|
)[:batch_size] |
|
model.del_cache() |
|
|
|
|
|
show_images(samples) |
|
demo = gr.Interface(fn =generate_image_from_text,inputs ="text",outputs ="image") |
|
demo.launch() |
|
|