Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,561 Bytes
ad88a0b f0b9ada ad88a0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
from adaface.adaface_wrapper import AdaFaceWrapper
import torch
#import torch.nn.functional as F
from PIL import Image
import numpy as np
import os, argparse, glob, re, shutil
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def seed_everything(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PL_GLOBAL_SEED"] = str(seed)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+",
default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"],
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
help="CFG scales of output embeddings of the ID2Ada prompt encoders")
parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
default=['models/ensemble/rv4-unet', 'models/ensemble/ar18-unet'],
help="Extra paths to the checkpoints of the UNet models")
parser.add_argument('--unet_weights', type=float, nargs="+", default=[4, 2, 1],
help="Weights for the UNet models")
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
# If True, the input folder contains images of mixed subjects.
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
help="Whether the input folder contains images of mixed subjects")
parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
parser.add_argument("--noise", dest='perturb_std', type=float, default=0)
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
help="Guidance scale for the diffusion model")
parser.add_argument("--ref_img_strength", type=float, default=0.8,
help="Strength of the reference image in the output image.")
parser.add_argument("--subject_string",
type=str, default="z",
help="Subject placeholder string used in prompts to denote the concept.")
parser.add_argument("--prompt", type=str, default="a person z")
parser.add_argument("--num_images_per_row", type=int, default=4,
help="Number of images to display in a row in the output grid image.")
parser.add_argument("--num_inference_steps", type=int, default=50,
help="Number of DDIM inference steps")
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
parser.add_argument("--seed", type=int, default=42,
help="the seed (for reproducible sampling). Set to -1 to disable.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
if args.seed != -1:
seed_everything(args.seed)
# screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
# --adaface_ckpt_paths logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
# --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /path/to/VGGface2_HQ_masks/
# --is_mix_subj_folder 0 --out_folder /path/to/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
if args.num_gpus > 1:
from accelerate import PartialState
distributed_state = PartialState()
args.device = distributed_state.device
process_index = distributed_state.process_index
elif re.match(r"^\d+$", args.device):
args.device = f"cuda:{args.device}"
distributed_state = None
process_index = 0
adaface = AdaFaceWrapper("img2img", args.base_model_path,
args.adaface_encoder_types, args.adaface_ckpt_paths,
args.adaface_encoder_cfg_scales,
args.subject_string, args.num_inference_steps,
unet_types=None,
extra_unet_dirpaths=args.extra_unet_dirpaths, unet_weights=args.unet_weights,
device=args.device)
in_folder = args.in_folder
if os.path.isfile(in_folder):
subject_folders = [ os.path.dirname(in_folder) ]
images_by_subject = [[in_folder]]
else:
if not args.is_mix_subj_folder:
in_folders = [in_folder]
else:
in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
images_by_subject = []
subject_folders = []
for in_folder in in_folders:
image_types = ["*.jpg", "*.png", "*.jpeg"]
alltype_image_paths = []
for image_type in image_types:
# glob returns the full path.
image_paths = glob.glob(os.path.join(in_folder, image_type))
if len(image_paths) > 0:
alltype_image_paths.extend(image_paths)
# Filter out images of "*_mask.png"
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
alltype_image_paths = sorted(alltype_image_paths)
if not args.is_mix_subj_folder:
# image_paths contain at most args.max_images_per_subject full image paths.
if args.max_images_per_subject > 0:
image_paths = alltype_image_paths[:args.max_images_per_subject]
else:
image_paths = alltype_image_paths
images_by_subject.append(image_paths)
subject_folders.append(in_folder)
else:
# Each image in the folder is treated as an individual subject.
images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
subject_folders.extend([in_folder] * len(alltype_image_paths))
if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
break
if args.trans_subject_count > 0:
images_by_subject = images_by_subject[:args.trans_subject_count]
subject_folders = subject_folders[:args.trans_subject_count]
out_image_count = 0
out_mask_count = 0
if not args.out_folder.endswith("/"):
args.out_folder += "/"
if args.num_gpus > 1:
# Split the subjects across the GPUs.
subject_folders = subject_folders[process_index::args.num_gpus]
images_by_subject = images_by_subject[process_index::args.num_gpus]
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
# Otherwise, we use the folder name as the signature of the images.
images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
print(f"Translating {images_sig}...")
with torch.no_grad():
adaface_subj_embs = \
adaface.prepare_adaface_embeddings(image_paths, None,
perturb_at_stage='img_prompt_emb',
perturb_std=args.perturb_std,
update_text_encoder=True)
# Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
if not os.path.exists(subject_out_folder):
os.makedirs(subject_out_folder)
print(f"Output images will be saved to {subject_out_folder}")
in_images = []
for image_path in image_paths:
image = Image.open(image_path).convert("RGB").resize((512, 512))
# [512, 512, 3] -> [3, 512, 512].
image = np.array(image).transpose(2, 0, 1)
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
image = torch.tensor(image).unsqueeze(0).float().cuda()
in_images.append(image)
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
# NOTE: For simplicity, we do not check overly large batch sizes.
in_images = torch.cat(in_images, dim=0)
# in_images: [5, 3, 512, 512].
# Normalize the pixel values to [0, 1].
in_images = in_images / 255.0
num_out_images = len(in_images) * args.out_count_per_input_image
with torch.no_grad():
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
out_images = adaface(in_images, args.prompt, None, 'append', args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
for img_i, img in enumerate(out_images):
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
subj_i = img_i % len(in_images)
copy_i = img_i // len(in_images)
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
if copy_i == 0:
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
else:
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
if args.copy_masks:
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
if os.path.exists(mask_path):
if copy_i == 0:
shutil.copy(mask_path, subject_out_folder)
else:
mask_filename_stem = image_filename_stem
shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
out_mask_count += 1
out_image_count += len(out_images)
print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")
|