|
import argparse |
|
import os |
|
os.environ['CUDA_HOME'] = '/usr/local/cuda' |
|
os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin' |
|
from datetime import datetime |
|
|
|
import numpy as np |
|
import torch |
|
from diffusers.image_processor import VaeImageProcessor |
|
from huggingface_hub import snapshot_download |
|
from PIL import Image |
|
torch.jit.script = lambda f: f |
|
from model.cloth_masker2 import AutoMasker, vis_mask |
|
from model.pipeline import CatVTONPipeline |
|
from utils import init_weight_dtype, resize_and_crop, resize_and_padding |
|
|
|
|
|
import cv2 |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Simple example of a training script.") |
|
parser.add_argument( |
|
"--base_model_path", |
|
type=str, |
|
default="booksforcharlie/stable-diffusion-inpainting", |
|
|
|
help=( |
|
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub." |
|
), |
|
) |
|
parser.add_argument( |
|
"--resume_path", |
|
type=str, |
|
default="zhengchong/CatVTON", |
|
help=( |
|
"The Path to the checkpoint of trained tryon model." |
|
), |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="resource/demo/output", |
|
help="The output directory where the model predictions will be written.", |
|
) |
|
|
|
parser.add_argument( |
|
"--width", |
|
type=int, |
|
default=768, |
|
help=( |
|
"The resolution for input images, all the images in the train/validation dataset will be resized to this" |
|
" resolution" |
|
), |
|
) |
|
parser.add_argument( |
|
"--height", |
|
type=int, |
|
default=1024, |
|
help=( |
|
"The resolution for input images, all the images in the train/validation dataset will be resized to this" |
|
" resolution" |
|
), |
|
) |
|
parser.add_argument( |
|
"--repaint", |
|
action="store_true", |
|
help="Whether to repaint the result image with the original background." |
|
) |
|
parser.add_argument( |
|
"--allow_tf32", |
|
action="store_true", |
|
default=True, |
|
help=( |
|
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" |
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" |
|
), |
|
) |
|
parser.add_argument( |
|
"--mixed_precision", |
|
type=str, |
|
default="no", |
|
choices=["no", "fp16", "bf16"], |
|
help=( |
|
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" |
|
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" |
|
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." |
|
), |
|
) |
|
|
|
args = parser.parse_args() |
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
|
if env_local_rank != -1 and env_local_rank != args.local_rank: |
|
args.local_rank = env_local_rank |
|
|
|
return args |
|
|
|
args = parse_args() |
|
repo_path = snapshot_download(repo_id=args.resume_path) |
|
|
|
|
|
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True) |
|
automasker = AutoMasker( |
|
densepose_ckpt=os.path.join(repo_path, "DensePose"), |
|
schp_ckpt=os.path.join(repo_path, "SCHP"), |
|
device='cuda', |
|
) |
|
|
|
person_image = Image.open("./resource/demo/example/person/men/m_lvl0.png").convert("RGB") |
|
mask = automasker( |
|
person_image, |
|
'short dress' |
|
)['mask'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masked_person = vis_mask(person_image, mask) |
|
|
|
mask.save("./test_mask.png") |
|
masked_person.save("./test_masked_person.png") |
|
|