|
import os |
|
import random |
|
import argparse |
|
import json |
|
import itertools |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import CLIPImageProcessor |
|
from accelerate import Accelerator |
|
from accelerate.utils import ProjectConfiguration |
|
from diffusers import AutoencoderKL, DDPMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection |
|
|
|
from src.unet_hacked_tryon import UNet2DConditionModel |
|
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref |
|
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline |
|
|
|
from ip_adapter.ip_adapter import Resampler |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from typing import Literal, Tuple,List |
|
import torch.utils.data as data |
|
import math |
|
from tqdm.auto import tqdm |
|
from diffusers.training_utils import compute_snr |
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
|
|
class VitonHDDataset(data.Dataset): |
|
def __init__( |
|
self, |
|
dataroot_path: str, |
|
phase: Literal["train", "test"], |
|
order: Literal["paired", "unpaired"] = "paired", |
|
size: Tuple[int, int] = (512, 384), |
|
): |
|
super(VitonHDDataset, self).__init__() |
|
self.dataroot = dataroot_path |
|
self.phase = phase |
|
self.height = size[0] |
|
self.width = size[1] |
|
self.size = size |
|
|
|
|
|
self.norm = transforms.Normalize([0.5], [0.5]) |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
) |
|
self.transform2D = transforms.Compose( |
|
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] |
|
) |
|
self.toTensor = transforms.ToTensor() |
|
|
|
with open( |
|
os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r" |
|
) as file1: |
|
data1 = json.load(file1) |
|
|
|
annotation_list = [ |
|
|
|
|
|
"sleeveLength", |
|
"neckLine", |
|
"item", |
|
] |
|
|
|
self.annotation_pair = {} |
|
for k, v in data1.items(): |
|
for elem in v: |
|
annotation_str = "" |
|
for template in annotation_list: |
|
for tag in elem["tag_info"]: |
|
if ( |
|
tag["tag_name"] == template |
|
and tag["tag_category"] is not None |
|
): |
|
annotation_str += tag["tag_category"] |
|
annotation_str += " " |
|
self.annotation_pair[elem["file_name"]] = annotation_str |
|
|
|
|
|
self.order = order |
|
|
|
self.toTensor = transforms.ToTensor() |
|
|
|
im_names = [] |
|
c_names = [] |
|
dataroot_names = [] |
|
|
|
|
|
if phase == "train": |
|
filename = os.path.join(dataroot_path, f"{phase}_pairs.txt") |
|
else: |
|
filename = os.path.join(dataroot_path, f"{phase}_pairs.txt") |
|
|
|
with open(filename, "r") as f: |
|
for line in f.readlines(): |
|
if phase == "train": |
|
im_name, _ = line.strip().split() |
|
c_name = im_name |
|
else: |
|
if order == "paired": |
|
im_name, _ = line.strip().split() |
|
c_name = im_name |
|
else: |
|
im_name, c_name = line.strip().split() |
|
|
|
im_names.append(im_name) |
|
c_names.append(c_name) |
|
dataroot_names.append(dataroot_path) |
|
|
|
self.im_names = im_names |
|
self.c_names = c_names |
|
self.dataroot_names = dataroot_names |
|
self.flip_transform = transforms.RandomHorizontalFlip(p=1) |
|
self.clip_processor = CLIPImageProcessor() |
|
def __getitem__(self, index): |
|
c_name = self.c_names[index] |
|
im_name = self.im_names[index] |
|
|
|
if c_name in self.annotation_pair: |
|
cloth_annotation = self.annotation_pair[c_name] |
|
else: |
|
cloth_annotation = "shirts" |
|
|
|
cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name)) |
|
|
|
im_pil_big = Image.open( |
|
os.path.join(self.dataroot, self.phase, "image", im_name) |
|
).resize((self.width,self.height)) |
|
|
|
image = self.transform(im_pil_big) |
|
|
|
|
|
|
|
mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height)) |
|
mask = self.toTensor(mask) |
|
mask = mask[:1] |
|
densepose_name = im_name |
|
densepose_map = Image.open( |
|
os.path.join(self.dataroot, self.phase, "image-densepose", densepose_name) |
|
) |
|
pose_img = self.toTensor(densepose_map) |
|
|
|
|
|
|
|
if self.phase == "train": |
|
if random.random() > 0.5: |
|
cloth = self.flip_transform(cloth) |
|
mask = self.flip_transform(mask) |
|
image = self.flip_transform(image) |
|
pose_img = self.flip_transform(pose_img) |
|
|
|
|
|
|
|
if random.random()>0.5: |
|
color_jitter = transforms.ColorJitter(brightness=0.5, contrast=0.3, saturation=0.5, hue=0.5) |
|
fn_idx, b, c, s, h = transforms.ColorJitter.get_params(color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,color_jitter.hue) |
|
|
|
image = TF.adjust_contrast(image, c) |
|
image = TF.adjust_brightness(image, b) |
|
image = TF.adjust_hue(image, h) |
|
image = TF.adjust_saturation(image, s) |
|
|
|
cloth = TF.adjust_contrast(cloth, c) |
|
cloth = TF.adjust_brightness(cloth, b) |
|
cloth = TF.adjust_hue(cloth, h) |
|
cloth = TF.adjust_saturation(cloth, s) |
|
|
|
|
|
if random.random() > 0.5: |
|
scale_val = random.uniform(0.8, 1.2) |
|
image = transforms.functional.affine( |
|
image, angle=0, translate=[0, 0], scale=scale_val, shear=0 |
|
) |
|
mask = transforms.functional.affine( |
|
mask, angle=0, translate=[0, 0], scale=scale_val, shear=0 |
|
) |
|
pose_img = transforms.functional.affine( |
|
pose_img, angle=0, translate=[0, 0], scale=scale_val, shear=0 |
|
) |
|
|
|
|
|
|
|
if random.random() > 0.5: |
|
shift_valx = random.uniform(-0.2, 0.2) |
|
shift_valy = random.uniform(-0.2, 0.2) |
|
image = transforms.functional.affine( |
|
image, |
|
angle=0, |
|
translate=[shift_valx * image.shape[-1], shift_valy * image.shape[-2]], |
|
scale=1, |
|
shear=0, |
|
) |
|
mask = transforms.functional.affine( |
|
mask, |
|
angle=0, |
|
translate=[shift_valx * mask.shape[-1], shift_valy * mask.shape[-2]], |
|
scale=1, |
|
shear=0, |
|
) |
|
pose_img = transforms.functional.affine( |
|
pose_img, |
|
angle=0, |
|
translate=[ |
|
shift_valx * pose_img.shape[-1], |
|
shift_valy * pose_img.shape[-2], |
|
], |
|
scale=1, |
|
shear=0, |
|
) |
|
|
|
|
|
|
|
|
|
mask = 1-mask |
|
|
|
cloth_trim = self.clip_processor(images=cloth, return_tensors="pt").pixel_values |
|
|
|
|
|
mask[mask < 0.5] = 0 |
|
mask[mask >= 0.5] = 1 |
|
|
|
im_mask = image * mask |
|
|
|
pose_img = self.norm(pose_img) |
|
|
|
|
|
result = {} |
|
result["c_name"] = c_name |
|
result["image"] = image |
|
result["cloth"] = cloth_trim |
|
result["cloth_pure"] = self.transform(cloth) |
|
result["inpaint_mask"] = 1-mask |
|
result["im_mask"] = im_mask |
|
result["caption"] = "model is wearing " + cloth_annotation |
|
result["caption_cloth"] = "a photo of " + cloth_annotation |
|
result["annotation"] = cloth_annotation |
|
result["pose_img"] = pose_img |
|
|
|
|
|
return result |
|
|
|
def __len__(self): |
|
return len(self.im_names) |
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Simple example of a training script.") |
|
parser.add_argument("--pretrained_model_name_or_path",type=str,default="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",required=False,help="Path to pretrained model or model identifier from huggingface.co/models.",) |
|
parser.add_argument("--pretrained_garmentnet_path",type=str,default="stabilityai/stable-diffusion-xl-base-1.0",required=False,help="Path to pretrained model or model identifier from huggingface.co/models.",) |
|
parser.add_argument("--checkpointing_epoch",type=int,default=10,help=("Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"" training using `--resume_from_checkpoint`."),) |
|
parser.add_argument("--pretrained_ip_adapter_path",type=str,default="ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin",help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",) |
|
parser.add_argument("--image_encoder_path",type=str,default="ckpt/image_encoder",required=False,help="Path to CLIP image encoder",) |
|
parser.add_argument("--gradient_checkpointing",action="store_true",help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",) |
|
parser.add_argument("--width",type=int,default=768,) |
|
parser.add_argument("--height",type=int,default=1024,) |
|
parser.add_argument("--gradient_accumulation_steps",type=int,default=1,help="Number of updates steps to accumulate before performing a backward/update pass.",) |
|
parser.add_argument("--logging_steps",type=int,default=1000,help=("Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"" training using `--resume_from_checkpoint`."),) |
|
parser.add_argument("--output_dir",type=str,default="output",help="The output directory where the model predictions and checkpoints will be written.",) |
|
parser.add_argument("--snr_gamma",type=float,default=None,help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. ""More details here: https://arxiv.org/abs/2303.09556.",) |
|
parser.add_argument("--num_tokens",type=int,default=16,help=("IP adapter token nums"),) |
|
parser.add_argument("--learning_rate",type=float,default=1e-5,help="Learning rate to use.",) |
|
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") |
|
parser.add_argument("--train_batch_size", type=int, default=6, help="Batch size (per device) for the training dataloader.") |
|
parser.add_argument("--test_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") |
|
parser.add_argument("--num_train_epochs", type=int, default=130) |
|
parser.add_argument("--max_train_steps",type=int,default=None,help="Total number of training steps to perform. If provided, overrides num_train_epochs.",) |
|
parser.add_argument("--noise_offset", type=float, default=None, help="noise offset") |
|
parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes.") |
|
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") |
|
parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],help=("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),) |
|
parser.add_argument("--guidance_scale",type=float,default=2.0,) |
|
parser.add_argument("--seed", type=int, default=42,) |
|
parser.add_argument("--num_inference_steps",type=int,default=30,) |
|
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") |
|
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") |
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") |
|
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") |
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
|
parser.add_argument("--data_dir", type=str, default="/home/omnious/workspace/yisol/Dataset/VITON-HD/zalando", help="For distributed training: local_rank") |
|
|
|
args = parser.parse_args() |
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
|
if env_local_rank != -1 and env_local_rank != args.local_rank: |
|
args.local_rank = env_local_rank |
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
args = parse_args() |
|
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir) |
|
accelerator = Accelerator( |
|
mixed_precision=args.mixed_precision, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
project_config=accelerator_project_config, |
|
) |
|
|
|
if accelerator.is_main_process: |
|
if args.output_dir is not None: |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler",rescale_betas_zero_snr=True) |
|
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
|
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") |
|
tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") |
|
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2") |
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,subfolder="vae",torch_dtype=torch.float16,) |
|
unet_encoder = UNet2DConditionModel_ref.from_pretrained(args.pretrained_garmentnet_path, subfolder="unet") |
|
unet_encoder.config.addition_embed_type = None |
|
unet_encoder.config["addition_embed_type"] = None |
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) |
|
|
|
|
|
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",low_cpu_mem_usage=False, device_map=None) |
|
unet.config.encoder_hid_dim = image_encoder.config.hidden_size |
|
unet.config.encoder_hid_dim_type = "ip_image_proj" |
|
unet.config["encoder_hid_dim"] = image_encoder.config.hidden_size |
|
unet.config["encoder_hid_dim_type"] = "ip_image_proj" |
|
|
|
|
|
state_dict = torch.load(args.pretrained_ip_adapter_path, map_location="cpu") |
|
|
|
|
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) |
|
adapter_modules.load_state_dict(state_dict["ip_adapter"],strict=True) |
|
|
|
|
|
image_proj_model = Resampler( |
|
dim=image_encoder.config.hidden_size, |
|
depth=4, |
|
dim_head=64, |
|
heads=20, |
|
num_queries=args.num_tokens, |
|
embedding_dim=image_encoder.config.hidden_size, |
|
output_dim=unet.config.cross_attention_dim, |
|
ff_mult=4, |
|
).to(accelerator.device, dtype=torch.float32) |
|
|
|
image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) |
|
image_proj_model.requires_grad_(True) |
|
|
|
unet.encoder_hid_proj = image_proj_model |
|
|
|
conv_new = torch.nn.Conv2d( |
|
in_channels=4+4+1+4, |
|
out_channels=unet.conv_in.out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
) |
|
torch.nn.init.kaiming_normal_(conv_new.weight) |
|
conv_new.weight.data = conv_new.weight.data * 0. |
|
|
|
conv_new.weight.data[:, :9] = unet.conv_in.weight.data |
|
conv_new.bias.data = unet.conv_in.bias.data |
|
|
|
unet.conv_in = conv_new |
|
unet.config['in_channels'] = 13 |
|
unet.config.in_channels = 13 |
|
|
|
|
|
|
|
weight_dtype = torch.float32 |
|
if accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif accelerator.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
vae.to(accelerator.device) |
|
text_encoder.to(accelerator.device, dtype=weight_dtype) |
|
text_encoder_2.to(accelerator.device, dtype=weight_dtype) |
|
image_encoder.to(accelerator.device, dtype=weight_dtype) |
|
unet_encoder.to(accelerator.device, dtype=weight_dtype) |
|
|
|
|
|
vae.requires_grad_(False) |
|
text_encoder.requires_grad_(False) |
|
text_encoder_2.requires_grad_(False) |
|
image_encoder.requires_grad_(False) |
|
unet_encoder.requires_grad_(False) |
|
unet.requires_grad_(True) |
|
|
|
|
|
|
|
|
|
if args.enable_xformers_memory_efficient_attention: |
|
if is_xformers_available(): |
|
import xformers |
|
|
|
unet.enable_xformers_memory_efficient_attention() |
|
else: |
|
raise ValueError("xformers is not available. Make sure it is installed correctly") |
|
|
|
if args.gradient_checkpointing: |
|
unet.enable_gradient_checkpointing() |
|
unet_encoder.enable_gradient_checkpointing() |
|
unet.train() |
|
|
|
if args.use_8bit_adam: |
|
try: |
|
import bitsandbytes as bnb |
|
except ImportError: |
|
raise ImportError( |
|
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." |
|
) |
|
|
|
optimizer_class = bnb.optim.AdamW8bit |
|
else: |
|
optimizer_class = torch.optim.AdamW |
|
|
|
params_to_opt = itertools.chain(unet.parameters()) |
|
|
|
|
|
optimizer = optimizer_class( |
|
params_to_opt, |
|
lr=args.learning_rate, |
|
betas=(args.adam_beta1, args.adam_beta2), |
|
weight_decay=args.adam_weight_decay, |
|
eps=args.adam_epsilon, |
|
) |
|
|
|
train_dataset = VitonHDDataset( |
|
dataroot_path=args.data_dir, |
|
phase="train", |
|
order="paired", |
|
size=(args.height, args.width), |
|
) |
|
train_dataloader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
pin_memory=True, |
|
shuffle=False, |
|
batch_size=args.train_batch_size, |
|
num_workers=16, |
|
) |
|
test_dataset = VitonHDDataset( |
|
dataroot_path=args.data_dir, |
|
phase="test", |
|
order="paired", |
|
size=(args.height, args.width), |
|
) |
|
test_dataloader = torch.utils.data.DataLoader( |
|
test_dataset, |
|
shuffle=False, |
|
batch_size=args.test_batch_size, |
|
num_workers=4, |
|
) |
|
|
|
overrode_max_train_steps = False |
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
|
if args.max_train_steps is None: |
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
|
overrode_max_train_steps = True |
|
|
|
|
|
unet,image_proj_model,unet_encoder,image_encoder,optimizer,train_dataloader,test_dataloader = accelerator.prepare(unet, image_proj_model,unet_encoder,image_encoder,optimizer,train_dataloader,test_dataloader) |
|
initial_global_step = 0 |
|
|
|
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
|
if overrode_max_train_steps: |
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
|
|
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
|
|
|
|
|
progress_bar = tqdm( |
|
range(0, args.max_train_steps), |
|
initial=initial_global_step, |
|
desc="Steps", |
|
|
|
disable=not accelerator.is_local_main_process, |
|
) |
|
global_step = 0 |
|
first_epoch = 0 |
|
train_loss=0.0 |
|
for epoch in range(first_epoch, args.num_train_epochs): |
|
for step, batch in enumerate(train_dataloader): |
|
with accelerator.accumulate(unet), accelerator.accumulate(image_proj_model): |
|
if global_step % args.logging_steps == 0: |
|
if accelerator.is_main_process: |
|
with torch.no_grad(): |
|
with torch.cuda.amp.autocast(): |
|
unwrapped_unet= accelerator.unwrap_model(unet) |
|
newpipe = TryonPipeline.from_pretrained( |
|
args.pretrained_model_name_or_path, |
|
unet=unwrapped_unet, |
|
vae= vae, |
|
scheduler=noise_scheduler, |
|
tokenizer=tokenizer, |
|
tokenizer_2=tokenizer_2, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_2, |
|
image_encoder=image_encoder, |
|
unet_encoder = unet_encoder, |
|
torch_dtype=torch.float16, |
|
add_watermarker=False, |
|
safety_checker=None, |
|
).to(accelerator.device) |
|
with torch.no_grad(): |
|
for sample in test_dataloader: |
|
img_emb_list = [] |
|
for i in range(sample['cloth'].shape[0]): |
|
img_emb_list.append(sample['cloth'][i]) |
|
|
|
prompt = sample["caption"] |
|
|
|
num_prompts = sample['cloth'].shape[0] |
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" |
|
|
|
if not isinstance(prompt, List): |
|
prompt = [prompt] * num_prompts |
|
if not isinstance(negative_prompt, List): |
|
negative_prompt = [negative_prompt] * num_prompts |
|
|
|
image_embeds = torch.cat(img_emb_list,dim=0) |
|
|
|
|
|
with torch.inference_mode(): |
|
( |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds, |
|
) = newpipe.encode_prompt( |
|
prompt, |
|
num_images_per_prompt=1, |
|
do_classifier_free_guidance=True, |
|
negative_prompt=negative_prompt, |
|
) |
|
|
|
|
|
prompt = sample["caption_cloth"] |
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" |
|
|
|
if not isinstance(prompt, List): |
|
prompt = [prompt] * num_prompts |
|
if not isinstance(negative_prompt, List): |
|
negative_prompt = [negative_prompt] * num_prompts |
|
|
|
|
|
with torch.inference_mode(): |
|
( |
|
prompt_embeds_c, |
|
_, |
|
_, |
|
_, |
|
) = newpipe.encode_prompt( |
|
prompt, |
|
num_images_per_prompt=1, |
|
do_classifier_free_guidance=False, |
|
negative_prompt=negative_prompt, |
|
) |
|
|
|
|
|
|
|
generator = torch.Generator(newpipe.device).manual_seed(args.seed) if args.seed is not None else None |
|
images = newpipe( |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
num_inference_steps=args.num_inference_steps, |
|
generator=generator, |
|
strength = 1.0, |
|
pose_img = sample['pose_img'], |
|
text_embeds_cloth=prompt_embeds_c, |
|
cloth = sample["cloth_pure"].to(accelerator.device), |
|
mask_image=sample['inpaint_mask'], |
|
image=(sample['image']+1.0)/2.0, |
|
height=args.height, |
|
width=args.width, |
|
guidance_scale=args.guidance_scale, |
|
ip_adapter_image = image_embeds, |
|
)[0] |
|
|
|
for i in range(len(images)): |
|
images[i].save(os.path.join(args.output_dir,str(global_step)+"_"+str(i)+"_"+"test.jpg")) |
|
break |
|
del unwrapped_unet |
|
del newpipe |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
pixel_values = batch["image"].to(dtype=vae.dtype) |
|
model_input = vae.encode(pixel_values).latent_dist.sample() |
|
model_input = model_input * vae.config.scaling_factor |
|
|
|
masked_latents = vae.encode( |
|
batch["im_mask"].reshape(batch["image"].shape).to(dtype=vae.dtype) |
|
).latent_dist.sample() |
|
masked_latents = masked_latents * vae.config.scaling_factor |
|
masks = batch["inpaint_mask"] |
|
|
|
mask = torch.stack( |
|
[ |
|
torch.nn.functional.interpolate(masks, size=(args.height // 8, args.width // 8)) |
|
] |
|
) |
|
mask = mask.reshape(-1, 1, args.height // 8, args.width // 8) |
|
|
|
pose_map = vae.encode(batch["pose_img"].to(dtype=vae.dtype)).latent_dist.sample() |
|
pose_map = pose_map * vae.config.scaling_factor |
|
|
|
|
|
noise = torch.randn_like(model_input) |
|
|
|
bsz = model_input.shape[0] |
|
timesteps = torch.randint( |
|
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device |
|
) |
|
|
|
noisy_latents = noise_scheduler.add_noise(model_input, noise, timesteps) |
|
latent_model_input = torch.cat([noisy_latents, mask,masked_latents,pose_map], dim=1) |
|
|
|
|
|
text_input_ids = tokenizer( |
|
batch['caption'], |
|
max_length=tokenizer.model_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
).input_ids |
|
text_input_ids_2 = tokenizer_2( |
|
batch['caption'], |
|
max_length=tokenizer_2.model_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
).input_ids |
|
|
|
encoder_output = text_encoder(text_input_ids.to(accelerator.device), output_hidden_states=True) |
|
text_embeds = encoder_output.hidden_states[-2] |
|
encoder_output_2 = text_encoder_2(text_input_ids_2.to(accelerator.device), output_hidden_states=True) |
|
pooled_text_embeds = encoder_output_2[0] |
|
text_embeds_2 = encoder_output_2.hidden_states[-2] |
|
encoder_hidden_states = torch.concat([text_embeds, text_embeds_2], dim=-1) |
|
|
|
|
|
def compute_time_ids(original_size, crops_coords_top_left = (0,0)): |
|
|
|
target_size = (args.height, args.height) |
|
add_time_ids = list(original_size + crops_coords_top_left + target_size) |
|
add_time_ids = torch.tensor([add_time_ids]) |
|
add_time_ids = add_time_ids.to(accelerator.device) |
|
return add_time_ids |
|
|
|
add_time_ids = torch.cat( |
|
[compute_time_ids((args.height, args.height)) for i in range(bsz)] |
|
) |
|
|
|
img_emb_list = [] |
|
for i in range(bsz): |
|
img_emb_list.append(batch['cloth'][i]) |
|
|
|
image_embeds = torch.cat(img_emb_list,dim=0) |
|
image_embeds = image_encoder(image_embeds, output_hidden_states=True).hidden_states[-2] |
|
ip_tokens =image_proj_model(image_embeds) |
|
|
|
|
|
|
|
|
|
unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids} |
|
unet_added_cond_kwargs["image_embeds"] = ip_tokens |
|
|
|
cloth_values = batch["cloth_pure"].to(accelerator.device,dtype=vae.dtype) |
|
cloth_values = vae.encode(cloth_values).latent_dist.sample() |
|
cloth_values = cloth_values * vae.config.scaling_factor |
|
|
|
|
|
text_input_ids = tokenizer( |
|
batch['caption_cloth'], |
|
max_length=tokenizer.model_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
).input_ids |
|
text_input_ids_2 = tokenizer_2( |
|
batch['caption_cloth'], |
|
max_length=tokenizer_2.model_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
).input_ids |
|
|
|
|
|
encoder_output = text_encoder(text_input_ids.to(accelerator.device), output_hidden_states=True) |
|
text_embeds_cloth = encoder_output.hidden_states[-2] |
|
encoder_output_2 = text_encoder_2(text_input_ids_2.to(accelerator.device), output_hidden_states=True) |
|
text_embeds_2_cloth = encoder_output_2.hidden_states[-2] |
|
text_embeds_cloth = torch.concat([text_embeds_cloth, text_embeds_2_cloth], dim=-1) |
|
|
|
|
|
down,reference_features = unet_encoder(cloth_values,timesteps, text_embeds_cloth,return_dict=False) |
|
reference_features = list(reference_features) |
|
|
|
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states,added_cond_kwargs=unet_added_cond_kwargs,garment_features=reference_features).sample |
|
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon": |
|
target = noise |
|
elif noise_scheduler.config.prediction_type == "v_prediction": |
|
target = noise_scheduler.get_velocity(model_input, noise, timesteps) |
|
elif noise_scheduler.config.prediction_type == "sample": |
|
|
|
target = model_input |
|
|
|
model_pred = model_pred - noise |
|
else: |
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
|
|
|
|
|
if args.snr_gamma is None: |
|
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") |
|
else: |
|
|
|
|
|
|
|
snr = compute_snr(noise_scheduler, timesteps) |
|
if noise_scheduler.config.prediction_type == "v_prediction": |
|
|
|
snr = snr + 1 |
|
mse_loss_weights = ( |
|
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr |
|
) |
|
|
|
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none") |
|
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights |
|
loss = loss.mean() |
|
|
|
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() |
|
train_loss += avg_loss.item() / args.gradient_accumulation_steps |
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
|
if accelerator.sync_gradients: |
|
accelerator.clip_grad_norm_(params_to_opt, 1.0) |
|
|
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
progress_bar.update(1) |
|
global_step += 1 |
|
if accelerator.sync_gradients: |
|
progress_bar.update(1) |
|
global_step += 1 |
|
accelerator.log({"train_loss": train_loss}, step=global_step) |
|
train_loss = 0.0 |
|
logs = {"step_loss": loss.detach().item()} |
|
progress_bar.set_postfix(**logs) |
|
|
|
if global_step >= args.max_train_steps: |
|
break |
|
|
|
if global_step % args.checkpointing_epoch == 0: |
|
if accelerator.is_main_process: |
|
|
|
unwrapped_unet = accelerator.unwrap_model( |
|
unet, keep_fp32_wrapper=True |
|
) |
|
pipeline = TryonPipeline.from_pretrained( |
|
args.pretrained_model_name_or_path, |
|
unet=unwrapped_unet, |
|
vae= vae, |
|
scheduler=noise_scheduler, |
|
tokenizer=tokenizer, |
|
tokenizer_2=tokenizer_2, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_2, |
|
image_encoder=image_encoder, |
|
unet_encoder=unet_encoder, |
|
torch_dtype=torch.float16, |
|
add_watermarker=False, |
|
safety_checker=None, |
|
) |
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
|
pipeline.save_pretrained(save_path) |
|
del pipeline |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|