|
import gradio as gr |
|
from PIL import Image |
|
import requests |
|
|
|
from tld.denoiser import Denoiser |
|
from tld.diffusion import DiffusionGenerator |
|
|
|
from diffusers import AutoencoderKL, AutoencoderTiny |
|
from tqdm import tqdm |
|
import clip |
|
import torch |
|
import numpy as np |
|
import torchvision.utils as vutils |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from PIL import Image |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
to_pil = transforms.ToPILImage() |
|
|
|
|
|
def download_file(url, filename): |
|
|
|
with requests.get(url, stream=True) as r: |
|
r.raise_for_status() |
|
with open(filename, 'wb') as f: |
|
for chunk in r.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
|
|
@torch.no_grad() |
|
def encode_text(label, model): |
|
text_tokens = clip.tokenize(label, truncate=True).to(device) |
|
text_encoding = model.encode_text(text_tokens) |
|
return text_encoding.cpu() |
|
|
|
def generate_image_from_text(prompt, class_guidance=6, seed=11, num_imgs=1, img_size = 32): |
|
|
|
n_iter = 15 |
|
nrow = int(np.sqrt(num_imgs)) |
|
|
|
cur_prompts = [prompt]*num_imgs |
|
labels = encode_text(cur_prompts, clip_model) |
|
out, out_latent = diffuser.generate(labels=labels, |
|
num_imgs=num_imgs, |
|
class_guidance=class_guidance, |
|
seed=seed, |
|
n_iter=n_iter, |
|
exponent=1, |
|
scale_factor=8, |
|
sharp_f=0, |
|
bright_f=0 |
|
) |
|
|
|
out = to_pil((vutils.make_grid((out+1)/2, nrow=nrow, padding=4)).float().clip(0, 1)) |
|
|
|
out.save(f'{prompt}_cfg:{class_guidance}_seed:{seed}.png') |
|
|
|
print("Images Generated and Saved. They will shortly output below.") |
|
return out |
|
|
|
|
|
vae_scale_factor = 8 |
|
img_size = 32 |
|
model_dtype = torch.float32 |
|
|
|
file_url = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth" |
|
local_filename = "state_dict_378000.pth" |
|
download_file(file_url, local_filename) |
|
|
|
|
|
denoiser = Denoiser(image_size=32, noise_embed_dims=256, patch_size=2, |
|
embed_dim=768, dropout=0, n_layers=12) |
|
|
|
|
|
state_dict = torch.load('state_dict_378000.pth', map_location=torch.device('cpu')) |
|
|
|
denoiser = denoiser.to(model_dtype) |
|
denoiser.load_state_dict(state_dict) |
|
denoiser = denoiser.to(device) |
|
|
|
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=model_dtype).to(device) |
|
|
|
clip_model, preprocess = clip.load("ViT-L/14") |
|
clip_model = clip_model.to(device) |
|
|
|
diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_image_from_text, |
|
inputs=["text", "slider"], |
|
outputs="image", |
|
title="Text-to-Image Generator", |
|
description="Enter a text prompt to generate an image." |
|
) |
|
|
|
|
|
iface.launch() |