Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,321 Bytes
ad88a0b f0b9ada ad88a0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from adaface.adaface_wrapper import AdaFaceWrapper
import torch
#import torch.nn.functional as F
from PIL import Image
import numpy as np
import os, argparse, glob, re
def save_images(images, num_images_per_row, subject_name, prompt, perturb_std, save_dir = "samples-ada"):
if num_images_per_row > len(images):
num_images_per_row = len(images)
os.makedirs(save_dir, exist_ok=True)
num_columns = int(np.ceil(len(images) / num_images_per_row))
# Save 4 images as a grid image in save_dir
grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
for i, image in enumerate(images):
image = image.resize((512, 512))
grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))
prompt_sig = prompt.replace(" ", "_").replace(",", "_")
grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}.png")
if os.path.exists(grid_filepath):
grid_count = 2
grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}-{grid_count}.png')
while os.path.exists(grid_filepath):
grid_count += 1
grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}-{grid_count}.png')
grid_image.save(grid_filepath)
print(f"Saved to {grid_filepath}")
def seed_everything(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PL_GLOBAL_SEED"] = str(seed)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--pipeline", type=str, default="text2img",
choices=["text2img", "img2img", "text2img3", "flux"],
help="Type of pipeline to use (default: txt2img)")
parser.add_argument("--base_model_path", type=str, default=None,
help="Type of checkpoints to use (default: None, using the official model)")
parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+",
default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"],
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
help="CFG scales of output embeddings of the ID2Ada prompt encoders")
parser.add_argument("--main_unet_filepath", type=str, default=None,
help="Path to the checkpoint of the main UNet model, if you want to replace the default UNet within --base_model_path")
parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
default=['models/ensemble/rv4-unet', 'models/ensemble/ar18-unet'],
help="Extra paths to the checkpoints of the UNet models")
parser.add_argument('--unet_weights', type=float, nargs="+", default=[4, 2, 1],
help="Weights for the UNet models")
parser.add_argument("--subject", type=str)
parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
parser.add_argument("--noise", dest='perturb_std', type=float, default=0)
parser.add_argument("--randface", action="store_true")
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
help="Guidance scale for the diffusion model")
parser.add_argument("--id_cfg_scale", type=float, default=6,
help="CFG scale when generating the identity embeddings")
parser.add_argument("--subject_string",
type=str, default="z",
help="Subject placeholder string used in prompts to denote the concept.")
parser.add_argument("--num_images_per_row", type=int, default=4,
help="Number of images to display in a row in the output grid image.")
parser.add_argument("--num_inference_steps", type=int, default=50,
help="Number of DDIM inference steps")
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
parser.add_argument("--seed", type=int, default=42,
help="the seed (for reproducible sampling). Set to -1 to disable.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
if args.seed != -1:
seed_everything(args.seed)
if re.match(r"^\d+$", args.device):
args.device = f"cuda:{args.device}"
print(f"Using device {args.device}")
if args.pipeline not in ["text2img", "img2img"]:
args.extra_unet_dirpaths = None
args.unet_weights = None
adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
args.adaface_encoder_types, args.adaface_ckpt_paths,
args.adaface_encoder_cfg_scales,
args.subject_string, args.num_inference_steps,
unet_types=None,
main_unet_filepath=args.main_unet_filepath,
extra_unet_dirpaths=args.extra_unet_dirpaths,
unet_weights=args.unet_weights, device=args.device)
if not args.randface:
image_folder = args.subject
if image_folder.endswith("/"):
image_folder = image_folder[:-1]
if os.path.isfile(image_folder):
# Get the second to the last part of the path
subject_name = os.path.basename(os.path.dirname(image_folder))
image_paths = [image_folder]
else:
subject_name = os.path.basename(image_folder)
image_types = ["*.jpg", "*.png", "*.jpeg"]
alltype_image_paths = []
for image_type in image_types:
# glob returns the full path.
image_paths = glob.glob(os.path.join(image_folder, image_type))
if len(image_paths) > 0:
alltype_image_paths.extend(image_paths)
# Filter out images of "*_mask.png"
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
# image_paths contain at most args.example_image_count full image paths.
if args.example_image_count > 0:
image_paths = alltype_image_paths[:args.example_image_count]
else:
image_paths = alltype_image_paths
else:
subject_name = None
image_paths = None
image_folder = None
subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
rand_init_id_embs = torch.randn(1, 512)
init_id_embs = rand_init_id_embs if args.randface else None
noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
# adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
adaface_subj_embs = \
adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
perturb_at_stage='img_prompt_emb',
perturb_std=args.perturb_std, update_text_encoder=True)
images = adaface(noise, args.prompt, None, 'append', args.guidance_scale, args.out_image_count, verbose=True)
save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
|