Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,036 Bytes
ad88a0b f0b9ada ad88a0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import torch
from PIL import Image
import os, argparse, glob
import numpy as np
from .face_id_to_ada_prompt import create_id2ada_prompt_encoder
from .util import create_consistentid_pipeline
from .arc2face_models import create_arc2face_pipeline
from transformers import CLIPTextModel
def save_images(images, subject_name, id2img_prompt_encoder_type,
prompt, perturb_std, save_dir = "samples-ada"):
os.makedirs(save_dir, exist_ok=True)
# Save 4 images as a grid image in save_dir
grid_image = Image.new('RGB', (512 * 2, 512 * 2))
for i, image in enumerate(images):
image = image.resize((512, 512))
grid_image.paste(image, (512 * (i % 2), 512 * (i // 2)))
prompt_sig = prompt.replace(" ", "_").replace(",", "_")
grid_filepath = os.path.join(save_dir,
"-".join([subject_name, id2img_prompt_encoder_type,
prompt_sig, f"perturb{perturb_std:.02f}.png"]))
if os.path.exists(grid_filepath):
grid_count = 2
grid_filepath = os.path.join(save_dir,
"-".join([ subject_name, id2img_prompt_encoder_type,
prompt_sig, f"perturb{perturb_std:.02f}", str(grid_count) ]) + ".png")
while os.path.exists(grid_filepath):
grid_count += 1
grid_filepath = os.path.join(save_dir,
"-".join([ subject_name, id2img_prompt_encoder_type,
prompt_sig, f"perturb{perturb_std:.02f}", str(grid_count) ]) + ".png")
grid_image.save(grid_filepath)
print(f"Saved to {grid_filepath}")
def seed_everything(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PL_GLOBAL_SEED"] = str(seed)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# --base_model_path models/Realistic_Vision_V4.0_noVAE
parser.add_argument("--base_model_path", type=str, default="models/sar/sar.safetensors")
parser.add_argument("--id2img_prompt_encoder_type", type=str,
choices=["arc2face", "consistentID"],
help="Types of the ID2Img prompt encoder")
parser.add_argument("--subject", type=str, default="subjects-celebrity/taylorswift")
parser.add_argument("--example_image_count", type=int, default=5, help="Number of example images to use")
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
parser.add_argument("--init_img", type=str, default=None)
parser.add_argument("--prompt", type=str, default="portrait photo of a person in superman costume")
parser.add_argument("--use_core_only", action="store_true")
parser.add_argument("--truncate_prompt_at", type=int, default=-1,
help="Truncate the prompt to this length")
parser.add_argument("--randface", action="store_true")
parser.add_argument("--seed", type=int, default=-1)
parser.add_argument("--perturb_std", type=float, default=1)
args = parser.parse_args()
if args.seed > 0:
seed_everything(args.seed)
if args.id2img_prompt_encoder_type == "arc2face":
pipeline = create_arc2face_pipeline(args.base_model_path)
use_teacher_neg = False
elif args.id2img_prompt_encoder_type == "consistentID":
pipeline = create_consistentid_pipeline(args.base_model_path)
use_teacher_neg = True
pipeline = pipeline.to('cuda', torch.float16)
# When the second argument, adaface_ckpt_path = None, create_id2ada_prompt_encoder()
# returns an id2ada_prompt_encoder object, with .subj_basis_generator uninitialized.
# But it doesn't matter, as we don't use the subj_basis_generator to generate ada embeddings.
id2img_prompt_encoder = create_id2ada_prompt_encoder([args.id2img_prompt_encoder_type],
num_static_img_suffix_embs=0)
id2img_prompt_encoder.to('cuda')
if not args.randface:
image_folder = args.subject
if image_folder.endswith("/"):
image_folder = image_folder[:-1]
if os.path.isfile(image_folder):
# Get the second to the last part of the path
subject_name = os.path.basename(os.path.dirname(image_folder))
image_paths = [image_folder]
else:
subject_name = os.path.basename(image_folder)
image_types = ["*.jpg", "*.png", "*.jpeg"]
alltype_image_paths = []
for image_type in image_types:
# glob returns the full path.
image_paths = glob.glob(os.path.join(image_folder, image_type))
if len(image_paths) > 0:
alltype_image_paths.extend(image_paths)
# image_paths contain at most args.example_image_count full image paths.
image_paths = alltype_image_paths[:args.example_image_count]
else:
subject_name = None
image_paths = None
image_folder = None
subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
id_batch_size = args.out_image_count
text_encoder = pipeline.text_encoder
orig_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cuda")
noise = torch.randn(args.out_image_count, 4, 64, 64, device='cuda', dtype=torch.float16)
if args.randface:
init_id_embs = torch.randn(1, 512, device='cuda', dtype=torch.float16)
if args.id2img_prompt_encoder_type == "arc2face":
pre_clip_features = None
elif args.id2img_prompt_encoder_type == "consistentID":
# For ConsistentID, random clip features are much better than zero clip features.
rand_clip_fgbg_features = torch.randn(1, 514, 1280, device='cuda', dtype=torch.float16)
pre_clip_features = rand_clip_fgbg_features
else:
breakpoint()
else:
init_id_embs = None
pre_clip_features = None
# perturb_std is the *relative* std of the noise added to the face ID embeddings.
# For Arc2Face, a perturb_std of 0.08 could change gender, but 0.06 is usually safe.
# For ConsistentID, the image prompt embeddings are extremely robust to noise,
# and the perturb_std can be set to 0.5, only leading to a slight change in the result images.
# Seems ConsistentID mainly relies on CLIP features, instead of the face ID embeddings.
for perturb_std in (args.perturb_std, 0):
# id_prompt_emb is in the image prompt space.
# neg_id_prompt_emb is used in ConsistentID only.
face_image_count, faceid_embeds, id_prompt_emb, neg_id_prompt_emb \
= id2img_prompt_encoder.get_img_prompt_embs( \
init_id_embs=init_id_embs,
pre_clip_features=pre_clip_features,
image_paths=image_paths,
image_objs=None,
id_batch_size=id_batch_size,
perturb_at_stage='img_prompt_emb',
perturb_std=perturb_std,
avg_at_stage='id_emb',
verbose=True)
pipeline.text_encoder = orig_text_encoder
comp_prompt = args.prompt
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
# prompt_embeds_, negative_prompt_embeds_: [4, 77, 768]
prompt_embeds_, negative_prompt_embeds_ = \
pipeline.encode_prompt(comp_prompt, device='cuda', num_images_per_prompt=args.out_image_count,
do_classifier_free_guidance=True, negative_prompt=negative_prompt)
#pipeline.text_encoder = text_encoder
# Append the id prompt embeddings to the prompt embeddings.
# For arc2face, id_prompt_emb can be either pre- or post-pended.
# But for ConsistentID, id_prompt_emb has to be **post-pended**. Otherwise, the result images are blank.
full_negative_prompt_embeds_ = negative_prompt_embeds_
if args.truncate_prompt_at >= 0:
prompt_embeds_ = prompt_embeds_[:, :args.truncate_prompt_at]
negative_prompt_embeds_ = negative_prompt_embeds_[:, :args.truncate_prompt_at]
prompt_embeds_ = torch.cat([prompt_embeds_, id_prompt_emb], dim=1)
M = id_prompt_emb.shape[1]
if (not use_teacher_neg) or neg_id_prompt_emb is None:
# For arc2face, neg_id_prompt_emb is None. So we concatenate the last M negative prompt embeddings,
# to make the negative prompt embeddings have the same length as the prompt embeddings.
negative_prompt_embeds_ = torch.cat([negative_prompt_embeds_, full_negative_prompt_embeds_[:, -M:]], dim=1)
else:
# NOTE: For ConsistentID, neg_id_prompt_emb has to be present in the negative prompt embeddings.
# Otherwise, the result images are cartoonish.
negative_prompt_embeds_ = torch.cat([negative_prompt_embeds_, neg_id_prompt_emb], dim=1)
if args.use_core_only:
prompt_embeds_ = id_prompt_emb
if (not use_teacher_neg) or neg_id_prompt_emb is None:
negative_prompt_embeds_ = full_negative_prompt_embeds_[:, :M]
else:
negative_prompt_embeds_ = neg_id_prompt_emb
for guidance_scale in [6]:
images = pipeline(latents=noise,
prompt_embeds=prompt_embeds_,
negative_prompt_embeds=negative_prompt_embeds_,
num_inference_steps=50,
guidance_scale=guidance_scale,
num_images_per_prompt=1).images
save_images(images, subject_name, args.id2img_prompt_encoder_type,
f"guide{guidance_scale}", perturb_std)
|