File size: 10,036 Bytes
ad88a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0b9ada
ad88a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import torch
from PIL import Image
import os, argparse, glob
import numpy as np
from .face_id_to_ada_prompt import create_id2ada_prompt_encoder
from .util            import create_consistentid_pipeline
from .arc2face_models import create_arc2face_pipeline
from transformers import CLIPTextModel

def save_images(images, subject_name, id2img_prompt_encoder_type,
                prompt, perturb_std, save_dir = "samples-ada"):
    os.makedirs(save_dir, exist_ok=True)
    # Save 4 images as a grid image in save_dir
    grid_image = Image.new('RGB', (512 * 2, 512 * 2))
    for i, image in enumerate(images):
        image = image.resize((512, 512))
        grid_image.paste(image, (512 * (i % 2), 512 * (i // 2)))

    prompt_sig = prompt.replace(" ", "_").replace(",", "_")
    grid_filepath = os.path.join(save_dir, 
                "-".join([subject_name, id2img_prompt_encoder_type, 
                          prompt_sig, f"perturb{perturb_std:.02f}.png"]))
    
    if os.path.exists(grid_filepath):
        grid_count = 2
        grid_filepath = os.path.join(save_dir, 
                        "-".join([ subject_name, id2img_prompt_encoder_type, 
                                   prompt_sig, f"perturb{perturb_std:.02f}", str(grid_count) ]) + ".png")
        while os.path.exists(grid_filepath):
            grid_count += 1
            grid_filepath = os.path.join(save_dir, 
                        "-".join([ subject_name, id2img_prompt_encoder_type, 
                                   prompt_sig, f"perturb{perturb_std:.02f}", str(grid_count) ]) + ".png")

    grid_image.save(grid_filepath)
    print(f"Saved to {grid_filepath}")

def seed_everything(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PL_GLOBAL_SEED"] = str(seed)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # --base_model_path models/Realistic_Vision_V4.0_noVAE
    parser.add_argument("--base_model_path", type=str, default="models/sar/sar.safetensors")    
    parser.add_argument("--id2img_prompt_encoder_type", type=str, 
                        choices=["arc2face", "consistentID"],
                        help="Types of the ID2Img prompt encoder")    
    parser.add_argument("--subject", type=str, default="subjects-celebrity/taylorswift")
    parser.add_argument("--example_image_count", type=int, default=5, help="Number of example images to use")
    parser.add_argument("--out_image_count",     type=int, default=4, help="Number of images to generate")
    parser.add_argument("--init_img", type=str, default=None)
    parser.add_argument("--prompt", type=str, default="portrait photo of a person in superman costume")
    parser.add_argument("--use_core_only", action="store_true")
    parser.add_argument("--truncate_prompt_at", type=int, default=-1,
                        help="Truncate the prompt to this length")
    parser.add_argument("--randface", action="store_true")
    parser.add_argument("--seed", type=int, default=-1)
    parser.add_argument("--perturb_std", type=float, default=1)

    args = parser.parse_args()
    if args.seed > 0:
        seed_everything(args.seed)

    if args.id2img_prompt_encoder_type == "arc2face":
        pipeline = create_arc2face_pipeline(args.base_model_path)
        use_teacher_neg = False
    elif args.id2img_prompt_encoder_type == "consistentID":
        pipeline = create_consistentid_pipeline(args.base_model_path)
        use_teacher_neg = True

    pipeline = pipeline.to('cuda', torch.float16)

    # When the second argument, adaface_ckpt_path = None, create_id2ada_prompt_encoder()
    # returns an id2ada_prompt_encoder object, with .subj_basis_generator uninitialized.
    # But it doesn't matter, as we don't use the subj_basis_generator to generate ada embeddings.
    id2img_prompt_encoder = create_id2ada_prompt_encoder([args.id2img_prompt_encoder_type],
                                                         num_static_img_suffix_embs=0)
    id2img_prompt_encoder.to('cuda')

    if not args.randface:
        image_folder = args.subject
        if image_folder.endswith("/"):
            image_folder = image_folder[:-1]

        if os.path.isfile(image_folder):
            # Get the second to the last part of the path
            subject_name = os.path.basename(os.path.dirname(image_folder))
            image_paths = [image_folder]

        else:
            subject_name = os.path.basename(image_folder)
            image_types = ["*.jpg", "*.png", "*.jpeg"]
            alltype_image_paths = []
            for image_type in image_types:
                # glob returns the full path.
                image_paths = glob.glob(os.path.join(image_folder, image_type))
                if len(image_paths) > 0:
                    alltype_image_paths.extend(image_paths)
            # image_paths contain at most args.example_image_count full image paths.
            image_paths = alltype_image_paths[:args.example_image_count]
    else:
        subject_name = None
        image_paths = None
        image_folder = None

    subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
    id_batch_size = args.out_image_count

    text_encoder = pipeline.text_encoder
    orig_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cuda") 

    noise = torch.randn(args.out_image_count, 4, 64, 64, device='cuda', dtype=torch.float16)
    if args.randface:
        init_id_embs = torch.randn(1, 512, device='cuda', dtype=torch.float16)
        if args.id2img_prompt_encoder_type == "arc2face":
            pre_clip_features = None
        elif args.id2img_prompt_encoder_type == "consistentID":
            # For ConsistentID, random clip features are much better than zero clip features.
            rand_clip_fgbg_features = torch.randn(1, 514, 1280, device='cuda', dtype=torch.float16)
            pre_clip_features       = rand_clip_fgbg_features
        else:
            breakpoint()
    else:
        init_id_embs = None
        pre_clip_features = None

    # perturb_std is the *relative* std of the noise added to the face ID embeddings.
    # For Arc2Face, a perturb_std of 0.08 could change gender, but 0.06 is usually safe.
    # For ConsistentID, the image prompt embeddings are extremely robust to noise,
    # and the perturb_std can be set to 0.5, only leading to a slight change in the result images.
    # Seems ConsistentID mainly relies on CLIP features, instead of the face ID embeddings.
    for perturb_std in (args.perturb_std, 0):
        # id_prompt_emb is in the image prompt space.
        # neg_id_prompt_emb is used in ConsistentID only.
        face_image_count, faceid_embeds, id_prompt_emb, neg_id_prompt_emb \
            = id2img_prompt_encoder.get_img_prompt_embs( \
                init_id_embs=init_id_embs,
                pre_clip_features=pre_clip_features,
                image_paths=image_paths,
                image_objs=None,
                id_batch_size=id_batch_size,
                perturb_at_stage='img_prompt_emb',
                perturb_std=perturb_std,
                avg_at_stage='id_emb',
                verbose=True)
        
        pipeline.text_encoder = orig_text_encoder

        comp_prompt     = args.prompt 
        negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
        # prompt_embeds_, negative_prompt_embeds_: [4, 77, 768]
        prompt_embeds_, negative_prompt_embeds_ = \
            pipeline.encode_prompt(comp_prompt, device='cuda', num_images_per_prompt=args.out_image_count,
                                   do_classifier_free_guidance=True, negative_prompt=negative_prompt)
        #pipeline.text_encoder = text_encoder
        # Append the id prompt embeddings to the prompt embeddings.
        # For arc2face, id_prompt_emb can be either pre- or post-pended.
        # But for ConsistentID, id_prompt_emb has to be **post-pended**. Otherwise, the result images are blank.

        full_negative_prompt_embeds_ = negative_prompt_embeds_
        if args.truncate_prompt_at >= 0:
            prompt_embeds_ = prompt_embeds_[:, :args.truncate_prompt_at]
            negative_prompt_embeds_ = negative_prompt_embeds_[:, :args.truncate_prompt_at]
            
        prompt_embeds_ = torch.cat([prompt_embeds_, id_prompt_emb], dim=1)
        M = id_prompt_emb.shape[1]
        
        if (not use_teacher_neg) or neg_id_prompt_emb is None:
            # For arc2face, neg_id_prompt_emb is None. So we concatenate the last M negative prompt embeddings,
            # to make the negative prompt embeddings have the same length as the prompt embeddings.
            negative_prompt_embeds_ = torch.cat([negative_prompt_embeds_, full_negative_prompt_embeds_[:, -M:]], dim=1)
        else:
            # NOTE: For ConsistentID, neg_id_prompt_emb has to be present in the negative prompt embeddings.
            # Otherwise, the result images are cartoonish.
            negative_prompt_embeds_ = torch.cat([negative_prompt_embeds_, neg_id_prompt_emb], dim=1)

        if args.use_core_only:
            prompt_embeds_ = id_prompt_emb
            if (not use_teacher_neg) or neg_id_prompt_emb is None:
                negative_prompt_embeds_ = full_negative_prompt_embeds_[:, :M]
            else:
                negative_prompt_embeds_ = neg_id_prompt_emb

        for guidance_scale in [6]:
            images = pipeline(latents=noise,
                              prompt_embeds=prompt_embeds_, 
                              negative_prompt_embeds=negative_prompt_embeds_, 
                              num_inference_steps=50, 
                              guidance_scale=guidance_scale, 
                              num_images_per_prompt=1).images

            save_images(images, subject_name, args.id2img_prompt_encoder_type, 
                        f"guide{guidance_scale}", perturb_std)