Kandinsky2.1 / kandinsky2_1_model.py
ai-forever's picture
Create kandinsky2_1_model.py
78a6221
raw
history blame
23.3 kB
from transformers import AutoTokenizer
from PIL import Image
import cv2
import torch
from omegaconf import OmegaConf
import math
from copy import deepcopy
import torch.nn.functional as F
import numpy as np
import clip
from transformers import AutoTokenizer
from kandinsky2.model.text_encoders import TextEncoder
from kandinsky2.vqgan.autoencoder import VQModelInterface, AutoencoderKL, MOVQ
from kandinsky2.model.samplers import DDIMSampler, PLMSSampler
from kandinsky2.model.model_creation import create_model, create_gaussian_diffusion
from kandinsky2.model.prior import PriorDiffusionModel, CustomizedTokenizer
from kandinsky2.utils import prepare_image, q_sample, process_images, prepare_mask
class Kandinsky2_1:
def __init__(
self,
config,
model_path,
prior_path,
device,
task_type="text2img"
):
self.config = config
self.device = device
self.use_fp16 = self.config["model_config"]["use_fp16"]
self.task_type = task_type
self.clip_image_size = config["clip_image_size"]
if task_type == "text2img":
self.config["model_config"]["up"] = False
self.config["model_config"]["inpainting"] = False
elif task_type == "inpainting":
self.config["model_config"]["up"] = False
self.config["model_config"]["inpainting"] = True
else:
raise ValueError("Only text2img and inpainting is available")
self.tokenizer1 = AutoTokenizer.from_pretrained(self.config["tokenizer_name"])
self.tokenizer2 = CustomizedTokenizer()
clip_mean, clip_std = torch.load(
config["prior"]["clip_mean_std_path"], map_location="cpu"
)
self.prior = PriorDiffusionModel(
config["prior"]["params"],
self.tokenizer2,
clip_mean,
clip_std,
)
self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False)
if self.use_fp16:
self.prior = self.prior.half()
self.text_encoder = TextEncoder(**self.config["text_enc_params"])
if self.use_fp16:
self.text_encoder = self.text_encoder.half()
self.clip_model, self.preprocess = clip.load(
config["clip_name"], device=self.device, jit=False
)
self.clip_model.eval()
if self.config["image_enc_params"] is not None:
self.use_image_enc = True
self.scale = self.config["image_enc_params"]["scale"]
if self.config["image_enc_params"]["name"] == "AutoencoderKL":
self.image_encoder = AutoencoderKL(
**self.config["image_enc_params"]["params"]
)
elif self.config["image_enc_params"]["name"] == "VQModelInterface":
self.image_encoder = VQModelInterface(
**self.config["image_enc_params"]["params"]
)
elif self.config["image_enc_params"]["name"] == "MOVQ":
self.image_encoder = MOVQ(**self.config["image_enc_params"]["params"])
self.image_encoder.load_state_dict(
torch.load(self.config["image_enc_params"]["ckpt_path"], map_location='cpu')
)
self.image_encoder.eval()
else:
self.use_image_enc = False
self.config["model_config"]["cache_text_emb"] = True
self.model = create_model(**self.config["model_config"])
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
if self.use_fp16:
self.model.convert_to_fp16()
self.image_encoder = self.image_encoder.half()
self.model_dtype = torch.float16
else:
self.model_dtype = torch.float32
self.image_encoder = self.image_encoder.to(self.device).eval()
self.text_encoder = self.text_encoder.to(self.device).eval()
self.prior = self.prior.to(self.device).eval()
self.model.eval()
self.model.to(self.device)
def get_new_h_w(self, h, w):
new_h = h // 64
if h % 64 != 0:
new_h += 1
new_w = w // 64
if w % 64 != 0:
new_w += 1
return new_h * 8, new_w * 8
@torch.no_grad()
def encode_text(self, text_encoder, tokenizer, prompt, batch_size):
text_encoding = tokenizer(
[prompt] * batch_size + [""] * batch_size,
max_length=77,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
tokens = text_encoding["input_ids"].to(self.device)
mask = text_encoding["attention_mask"].to(self.device)
full_emb, pooled_emb = text_encoder(tokens=tokens, mask=mask)
return full_emb, pooled_emb
@torch.no_grad()
def generate_clip_emb(
self,
prompt,
batch_size=1,
prior_cf_scale=4,
prior_steps="25",
negative_prior_prompt="",
):
prompts_batch = [prompt for _ in range(batch_size)]
prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
max_txt_length = self.prior.model.text_ctx
tok, mask = self.tokenizer2.padded_tokens_and_mask(
prompts_batch, max_txt_length
)
cf_token, cf_mask = self.tokenizer2.padded_tokens_and_mask(
[negative_prior_prompt], max_txt_length
)
if not (cf_token.shape == tok.shape):
cf_token = cf_token.expand(tok.shape[0], -1)
cf_mask = cf_mask.expand(tok.shape[0], -1)
tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device=self.device), mask.to(device=self.device)
x = self.clip_model.token_embedding(tok).type(self.clip_model.dtype)
x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
x = x.permute(1, 0, 2) # NLD -> LND|
x = self.clip_model.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.clip_model.ln_final(x).type(self.clip_model.dtype)
txt_feat_seq = x
txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ self.clip_model.text_projection)
txt_feat, txt_feat_seq = txt_feat.float().to(self.device), txt_feat_seq.float().to(self.device)
img_feat = self.prior(
txt_feat,
txt_feat_seq,
mask,
prior_cf_scales_batch,
timestep_respacing=prior_steps,
)
return img_feat.to(self.model_dtype)
@torch.no_grad()
def encode_images(self, image, is_pil=False):
if is_pil:
image = self.preprocess(image).unsqueeze(0).to(self.device)
return self.clip_model.encode_image(image).to(self.model_dtype)
@torch.no_grad()
def generate_img(
self,
prompt,
img_prompt,
batch_size=1,
diffusion=None,
guidance_scale=7,
init_step=None,
noise=None,
init_img=None,
img_mask=None,
h=512,
w=512,
sampler="ddim_sampler",
num_steps=50,
):
new_h, new_w = self.get_new_h_w(h, w)
full_batch_size = batch_size * 2
model_kwargs = {}
if init_img is not None and self.use_fp16:
init_img = init_img.half()
if img_mask is not None and self.use_fp16:
img_mask = img_mask.half()
model_kwargs["full_emb"], model_kwargs["pooled_emb"] = self.encode_text(
text_encoder=self.text_encoder,
tokenizer=self.tokenizer1,
prompt=prompt,
batch_size=batch_size,
)
model_kwargs["image_emb"] = img_prompt
if self.task_type == "inpainting":
init_img = init_img.to(self.device)
img_mask = img_mask.to(self.device)
model_kwargs["inpaint_image"] = init_img * img_mask
model_kwargs["inpaint_mask"] = img_mask
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.model(combined, ts, **kwargs)
eps, rest = model_out[:, :4], model_out[:, 4:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
if sampler == "p_sampler":
return torch.cat([eps, rest], dim=1)
else:
return eps
if noise is not None:
noise = noise.float()
if self.task_type == "inpainting":
def denoised_fun(x_start):
x_start = x_start.clamp(-2, 2)
return x_start * (1 - img_mask) + init_img * img_mask
else:
def denoised_fun(x):
return x.clamp(-2, 2)
if sampler == "p_sampler":
self.model.del_cache()
samples = diffusion.p_sample_loop(
model_fn,
(full_batch_size, 4, new_h, new_w),
device=self.device,
noise=noise,
progress=True,
model_kwargs=model_kwargs,
init_step=init_step,
denoised_fn=denoised_fun,
)[:batch_size]
self.model.del_cache()
else:
if sampler == "ddim_sampler":
sampler = DDIMSampler(
model=model_fn,
old_diffusion=diffusion,
schedule="linear",
)
elif sampler == "plms_sampler":
sampler = PLMSSampler(
model=model_fn,
old_diffusion=diffusion,
schedule="linear",
)
else:
raise ValueError("Only ddim_sampler and plms_sampler is available")
self.model.del_cache()
samples, _ = sampler.sample(
num_steps,
batch_size * 2,
(4, new_h, new_w),
conditioning=model_kwargs,
x_T=noise,
init_step=init_step,
)
self.model.del_cache()
samples = samples[:batch_size]
if self.use_image_enc:
if self.use_fp16:
samples = samples.half()
samples = self.image_encoder.decode(samples / self.scale)
samples = samples[:, :, :h, :w]
return process_images(samples)
@torch.no_grad()
def create_zero_img_emb(self, batch_size):
img = torch.zeros(1, 3, self.clip_image_size, self.clip_image_size).to(self.device)
return self.encode_images(img, is_pil=False).repeat(batch_size, 1)
@torch.no_grad()
def generate_text2img(
self,
prompt,
num_steps=100,
batch_size=1,
guidance_scale=7,
h=512,
w=512,
sampler="ddim_sampler",
prior_cf_scale=4,
prior_steps="25",
negative_prior_prompt="",
negative_decoder_prompt="",
):
# generate clip embeddings
image_emb = self.generate_clip_emb(
prompt,
batch_size=batch_size,
prior_cf_scale=prior_cf_scale,
prior_steps=prior_steps,
negative_prior_prompt=negative_prior_prompt,
)
if negative_decoder_prompt == "":
zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
else:
zero_image_emb = self.generate_clip_emb(
negative_decoder_prompt,
batch_size=batch_size,
prior_cf_scale=prior_cf_scale,
prior_steps=prior_steps,
negative_prior_prompt=negative_prior_prompt,
)
image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
# load diffusion
config = deepcopy(self.config)
if sampler == "p_sampler":
config["diffusion_config"]["timestep_respacing"] = str(num_steps)
diffusion = create_gaussian_diffusion(**config["diffusion_config"])
return self.generate_img(
prompt=prompt,
img_prompt=image_emb,
batch_size=batch_size,
guidance_scale=guidance_scale,
h=h,
w=w,
sampler=sampler,
num_steps=num_steps,
diffusion=diffusion,
)
@torch.no_grad()
def mix_images(
self,
images_texts,
weights,
num_steps=100,
batch_size=1,
guidance_scale=7,
h=512,
w=512,
sampler="ddim_sampler",
prior_cf_scale=4,
prior_steps="25",
negative_prior_prompt="",
negative_decoder_prompt="",
):
assert len(images_texts) == len(weights) and len(images_texts) > 0
# generate clip embeddings
image_emb = None
for i in range(len(images_texts)):
if image_emb is None:
if type(images_texts[i]) == str:
image_emb = weights[i] * self.generate_clip_emb(
images_texts[i],
batch_size=1,
prior_cf_scale=prior_cf_scale,
prior_steps=prior_steps,
negative_prior_prompt=negative_prior_prompt,
)
else:
image_emb = self.encode_images(images_texts[i], is_pil=True) * weights[i]
else:
if type(images_texts[i]) == str:
image_emb = image_emb + weights[i] * self.generate_clip_emb(
images_texts[i],
batch_size=1,
prior_cf_scale=prior_cf_scale,
prior_steps=prior_steps,
negative_prior_prompt=negative_prior_prompt,
)
else:
image_emb = image_emb + self.encode_images(images_texts[i], is_pil=True) * weights[i]
image_emb = image_emb.repeat(batch_size, 1)
if negative_decoder_prompt == "":
zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
else:
zero_image_emb = self.generate_clip_emb(
negative_decoder_prompt,
batch_size=batch_size,
prior_cf_scale=prior_cf_scale,
prior_steps=prior_steps,
negative_prior_prompt=negative_prior_prompt,
)
image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
# load diffusion
config = deepcopy(self.config)
if sampler == "p_sampler":
config["diffusion_config"]["timestep_respacing"] = str(num_steps)
diffusion = create_gaussian_diffusion(**config["diffusion_config"])
return self.generate_img(
prompt="",
img_prompt=image_emb,
batch_size=batch_size,
guidance_scale=guidance_scale,
h=h,
w=w,
sampler=sampler,
num_steps=num_steps,
diffusion=diffusion,
)
@torch.no_grad()
def generate_img2img(
self,
prompt,
pil_img,
strength=0.7,
num_steps=100,
batch_size=1,
guidance_scale=7,
h=512,
w=512,
sampler="ddim_sampler",
prior_cf_scale=4,
prior_steps="25",
):
# generate clip embeddings
image_emb = self.generate_clip_emb(
prompt,
batch_size=batch_size,
prior_cf_scale=prior_cf_scale,
prior_steps=prior_steps,
)
zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
# load diffusion
config = deepcopy(self.config)
if sampler == "p_sampler":
config["diffusion_config"]["timestep_respacing"] = str(num_steps)
diffusion = create_gaussian_diffusion(**config["diffusion_config"])
image = prepare_image(pil_img, h=h, w=w).to(self.device)
if self.use_fp16:
image = image.half()
image = self.image_encoder.encode(image) * self.scale
start_step = int(diffusion.num_timesteps * (1 - strength))
image = q_sample(
image,
torch.tensor(diffusion.timestep_map[start_step - 1]).to(self.device),
schedule_name=config["diffusion_config"]["noise_schedule"],
num_steps=config["diffusion_config"]["steps"],
)
image = image.repeat(2, 1, 1, 1)
return self.generate_img(
prompt=prompt,
img_prompt=image_emb,
batch_size=batch_size,
guidance_scale=guidance_scale,
h=h,
w=w,
sampler=sampler,
num_steps=num_steps,
diffusion=diffusion,
noise=image,
init_step=start_step,
)
@torch.no_grad()
def generate_inpainting(
self,
prompt,
pil_img,
img_mask,
num_steps=100,
batch_size=1,
guidance_scale=7,
h=512,
w=512,
sampler="ddim_sampler",
prior_cf_scale=4,
prior_steps="25",
negative_prior_prompt="",
negative_decoder_prompt="",
):
# generate clip embeddings
image_emb = self.generate_clip_emb(
prompt,
batch_size=batch_size,
prior_cf_scale=prior_cf_scale,
prior_steps=prior_steps,
negative_prior_prompt=negative_prior_prompt,
)
zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
# load diffusion
config = deepcopy(self.config)
if sampler == "p_sampler":
config["diffusion_config"]["timestep_respacing"] = str(num_steps)
diffusion = create_gaussian_diffusion(**config["diffusion_config"])
image = prepare_image(pil_img, w, h).to(self.device)
if self.use_fp16:
image = image.half()
image = self.image_encoder.encode(image) * self.scale
image_shape = tuple(image.shape[-2:])
img_mask = torch.from_numpy(img_mask).unsqueeze(0).unsqueeze(0)
img_mask = F.interpolate(
img_mask,
image_shape,
mode="nearest",
)
img_mask = prepare_mask(img_mask).to(self.device)
if self.use_fp16:
img_mask = img_mask.half()
image = image.repeat(2, 1, 1, 1)
img_mask = img_mask.repeat(2, 1, 1, 1)
return self.generate_img(
prompt=prompt,
img_prompt=image_emb,
batch_size=batch_size,
guidance_scale=guidance_scale,
h=h,
w=w,
sampler=sampler,
num_steps=num_steps,
diffusion=diffusion,
init_img=image,
img_mask=img_mask,
)
import os
from huggingface_hub import hf_hub_url, cached_download
from copy import deepcopy
from omegaconf.dictconfig import DictConfig
def get_kandinsky2_1(
device,
task_type="text2img",
cache_dir="/tmp/kandinsky2",
use_auth_token=None,
use_flash_attention=False,
):
cache_dir = os.path.join(cache_dir, "2_1")
config = DictConfig(deepcopy(CONFIG_2_1))
config["model_config"]["use_flash_attention"] = use_flash_attention
if task_type == "text2img":
model_name = "decoder_fp16.ckpt"
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
elif task_type == "inpainting":
model_name = "inpainting_fp16.ckpt"
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
cached_download(
config_file_url,
cache_dir=cache_dir,
force_filename=model_name,
use_auth_token=use_auth_token,
)
prior_name = "prior_fp16.ckpt"
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=prior_name)
cached_download(
config_file_url,
cache_dir=cache_dir,
force_filename=prior_name,
use_auth_token=use_auth_token,
)
cache_dir_text_en = os.path.join(cache_dir, "text_encoder")
for name in [
"config.json",
"pytorch_model.bin",
"sentencepiece.bpe.model",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
]:
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=f"text_encoder/{name}")
cached_download(
config_file_url,
cache_dir=cache_dir_text_en,
force_filename=name,
use_auth_token=use_auth_token,
)
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="movq_final.ckpt")
cached_download(
config_file_url,
cache_dir=cache_dir,
force_filename="movq_final.ckpt",
use_auth_token=use_auth_token,
)
config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="ViT-L-14_stats.th")
cached_download(
config_file_url,
cache_dir=cache_dir,
force_filename="ViT-L-14_stats.th",
use_auth_token=use_auth_token,
)
config["tokenizer_name"] = cache_dir_text_en
config["text_enc_params"]["model_path"] = cache_dir_text_en
config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th")
config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt")
cache_model_name = os.path.join(cache_dir, model_name)
cache_prior_name = os.path.join(cache_dir, prior_name)
model = Kandinsky2_1(config, cache_model_name, cache_prior_name, device, task_type=task_type)
return model
def get_kandinsky2(
device,
task_type="text2img",
cache_dir="/tmp/kandinsky2",
use_auth_token=None,
model_version="2.1",
use_flash_attention=False,
):
if model_version == "2.0":
model = get_kandinsky2_0(
device,
task_type=task_type,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
)
elif model_version == "2.1":
model = get_kandinsky2_1(
device,
task_type=task_type,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
use_flash_attention=use_flash_attention,
)
elif model_version == "2.2":
model = Kandinsky2_2(device=device, task_type=task_type)
else:
raise ValueError("Only 2.0 and 2.1 is available")
return model