Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer | |
from PIL import Image | |
import cv2 | |
import torch | |
from omegaconf import OmegaConf | |
import math | |
from copy import deepcopy | |
import torch.nn.functional as F | |
import numpy as np | |
import clip | |
from transformers import AutoTokenizer | |
from kandinsky2.model.text_encoders import TextEncoder | |
from kandinsky2.vqgan.autoencoder import VQModelInterface, AutoencoderKL, MOVQ | |
from kandinsky2.model.samplers import DDIMSampler, PLMSSampler | |
from kandinsky2.model.model_creation import create_model, create_gaussian_diffusion | |
from kandinsky2.model.prior import PriorDiffusionModel, CustomizedTokenizer | |
from kandinsky2.utils import prepare_image, q_sample, process_images, prepare_mask | |
class Kandinsky2_1: | |
def __init__( | |
self, | |
config, | |
model_path, | |
prior_path, | |
device, | |
task_type="text2img" | |
): | |
self.config = config | |
self.device = device | |
self.use_fp16 = self.config["model_config"]["use_fp16"] | |
self.task_type = task_type | |
self.clip_image_size = config["clip_image_size"] | |
if task_type == "text2img": | |
self.config["model_config"]["up"] = False | |
self.config["model_config"]["inpainting"] = False | |
elif task_type == "inpainting": | |
self.config["model_config"]["up"] = False | |
self.config["model_config"]["inpainting"] = True | |
else: | |
raise ValueError("Only text2img and inpainting is available") | |
self.tokenizer1 = AutoTokenizer.from_pretrained(self.config["tokenizer_name"]) | |
self.tokenizer2 = CustomizedTokenizer() | |
clip_mean, clip_std = torch.load( | |
config["prior"]["clip_mean_std_path"], map_location="cpu" | |
) | |
self.prior = PriorDiffusionModel( | |
config["prior"]["params"], | |
self.tokenizer2, | |
clip_mean, | |
clip_std, | |
) | |
self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False) | |
if self.use_fp16: | |
self.prior = self.prior.half() | |
self.text_encoder = TextEncoder(**self.config["text_enc_params"]) | |
if self.use_fp16: | |
self.text_encoder = self.text_encoder.half() | |
self.clip_model, self.preprocess = clip.load( | |
config["clip_name"], device=self.device, jit=False | |
) | |
self.clip_model.eval() | |
if self.config["image_enc_params"] is not None: | |
self.use_image_enc = True | |
self.scale = self.config["image_enc_params"]["scale"] | |
if self.config["image_enc_params"]["name"] == "AutoencoderKL": | |
self.image_encoder = AutoencoderKL( | |
**self.config["image_enc_params"]["params"] | |
) | |
elif self.config["image_enc_params"]["name"] == "VQModelInterface": | |
self.image_encoder = VQModelInterface( | |
**self.config["image_enc_params"]["params"] | |
) | |
elif self.config["image_enc_params"]["name"] == "MOVQ": | |
self.image_encoder = MOVQ(**self.config["image_enc_params"]["params"]) | |
self.image_encoder.load_state_dict( | |
torch.load(self.config["image_enc_params"]["ckpt_path"], map_location='cpu') | |
) | |
self.image_encoder.eval() | |
else: | |
self.use_image_enc = False | |
self.config["model_config"]["cache_text_emb"] = True | |
self.model = create_model(**self.config["model_config"]) | |
self.model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
if self.use_fp16: | |
self.model.convert_to_fp16() | |
self.image_encoder = self.image_encoder.half() | |
self.model_dtype = torch.float16 | |
else: | |
self.model_dtype = torch.float32 | |
self.image_encoder = self.image_encoder.to(self.device).eval() | |
self.text_encoder = self.text_encoder.to(self.device).eval() | |
self.prior = self.prior.to(self.device).eval() | |
self.model.eval() | |
self.model.to(self.device) | |
def get_new_h_w(self, h, w): | |
new_h = h // 64 | |
if h % 64 != 0: | |
new_h += 1 | |
new_w = w // 64 | |
if w % 64 != 0: | |
new_w += 1 | |
return new_h * 8, new_w * 8 | |
def encode_text(self, text_encoder, tokenizer, prompt, batch_size): | |
text_encoding = tokenizer( | |
[prompt] * batch_size + [""] * batch_size, | |
max_length=77, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
tokens = text_encoding["input_ids"].to(self.device) | |
mask = text_encoding["attention_mask"].to(self.device) | |
full_emb, pooled_emb = text_encoder(tokens=tokens, mask=mask) | |
return full_emb, pooled_emb | |
def generate_clip_emb( | |
self, | |
prompt, | |
batch_size=1, | |
prior_cf_scale=4, | |
prior_steps="25", | |
negative_prior_prompt="", | |
): | |
prompts_batch = [prompt for _ in range(batch_size)] | |
prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch) | |
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device) | |
max_txt_length = self.prior.model.text_ctx | |
tok, mask = self.tokenizer2.padded_tokens_and_mask( | |
prompts_batch, max_txt_length | |
) | |
cf_token, cf_mask = self.tokenizer2.padded_tokens_and_mask( | |
[negative_prior_prompt], max_txt_length | |
) | |
if not (cf_token.shape == tok.shape): | |
cf_token = cf_token.expand(tok.shape[0], -1) | |
cf_mask = cf_mask.expand(tok.shape[0], -1) | |
tok = torch.cat([tok, cf_token], dim=0) | |
mask = torch.cat([mask, cf_mask], dim=0) | |
tok, mask = tok.to(device=self.device), mask.to(device=self.device) | |
x = self.clip_model.token_embedding(tok).type(self.clip_model.dtype) | |
x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype) | |
x = x.permute(1, 0, 2) # NLD -> LND| | |
x = self.clip_model.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.clip_model.ln_final(x).type(self.clip_model.dtype) | |
txt_feat_seq = x | |
txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ self.clip_model.text_projection) | |
txt_feat, txt_feat_seq = txt_feat.float().to(self.device), txt_feat_seq.float().to(self.device) | |
img_feat = self.prior( | |
txt_feat, | |
txt_feat_seq, | |
mask, | |
prior_cf_scales_batch, | |
timestep_respacing=prior_steps, | |
) | |
return img_feat.to(self.model_dtype) | |
def encode_images(self, image, is_pil=False): | |
if is_pil: | |
image = self.preprocess(image).unsqueeze(0).to(self.device) | |
return self.clip_model.encode_image(image).to(self.model_dtype) | |
def generate_img( | |
self, | |
prompt, | |
img_prompt, | |
batch_size=1, | |
diffusion=None, | |
guidance_scale=7, | |
init_step=None, | |
noise=None, | |
init_img=None, | |
img_mask=None, | |
h=512, | |
w=512, | |
sampler="ddim_sampler", | |
num_steps=50, | |
): | |
new_h, new_w = self.get_new_h_w(h, w) | |
full_batch_size = batch_size * 2 | |
model_kwargs = {} | |
if init_img is not None and self.use_fp16: | |
init_img = init_img.half() | |
if img_mask is not None and self.use_fp16: | |
img_mask = img_mask.half() | |
model_kwargs["full_emb"], model_kwargs["pooled_emb"] = self.encode_text( | |
text_encoder=self.text_encoder, | |
tokenizer=self.tokenizer1, | |
prompt=prompt, | |
batch_size=batch_size, | |
) | |
model_kwargs["image_emb"] = img_prompt | |
if self.task_type == "inpainting": | |
init_img = init_img.to(self.device) | |
img_mask = img_mask.to(self.device) | |
model_kwargs["inpaint_image"] = init_img * img_mask | |
model_kwargs["inpaint_mask"] = img_mask | |
def model_fn(x_t, ts, **kwargs): | |
half = x_t[: len(x_t) // 2] | |
combined = torch.cat([half, half], dim=0) | |
model_out = self.model(combined, ts, **kwargs) | |
eps, rest = model_out[:, :4], model_out[:, 4:] | |
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) | |
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) | |
eps = torch.cat([half_eps, half_eps], dim=0) | |
if sampler == "p_sampler": | |
return torch.cat([eps, rest], dim=1) | |
else: | |
return eps | |
if noise is not None: | |
noise = noise.float() | |
if self.task_type == "inpainting": | |
def denoised_fun(x_start): | |
x_start = x_start.clamp(-2, 2) | |
return x_start * (1 - img_mask) + init_img * img_mask | |
else: | |
def denoised_fun(x): | |
return x.clamp(-2, 2) | |
if sampler == "p_sampler": | |
self.model.del_cache() | |
samples = diffusion.p_sample_loop( | |
model_fn, | |
(full_batch_size, 4, new_h, new_w), | |
device=self.device, | |
noise=noise, | |
progress=True, | |
model_kwargs=model_kwargs, | |
init_step=init_step, | |
denoised_fn=denoised_fun, | |
)[:batch_size] | |
self.model.del_cache() | |
else: | |
if sampler == "ddim_sampler": | |
sampler = DDIMSampler( | |
model=model_fn, | |
old_diffusion=diffusion, | |
schedule="linear", | |
) | |
elif sampler == "plms_sampler": | |
sampler = PLMSSampler( | |
model=model_fn, | |
old_diffusion=diffusion, | |
schedule="linear", | |
) | |
else: | |
raise ValueError("Only ddim_sampler and plms_sampler is available") | |
self.model.del_cache() | |
samples, _ = sampler.sample( | |
num_steps, | |
batch_size * 2, | |
(4, new_h, new_w), | |
conditioning=model_kwargs, | |
x_T=noise, | |
init_step=init_step, | |
) | |
self.model.del_cache() | |
samples = samples[:batch_size] | |
if self.use_image_enc: | |
if self.use_fp16: | |
samples = samples.half() | |
samples = self.image_encoder.decode(samples / self.scale) | |
samples = samples[:, :, :h, :w] | |
return process_images(samples) | |
def create_zero_img_emb(self, batch_size): | |
img = torch.zeros(1, 3, self.clip_image_size, self.clip_image_size).to(self.device) | |
return self.encode_images(img, is_pil=False).repeat(batch_size, 1) | |
def generate_text2img( | |
self, | |
prompt, | |
num_steps=100, | |
batch_size=1, | |
guidance_scale=7, | |
h=512, | |
w=512, | |
sampler="ddim_sampler", | |
prior_cf_scale=4, | |
prior_steps="25", | |
negative_prior_prompt="", | |
negative_decoder_prompt="", | |
): | |
# generate clip embeddings | |
image_emb = self.generate_clip_emb( | |
prompt, | |
batch_size=batch_size, | |
prior_cf_scale=prior_cf_scale, | |
prior_steps=prior_steps, | |
negative_prior_prompt=negative_prior_prompt, | |
) | |
if negative_decoder_prompt == "": | |
zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) | |
else: | |
zero_image_emb = self.generate_clip_emb( | |
negative_decoder_prompt, | |
batch_size=batch_size, | |
prior_cf_scale=prior_cf_scale, | |
prior_steps=prior_steps, | |
negative_prior_prompt=negative_prior_prompt, | |
) | |
image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) | |
# load diffusion | |
config = deepcopy(self.config) | |
if sampler == "p_sampler": | |
config["diffusion_config"]["timestep_respacing"] = str(num_steps) | |
diffusion = create_gaussian_diffusion(**config["diffusion_config"]) | |
return self.generate_img( | |
prompt=prompt, | |
img_prompt=image_emb, | |
batch_size=batch_size, | |
guidance_scale=guidance_scale, | |
h=h, | |
w=w, | |
sampler=sampler, | |
num_steps=num_steps, | |
diffusion=diffusion, | |
) | |
def mix_images( | |
self, | |
images_texts, | |
weights, | |
num_steps=100, | |
batch_size=1, | |
guidance_scale=7, | |
h=512, | |
w=512, | |
sampler="ddim_sampler", | |
prior_cf_scale=4, | |
prior_steps="25", | |
negative_prior_prompt="", | |
negative_decoder_prompt="", | |
): | |
assert len(images_texts) == len(weights) and len(images_texts) > 0 | |
# generate clip embeddings | |
image_emb = None | |
for i in range(len(images_texts)): | |
if image_emb is None: | |
if type(images_texts[i]) == str: | |
image_emb = weights[i] * self.generate_clip_emb( | |
images_texts[i], | |
batch_size=1, | |
prior_cf_scale=prior_cf_scale, | |
prior_steps=prior_steps, | |
negative_prior_prompt=negative_prior_prompt, | |
) | |
else: | |
image_emb = self.encode_images(images_texts[i], is_pil=True) * weights[i] | |
else: | |
if type(images_texts[i]) == str: | |
image_emb = image_emb + weights[i] * self.generate_clip_emb( | |
images_texts[i], | |
batch_size=1, | |
prior_cf_scale=prior_cf_scale, | |
prior_steps=prior_steps, | |
negative_prior_prompt=negative_prior_prompt, | |
) | |
else: | |
image_emb = image_emb + self.encode_images(images_texts[i], is_pil=True) * weights[i] | |
image_emb = image_emb.repeat(batch_size, 1) | |
if negative_decoder_prompt == "": | |
zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) | |
else: | |
zero_image_emb = self.generate_clip_emb( | |
negative_decoder_prompt, | |
batch_size=batch_size, | |
prior_cf_scale=prior_cf_scale, | |
prior_steps=prior_steps, | |
negative_prior_prompt=negative_prior_prompt, | |
) | |
image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) | |
# load diffusion | |
config = deepcopy(self.config) | |
if sampler == "p_sampler": | |
config["diffusion_config"]["timestep_respacing"] = str(num_steps) | |
diffusion = create_gaussian_diffusion(**config["diffusion_config"]) | |
return self.generate_img( | |
prompt="", | |
img_prompt=image_emb, | |
batch_size=batch_size, | |
guidance_scale=guidance_scale, | |
h=h, | |
w=w, | |
sampler=sampler, | |
num_steps=num_steps, | |
diffusion=diffusion, | |
) | |
def generate_img2img( | |
self, | |
prompt, | |
pil_img, | |
strength=0.7, | |
num_steps=100, | |
batch_size=1, | |
guidance_scale=7, | |
h=512, | |
w=512, | |
sampler="ddim_sampler", | |
prior_cf_scale=4, | |
prior_steps="25", | |
): | |
# generate clip embeddings | |
image_emb = self.generate_clip_emb( | |
prompt, | |
batch_size=batch_size, | |
prior_cf_scale=prior_cf_scale, | |
prior_steps=prior_steps, | |
) | |
zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) | |
image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) | |
# load diffusion | |
config = deepcopy(self.config) | |
if sampler == "p_sampler": | |
config["diffusion_config"]["timestep_respacing"] = str(num_steps) | |
diffusion = create_gaussian_diffusion(**config["diffusion_config"]) | |
image = prepare_image(pil_img, h=h, w=w).to(self.device) | |
if self.use_fp16: | |
image = image.half() | |
image = self.image_encoder.encode(image) * self.scale | |
start_step = int(diffusion.num_timesteps * (1 - strength)) | |
image = q_sample( | |
image, | |
torch.tensor(diffusion.timestep_map[start_step - 1]).to(self.device), | |
schedule_name=config["diffusion_config"]["noise_schedule"], | |
num_steps=config["diffusion_config"]["steps"], | |
) | |
image = image.repeat(2, 1, 1, 1) | |
return self.generate_img( | |
prompt=prompt, | |
img_prompt=image_emb, | |
batch_size=batch_size, | |
guidance_scale=guidance_scale, | |
h=h, | |
w=w, | |
sampler=sampler, | |
num_steps=num_steps, | |
diffusion=diffusion, | |
noise=image, | |
init_step=start_step, | |
) | |
def generate_inpainting( | |
self, | |
prompt, | |
pil_img, | |
img_mask, | |
num_steps=100, | |
batch_size=1, | |
guidance_scale=7, | |
h=512, | |
w=512, | |
sampler="ddim_sampler", | |
prior_cf_scale=4, | |
prior_steps="25", | |
negative_prior_prompt="", | |
negative_decoder_prompt="", | |
): | |
# generate clip embeddings | |
image_emb = self.generate_clip_emb( | |
prompt, | |
batch_size=batch_size, | |
prior_cf_scale=prior_cf_scale, | |
prior_steps=prior_steps, | |
negative_prior_prompt=negative_prior_prompt, | |
) | |
zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) | |
image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) | |
# load diffusion | |
config = deepcopy(self.config) | |
if sampler == "p_sampler": | |
config["diffusion_config"]["timestep_respacing"] = str(num_steps) | |
diffusion = create_gaussian_diffusion(**config["diffusion_config"]) | |
image = prepare_image(pil_img, w, h).to(self.device) | |
if self.use_fp16: | |
image = image.half() | |
image = self.image_encoder.encode(image) * self.scale | |
image_shape = tuple(image.shape[-2:]) | |
img_mask = torch.from_numpy(img_mask).unsqueeze(0).unsqueeze(0) | |
img_mask = F.interpolate( | |
img_mask, | |
image_shape, | |
mode="nearest", | |
) | |
img_mask = prepare_mask(img_mask).to(self.device) | |
if self.use_fp16: | |
img_mask = img_mask.half() | |
image = image.repeat(2, 1, 1, 1) | |
img_mask = img_mask.repeat(2, 1, 1, 1) | |
return self.generate_img( | |
prompt=prompt, | |
img_prompt=image_emb, | |
batch_size=batch_size, | |
guidance_scale=guidance_scale, | |
h=h, | |
w=w, | |
sampler=sampler, | |
num_steps=num_steps, | |
diffusion=diffusion, | |
init_img=image, | |
img_mask=img_mask, | |
) | |
import os | |
from huggingface_hub import hf_hub_url, cached_download | |
from copy import deepcopy | |
from omegaconf.dictconfig import DictConfig | |
def get_kandinsky2_1( | |
device, | |
task_type="text2img", | |
cache_dir="/tmp/kandinsky2", | |
use_auth_token=None, | |
use_flash_attention=False, | |
): | |
cache_dir = os.path.join(cache_dir, "2_1") | |
config = DictConfig(deepcopy(CONFIG_2_1)) | |
config["model_config"]["use_flash_attention"] = use_flash_attention | |
if task_type == "text2img": | |
model_name = "decoder_fp16.ckpt" | |
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name) | |
elif task_type == "inpainting": | |
model_name = "inpainting_fp16.ckpt" | |
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name) | |
cached_download( | |
config_file_url, | |
cache_dir=cache_dir, | |
force_filename=model_name, | |
use_auth_token=use_auth_token, | |
) | |
prior_name = "prior_fp16.ckpt" | |
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=prior_name) | |
cached_download( | |
config_file_url, | |
cache_dir=cache_dir, | |
force_filename=prior_name, | |
use_auth_token=use_auth_token, | |
) | |
cache_dir_text_en = os.path.join(cache_dir, "text_encoder") | |
for name in [ | |
"config.json", | |
"pytorch_model.bin", | |
"sentencepiece.bpe.model", | |
"special_tokens_map.json", | |
"tokenizer.json", | |
"tokenizer_config.json", | |
]: | |
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=f"text_encoder/{name}") | |
cached_download( | |
config_file_url, | |
cache_dir=cache_dir_text_en, | |
force_filename=name, | |
use_auth_token=use_auth_token, | |
) | |
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="movq_final.ckpt") | |
cached_download( | |
config_file_url, | |
cache_dir=cache_dir, | |
force_filename="movq_final.ckpt", | |
use_auth_token=use_auth_token, | |
) | |
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="ViT-L-14_stats.th") | |
cached_download( | |
config_file_url, | |
cache_dir=cache_dir, | |
force_filename="ViT-L-14_stats.th", | |
use_auth_token=use_auth_token, | |
) | |
config["tokenizer_name"] = cache_dir_text_en | |
config["text_enc_params"]["model_path"] = cache_dir_text_en | |
config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th") | |
config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt") | |
cache_model_name = os.path.join(cache_dir, model_name) | |
cache_prior_name = os.path.join(cache_dir, prior_name) | |
model = Kandinsky2_1(config, cache_model_name, cache_prior_name, device, task_type=task_type) | |
return model | |
def get_kandinsky2( | |
device, | |
task_type="text2img", | |
cache_dir="/tmp/kandinsky2", | |
use_auth_token=None, | |
model_version="2.1", | |
use_flash_attention=False, | |
): | |
if model_version == "2.0": | |
model = get_kandinsky2_0( | |
device, | |
task_type=task_type, | |
cache_dir=cache_dir, | |
use_auth_token=use_auth_token, | |
) | |
elif model_version == "2.1": | |
model = get_kandinsky2_1( | |
device, | |
task_type=task_type, | |
cache_dir=cache_dir, | |
use_auth_token=use_auth_token, | |
use_flash_attention=use_flash_attention, | |
) | |
elif model_version == "2.2": | |
model = Kandinsky2_2(device=device, task_type=task_type) | |
else: | |
raise ValueError("Only 2.0 and 2.1 is available") | |
return model |