Himasnhu-at
add files
1850baa
raw
history blame
37.1 kB
import os
import random
import argparse
import json
import itertools
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from transformers import CLIPImageProcessor
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
from src.unet_hacked_tryon import UNet2DConditionModel
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from ip_adapter.ip_adapter import Resampler
from diffusers.utils.import_utils import is_xformers_available
from typing import Literal, Tuple,List
import torch.utils.data as data
import math
from tqdm.auto import tqdm
from diffusers.training_utils import compute_snr
import torchvision.transforms.functional as TF
class VitonHDDataset(data.Dataset):
def __init__(
self,
dataroot_path: str,
phase: Literal["train", "test"],
order: Literal["paired", "unpaired"] = "paired",
size: Tuple[int, int] = (512, 384),
):
super(VitonHDDataset, self).__init__()
self.dataroot = dataroot_path
self.phase = phase
self.height = size[0]
self.width = size[1]
self.size = size
self.norm = transforms.Normalize([0.5], [0.5])
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.transform2D = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
self.toTensor = transforms.ToTensor()
with open(
os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r"
) as file1:
data1 = json.load(file1)
annotation_list = [
# "colors",
# "textures",
"sleeveLength",
"neckLine",
"item",
]
self.annotation_pair = {}
for k, v in data1.items():
for elem in v:
annotation_str = ""
for template in annotation_list:
for tag in elem["tag_info"]:
if (
tag["tag_name"] == template
and tag["tag_category"] is not None
):
annotation_str += tag["tag_category"]
annotation_str += " "
self.annotation_pair[elem["file_name"]] = annotation_str
self.order = order
self.toTensor = transforms.ToTensor()
im_names = []
c_names = []
dataroot_names = []
if phase == "train":
filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
else:
filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
with open(filename, "r") as f:
for line in f.readlines():
if phase == "train":
im_name, _ = line.strip().split()
c_name = im_name
else:
if order == "paired":
im_name, _ = line.strip().split()
c_name = im_name
else:
im_name, c_name = line.strip().split()
im_names.append(im_name)
c_names.append(c_name)
dataroot_names.append(dataroot_path)
self.im_names = im_names
self.c_names = c_names
self.dataroot_names = dataroot_names
self.flip_transform = transforms.RandomHorizontalFlip(p=1)
self.clip_processor = CLIPImageProcessor()
def __getitem__(self, index):
c_name = self.c_names[index]
im_name = self.im_names[index]
# subject_txt = self.txt_preprocess['train']("shirt")
if c_name in self.annotation_pair:
cloth_annotation = self.annotation_pair[c_name]
else:
cloth_annotation = "shirts"
cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name))
im_pil_big = Image.open(
os.path.join(self.dataroot, self.phase, "image", im_name)
).resize((self.width,self.height))
image = self.transform(im_pil_big)
# load parsing image
mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height))
mask = self.toTensor(mask)
mask = mask[:1]
densepose_name = im_name
densepose_map = Image.open(
os.path.join(self.dataroot, self.phase, "image-densepose", densepose_name)
)
pose_img = self.toTensor(densepose_map) # [-1,1]
if self.phase == "train":
if random.random() > 0.5:
cloth = self.flip_transform(cloth)
mask = self.flip_transform(mask)
image = self.flip_transform(image)
pose_img = self.flip_transform(pose_img)
if random.random()>0.5:
color_jitter = transforms.ColorJitter(brightness=0.5, contrast=0.3, saturation=0.5, hue=0.5)
fn_idx, b, c, s, h = transforms.ColorJitter.get_params(color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,color_jitter.hue)
image = TF.adjust_contrast(image, c)
image = TF.adjust_brightness(image, b)
image = TF.adjust_hue(image, h)
image = TF.adjust_saturation(image, s)
cloth = TF.adjust_contrast(cloth, c)
cloth = TF.adjust_brightness(cloth, b)
cloth = TF.adjust_hue(cloth, h)
cloth = TF.adjust_saturation(cloth, s)
if random.random() > 0.5:
scale_val = random.uniform(0.8, 1.2)
image = transforms.functional.affine(
image, angle=0, translate=[0, 0], scale=scale_val, shear=0
)
mask = transforms.functional.affine(
mask, angle=0, translate=[0, 0], scale=scale_val, shear=0
)
pose_img = transforms.functional.affine(
pose_img, angle=0, translate=[0, 0], scale=scale_val, shear=0
)
if random.random() > 0.5:
shift_valx = random.uniform(-0.2, 0.2)
shift_valy = random.uniform(-0.2, 0.2)
image = transforms.functional.affine(
image,
angle=0,
translate=[shift_valx * image.shape[-1], shift_valy * image.shape[-2]],
scale=1,
shear=0,
)
mask = transforms.functional.affine(
mask,
angle=0,
translate=[shift_valx * mask.shape[-1], shift_valy * mask.shape[-2]],
scale=1,
shear=0,
)
pose_img = transforms.functional.affine(
pose_img,
angle=0,
translate=[
shift_valx * pose_img.shape[-1],
shift_valy * pose_img.shape[-2],
],
scale=1,
shear=0,
)
mask = 1-mask
cloth_trim = self.clip_processor(images=cloth, return_tensors="pt").pixel_values
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
im_mask = image * mask
pose_img = self.norm(pose_img)
result = {}
result["c_name"] = c_name
result["image"] = image
result["cloth"] = cloth_trim
result["cloth_pure"] = self.transform(cloth)
result["inpaint_mask"] = 1-mask
result["im_mask"] = im_mask
result["caption"] = "model is wearing " + cloth_annotation
result["caption_cloth"] = "a photo of " + cloth_annotation
result["annotation"] = cloth_annotation
result["pose_img"] = pose_img
return result
def __len__(self):
return len(self.im_names)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--pretrained_model_name_or_path",type=str,default="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",required=False,help="Path to pretrained model or model identifier from huggingface.co/models.",)
parser.add_argument("--pretrained_garmentnet_path",type=str,default="stabilityai/stable-diffusion-xl-base-1.0",required=False,help="Path to pretrained model or model identifier from huggingface.co/models.",)
parser.add_argument("--checkpointing_epoch",type=int,default=10,help=("Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"" training using `--resume_from_checkpoint`."),)
parser.add_argument("--pretrained_ip_adapter_path",type=str,default="ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin",help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",)
parser.add_argument("--image_encoder_path",type=str,default="ckpt/image_encoder",required=False,help="Path to CLIP image encoder",)
parser.add_argument("--gradient_checkpointing",action="store_true",help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",)
parser.add_argument("--width",type=int,default=768,)
parser.add_argument("--height",type=int,default=1024,)
parser.add_argument("--gradient_accumulation_steps",type=int,default=1,help="Number of updates steps to accumulate before performing a backward/update pass.",)
parser.add_argument("--logging_steps",type=int,default=1000,help=("Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"" training using `--resume_from_checkpoint`."),)
parser.add_argument("--output_dir",type=str,default="output",help="The output directory where the model predictions and checkpoints will be written.",)
parser.add_argument("--snr_gamma",type=float,default=None,help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. ""More details here: https://arxiv.org/abs/2303.09556.",)
parser.add_argument("--num_tokens",type=int,default=16,help=("IP adapter token nums"),)
parser.add_argument("--learning_rate",type=float,default=1e-5,help="Learning rate to use.",)
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--train_batch_size", type=int, default=6, help="Batch size (per device) for the training dataloader.")
parser.add_argument("--test_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
parser.add_argument("--num_train_epochs", type=int, default=130)
parser.add_argument("--max_train_steps",type=int,default=None,help="Total number of training steps to perform. If provided, overrides num_train_epochs.",)
parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],help=("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),)
parser.add_argument("--guidance_scale",type=float,default=2.0,)
parser.add_argument("--seed", type=int, default=42,)
parser.add_argument("--num_inference_steps",type=int,default=30,)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--data_dir", type=str, default="/home/omnious/workspace/yisol/Dataset/VITON-HD/zalando", help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def main():
args = parse_args()
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
gradient_accumulation_steps=args.gradient_accumulation_steps,
project_config=accelerator_project_config,
)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler",rescale_betas_zero_snr=True)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,subfolder="vae",torch_dtype=torch.float16,)
unet_encoder = UNet2DConditionModel_ref.from_pretrained(args.pretrained_garmentnet_path, subfolder="unet")
unet_encoder.config.addition_embed_type = None
unet_encoder.config["addition_embed_type"] = None
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
#customize unet start
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",low_cpu_mem_usage=False, device_map=None)
unet.config.encoder_hid_dim = image_encoder.config.hidden_size
unet.config.encoder_hid_dim_type = "ip_image_proj"
unet.config["encoder_hid_dim"] = image_encoder.config.hidden_size
unet.config["encoder_hid_dim_type"] = "ip_image_proj"
state_dict = torch.load(args.pretrained_ip_adapter_path, map_location="cpu")
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
adapter_modules.load_state_dict(state_dict["ip_adapter"],strict=True)
#ip-adapter
image_proj_model = Resampler(
dim=image_encoder.config.hidden_size,
depth=4,
dim_head=64,
heads=20,
num_queries=args.num_tokens,
embedding_dim=image_encoder.config.hidden_size,
output_dim=unet.config.cross_attention_dim,
ff_mult=4,
).to(accelerator.device, dtype=torch.float32)
image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
image_proj_model.requires_grad_(True)
unet.encoder_hid_proj = image_proj_model
conv_new = torch.nn.Conv2d(
in_channels=4+4+1+4,
out_channels=unet.conv_in.out_channels,
kernel_size=3,
padding=1,
)
torch.nn.init.kaiming_normal_(conv_new.weight)
conv_new.weight.data = conv_new.weight.data * 0.
conv_new.weight.data[:, :9] = unet.conv_in.weight.data
conv_new.bias.data = unet.conv_in.bias.data
unet.conv_in = conv_new # replace conv layer in unet
unet.config['in_channels'] = 13 # update config
unet.config.in_channels = 13 # update config
#customize unet end
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae.to(accelerator.device)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
image_encoder.to(accelerator.device, dtype=weight_dtype)
unet_encoder.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)
unet_encoder.requires_grad_(False)
unet.requires_grad_(True)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
unet_encoder.enable_gradient_checkpointing()
unet.train()
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
params_to_opt = itertools.chain(unet.parameters())
optimizer = optimizer_class(
params_to_opt,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
train_dataset = VitonHDDataset(
dataroot_path=args.data_dir,
phase="train",
order="paired",
size=(args.height, args.width),
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
pin_memory=True,
shuffle=False,
batch_size=args.train_batch_size,
num_workers=16,
)
test_dataset = VitonHDDataset(
dataroot_path=args.data_dir,
phase="test",
order="paired",
size=(args.height, args.width),
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
shuffle=False,
batch_size=args.test_batch_size,
num_workers=4,
)
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
unet,image_proj_model,unet_encoder,image_encoder,optimizer,train_dataloader,test_dataloader = accelerator.prepare(unet, image_proj_model,unet_encoder,image_encoder,optimizer,train_dataloader,test_dataloader)
initial_global_step = 0
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
global_step = 0
first_epoch = 0
train_loss=0.0
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet), accelerator.accumulate(image_proj_model):
if global_step % args.logging_steps == 0:
if accelerator.is_main_process:
with torch.no_grad():
with torch.cuda.amp.autocast():
unwrapped_unet= accelerator.unwrap_model(unet)
newpipe = TryonPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unwrapped_unet,
vae= vae,
scheduler=noise_scheduler,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
image_encoder=image_encoder,
unet_encoder = unet_encoder,
torch_dtype=torch.float16,
add_watermarker=False,
safety_checker=None,
).to(accelerator.device)
with torch.no_grad():
for sample in test_dataloader:
img_emb_list = []
for i in range(sample['cloth'].shape[0]):
img_emb_list.append(sample['cloth'][i])
prompt = sample["caption"]
num_prompts = sample['cloth'].shape[0]
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_embeds = torch.cat(img_emb_list,dim=0)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = newpipe.encode_prompt(
prompt,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt = sample["caption_cloth"]
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
with torch.inference_mode():
(
prompt_embeds_c,
_,
_,
_,
) = newpipe.encode_prompt(
prompt,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
negative_prompt=negative_prompt,
)
generator = torch.Generator(newpipe.device).manual_seed(args.seed) if args.seed is not None else None
images = newpipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=args.num_inference_steps,
generator=generator,
strength = 1.0,
pose_img = sample['pose_img'],
text_embeds_cloth=prompt_embeds_c,
cloth = sample["cloth_pure"].to(accelerator.device),
mask_image=sample['inpaint_mask'],
image=(sample['image']+1.0)/2.0,
height=args.height,
width=args.width,
guidance_scale=args.guidance_scale,
ip_adapter_image = image_embeds,
)[0]
for i in range(len(images)):
images[i].save(os.path.join(args.output_dir,str(global_step)+"_"+str(i)+"_"+"test.jpg"))
break
del unwrapped_unet
del newpipe
torch.cuda.empty_cache()
pixel_values = batch["image"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
masked_latents = vae.encode(
batch["im_mask"].reshape(batch["image"].shape).to(dtype=vae.dtype)
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
masks = batch["inpaint_mask"]
# resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack(
[
torch.nn.functional.interpolate(masks, size=(args.height // 8, args.width // 8))
]
)
mask = mask.reshape(-1, 1, args.height // 8, args.width // 8)
pose_map = vae.encode(batch["pose_img"].to(dtype=vae.dtype)).latent_dist.sample()
pose_map = pose_map * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
# Add noise to the latents according to the noise magnitude at each timestep
noisy_latents = noise_scheduler.add_noise(model_input, noise, timesteps)
latent_model_input = torch.cat([noisy_latents, mask,masked_latents,pose_map], dim=1)
text_input_ids = tokenizer(
batch['caption'],
max_length=tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
text_input_ids_2 = tokenizer_2(
batch['caption'],
max_length=tokenizer_2.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
encoder_output = text_encoder(text_input_ids.to(accelerator.device), output_hidden_states=True)
text_embeds = encoder_output.hidden_states[-2]
encoder_output_2 = text_encoder_2(text_input_ids_2.to(accelerator.device), output_hidden_states=True)
pooled_text_embeds = encoder_output_2[0]
text_embeds_2 = encoder_output_2.hidden_states[-2]
encoder_hidden_states = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat
def compute_time_ids(original_size, crops_coords_top_left = (0,0)):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.height, args.height)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device)
return add_time_ids
add_time_ids = torch.cat(
[compute_time_ids((args.height, args.height)) for i in range(bsz)]
)
img_emb_list = []
for i in range(bsz):
img_emb_list.append(batch['cloth'][i])
image_embeds = torch.cat(img_emb_list,dim=0)
image_embeds = image_encoder(image_embeds, output_hidden_states=True).hidden_states[-2]
ip_tokens =image_proj_model(image_embeds)
# add cond
unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
unet_added_cond_kwargs["image_embeds"] = ip_tokens
cloth_values = batch["cloth_pure"].to(accelerator.device,dtype=vae.dtype)
cloth_values = vae.encode(cloth_values).latent_dist.sample()
cloth_values = cloth_values * vae.config.scaling_factor
text_input_ids = tokenizer(
batch['caption_cloth'],
max_length=tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
text_input_ids_2 = tokenizer_2(
batch['caption_cloth'],
max_length=tokenizer_2.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
encoder_output = text_encoder(text_input_ids.to(accelerator.device), output_hidden_states=True)
text_embeds_cloth = encoder_output.hidden_states[-2]
encoder_output_2 = text_encoder_2(text_input_ids_2.to(accelerator.device), output_hidden_states=True)
text_embeds_2_cloth = encoder_output_2.hidden_states[-2]
text_embeds_cloth = torch.concat([text_embeds_cloth, text_embeds_2_cloth], dim=-1) # concat
down,reference_features = unet_encoder(cloth_values,timesteps, text_embeds_cloth,return_dict=False)
reference_features = list(reference_features)
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states,added_cond_kwargs=unet_added_cond_kwargs,garment_features=reference_features).sample
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
elif noise_scheduler.config.prediction_type == "sample":
# We set the target to latents here, but the model_pred will return the noise sample prediction.
target = model_input
# We will have to subtract the noise residual from the prediction to get the target sample.
model_pred = model_pred - noise
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.snr_gamma is None:
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_to_opt, 1.0)
optimizer.step()
optimizer.zero_grad()
# Load scheduler, tokenizer and models.
progress_bar.update(1)
global_step += 1
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
logs = {"step_loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if global_step % args.checkpointing_epoch == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
unwrapped_unet = accelerator.unwrap_model(
unet, keep_fp32_wrapper=True
)
pipeline = TryonPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unwrapped_unet,
vae= vae,
scheduler=noise_scheduler,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
image_encoder=image_encoder,
unet_encoder=unet_encoder,
torch_dtype=torch.float16,
add_watermarker=False,
safety_checker=None,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path)
del pipeline
if __name__ == "__main__":
main()