adaface-animate / adaface /adaface_translate.py
adaface-neurips
Integrate do_neg_id_prompt_weight, fix bugs, various refinements
f0b9ada
raw
history blame
12.6 kB
from adaface.adaface_wrapper import AdaFaceWrapper
import torch
#import torch.nn.functional as F
from PIL import Image
import numpy as np
import os, argparse, glob, re, shutil
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def seed_everything(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PL_GLOBAL_SEED"] = str(seed)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+",
default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"],
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
help="CFG scales of output embeddings of the ID2Ada prompt encoders")
parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
default=['models/ensemble/rv4-unet', 'models/ensemble/ar18-unet'],
help="Extra paths to the checkpoints of the UNet models")
parser.add_argument('--unet_weights', type=float, nargs="+", default=[4, 2, 1],
help="Weights for the UNet models")
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
# If True, the input folder contains images of mixed subjects.
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
help="Whether the input folder contains images of mixed subjects")
parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
parser.add_argument("--noise", dest='perturb_std', type=float, default=0)
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
help="Guidance scale for the diffusion model")
parser.add_argument("--ref_img_strength", type=float, default=0.8,
help="Strength of the reference image in the output image.")
parser.add_argument("--subject_string",
type=str, default="z",
help="Subject placeholder string used in prompts to denote the concept.")
parser.add_argument("--prompt", type=str, default="a person z")
parser.add_argument("--num_images_per_row", type=int, default=4,
help="Number of images to display in a row in the output grid image.")
parser.add_argument("--num_inference_steps", type=int, default=50,
help="Number of DDIM inference steps")
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
parser.add_argument("--seed", type=int, default=42,
help="the seed (for reproducible sampling). Set to -1 to disable.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
if args.seed != -1:
seed_everything(args.seed)
# screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
# --adaface_ckpt_paths logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
# --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /path/to/VGGface2_HQ_masks/
# --is_mix_subj_folder 0 --out_folder /path/to/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
if args.num_gpus > 1:
from accelerate import PartialState
distributed_state = PartialState()
args.device = distributed_state.device
process_index = distributed_state.process_index
elif re.match(r"^\d+$", args.device):
args.device = f"cuda:{args.device}"
distributed_state = None
process_index = 0
adaface = AdaFaceWrapper("img2img", args.base_model_path,
args.adaface_encoder_types, args.adaface_ckpt_paths,
args.adaface_encoder_cfg_scales,
args.subject_string, args.num_inference_steps,
unet_types=None,
extra_unet_dirpaths=args.extra_unet_dirpaths, unet_weights=args.unet_weights,
device=args.device)
in_folder = args.in_folder
if os.path.isfile(in_folder):
subject_folders = [ os.path.dirname(in_folder) ]
images_by_subject = [[in_folder]]
else:
if not args.is_mix_subj_folder:
in_folders = [in_folder]
else:
in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
images_by_subject = []
subject_folders = []
for in_folder in in_folders:
image_types = ["*.jpg", "*.png", "*.jpeg"]
alltype_image_paths = []
for image_type in image_types:
# glob returns the full path.
image_paths = glob.glob(os.path.join(in_folder, image_type))
if len(image_paths) > 0:
alltype_image_paths.extend(image_paths)
# Filter out images of "*_mask.png"
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
alltype_image_paths = sorted(alltype_image_paths)
if not args.is_mix_subj_folder:
# image_paths contain at most args.max_images_per_subject full image paths.
if args.max_images_per_subject > 0:
image_paths = alltype_image_paths[:args.max_images_per_subject]
else:
image_paths = alltype_image_paths
images_by_subject.append(image_paths)
subject_folders.append(in_folder)
else:
# Each image in the folder is treated as an individual subject.
images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
subject_folders.extend([in_folder] * len(alltype_image_paths))
if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
break
if args.trans_subject_count > 0:
images_by_subject = images_by_subject[:args.trans_subject_count]
subject_folders = subject_folders[:args.trans_subject_count]
out_image_count = 0
out_mask_count = 0
if not args.out_folder.endswith("/"):
args.out_folder += "/"
if args.num_gpus > 1:
# Split the subjects across the GPUs.
subject_folders = subject_folders[process_index::args.num_gpus]
images_by_subject = images_by_subject[process_index::args.num_gpus]
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
# Otherwise, we use the folder name as the signature of the images.
images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
print(f"Translating {images_sig}...")
with torch.no_grad():
adaface_subj_embs = \
adaface.prepare_adaface_embeddings(image_paths, None,
perturb_at_stage='img_prompt_emb',
perturb_std=args.perturb_std,
update_text_encoder=True)
# Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
if not os.path.exists(subject_out_folder):
os.makedirs(subject_out_folder)
print(f"Output images will be saved to {subject_out_folder}")
in_images = []
for image_path in image_paths:
image = Image.open(image_path).convert("RGB").resize((512, 512))
# [512, 512, 3] -> [3, 512, 512].
image = np.array(image).transpose(2, 0, 1)
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
image = torch.tensor(image).unsqueeze(0).float().cuda()
in_images.append(image)
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
# NOTE: For simplicity, we do not check overly large batch sizes.
in_images = torch.cat(in_images, dim=0)
# in_images: [5, 3, 512, 512].
# Normalize the pixel values to [0, 1].
in_images = in_images / 255.0
num_out_images = len(in_images) * args.out_count_per_input_image
with torch.no_grad():
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
out_images = adaface(in_images, args.prompt, None, 'append', args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
for img_i, img in enumerate(out_images):
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
subj_i = img_i % len(in_images)
copy_i = img_i // len(in_images)
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
if copy_i == 0:
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
else:
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
if args.copy_masks:
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
if os.path.exists(mask_path):
if copy_i == 0:
shutil.copy(mask_path, subject_out_folder)
else:
mask_filename_stem = image_filename_stem
shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
out_mask_count += 1
out_image_count += len(out_images)
print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")