SunderAli17 commited on
Commit
67a498b
1 Parent(s): a4f6bc0

Create infer.py

Browse files
Files changed (1) hide show
  1. functions/infer.py +381 -0
functions/infer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from pipelines.lcm_single_step_scheduler import LCMSingleStepScheduler
8
+
9
+ from diffusers import DDPMScheduler
10
+
11
+ from module.ip_adapter.utils import load_adapter_to_pipe
12
+ from pipelines.sdxl_SAKBIR import SAKBIRPipeline
13
+
14
+
15
+ def name_unet_submodules(unet):
16
+ def recursive_find_module(name, module, end=False):
17
+ if end:
18
+ for sub_name, sub_module in module.named_children():
19
+ sub_module.full_name = f"{name}.{sub_name}"
20
+ return
21
+ if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
22
+ elif "resnets" in name: return
23
+ for sub_name, sub_module in module.named_children():
24
+ end = True if sub_name == "transformer_blocks" else False
25
+ recursive_find_module(f"{name}.{sub_name}", sub_module, end)
26
+
27
+ for name, module in unet.named_children():
28
+ recursive_find_module(name, module)
29
+
30
+
31
+ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
32
+ pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
33
+
34
+ w, h = input_image.size
35
+ if size is not None:
36
+ w_resize_new, h_resize_new = size
37
+ else:
38
+ # ratio = min_side / min(h, w)
39
+ # w, h = round(ratio*w), round(ratio*h)
40
+ ratio = max_side / max(h, w)
41
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
42
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
43
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
44
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
45
+
46
+ if pad_to_max_side:
47
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
48
+ offset_x = (max_side - w_resize_new) // 2
49
+ offset_y = (max_side - h_resize_new) // 2
50
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
51
+ input_image = Image.fromarray(res)
52
+ return input_image
53
+
54
+
55
+ def tensor_to_pil(images):
56
+ """
57
+ Convert image tensor or a batch of image tensors to PIL image(s).
58
+ """
59
+ images = images.clamp(0, 1)
60
+ images_np = images.detach().cpu().numpy()
61
+ if images_np.ndim == 4:
62
+ images_np = np.transpose(images_np, (0, 2, 3, 1))
63
+ elif images_np.ndim == 3:
64
+ images_np = np.transpose(images_np, (1, 2, 0))
65
+ images_np = images_np[None, ...]
66
+ images_np = (images_np * 255).round().astype("uint8")
67
+ if images_np.shape[-1] == 1:
68
+ # special case for grayscale (single channel) images
69
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np]
70
+ else:
71
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np]
72
+
73
+ return pil_images
74
+
75
+
76
+ def calc_mean_std(feat, eps=1e-5):
77
+ """Calculate mean and std for adaptive_instance_normalization.
78
+ Args:
79
+ feat (Tensor): 4D tensor.
80
+ eps (float): A small value added to the variance to avoid
81
+ divide-by-zero. Default: 1e-5.
82
+ """
83
+ size = feat.size()
84
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
85
+ b, c = size[:2]
86
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
87
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
88
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
89
+ return feat_mean, feat_std
90
+
91
+
92
+ def adaptive_instance_normalization(content_feat, style_feat):
93
+ size = content_feat.size()
94
+ style_mean, style_std = calc_mean_std(style_feat)
95
+ content_mean, content_std = calc_mean_std(content_feat)
96
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
97
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
98
+
99
+
100
+ def main(args, device):
101
+
102
+ # Load pretrained models.
103
+ pipe = InstantIRPipeline.from_pretrained(
104
+ args.sdxl_path,
105
+ torch_dtype=torch.float16,
106
+ )
107
+
108
+ # Image prompt projector.
109
+ print("Loading LQ-Adapter...")
110
+ load_adapter_to_pipe(
111
+ pipe,
112
+ args.adapter_model_path if args.adapter_model_path is not None else os.path.join(args.instantir_path, 'adapter.pt'),
113
+ args.vision_encoder_path,
114
+ use_clip_encoder=args.use_clip_encoder,
115
+ )
116
+
117
+ # Prepare previewer
118
+ previewer_lora_path = args.previewer_lora_path if args.previewer_lora_path is not None else args.instantir_path
119
+ if previewer_lora_path is not None:
120
+ lora_alpha = pipe.prepare_previewers(previewer_lora_path)
121
+ print(f"use lora alpha {lora_alpha}")
122
+ pipe.to(device=device, dtype=torch.float16)
123
+ pipe.scheduler = DDPMScheduler.from_pretrained(args.sdxl_path, subfolder="scheduler")
124
+ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
125
+
126
+ # Load weights.
127
+ print("Loading checkpoint...")
128
+ pretrained_state_dict = torch.load(os.path.join(args.instantir_path, "aggregator.pt"), map_location="cpu")
129
+ pipe.aggregator.load_state_dict(pretrained_state_dict)
130
+ pipe.aggregator.to(device, dtype=torch.float16)
131
+
132
+ #################### Restoration ####################
133
+
134
+ post_fix = f"_{args.post_fix}" if args.post_fix else ""
135
+ os.makedirs(f"{args.out_path}/{post_fix}", exist_ok=True)
136
+
137
+ processed_imgs = os.listdir(os.path.join(args.out_path, post_fix))
138
+ lq_files = []
139
+ lq_batch = []
140
+ if os.path.isfile(args.test_path):
141
+ all_inputs = [args.test_path.split("/")[-1]]
142
+ else:
143
+ all_inputs = os.listdir(args.test_path)
144
+ all_inputs.sort()
145
+ for file in all_inputs:
146
+ if file in processed_imgs:
147
+ print(f"Skip {file}")
148
+ continue
149
+ lq_batch.append(f"{file}")
150
+ if len(lq_batch) == args.batch_size:
151
+ lq_files.append(lq_batch)
152
+ lq_batch = []
153
+
154
+ if len(lq_batch) > 0:
155
+ lq_files.append(lq_batch)
156
+
157
+ for lq_batch in lq_files:
158
+ generator = torch.Generator(device=device).manual_seed(args.seed)
159
+ pil_lqs = [Image.open(os.path.join(args.test_path, file)) for file in lq_batch]
160
+ if args.width is None or args.height is None:
161
+ lq = [resize_img(pil_lq.convert("RGB"), size=None) for pil_lq in pil_lqs]
162
+ else:
163
+ lq = [resize_img(pil_lq.convert("RGB"), size=(args.width, args.height)) for pil_lq in pil_lqs]
164
+ timesteps = None
165
+ if args.denoising_start < 1000:
166
+ timesteps = [
167
+ i * (args.denoising_start//args.num_inference_steps) + pipe.scheduler.config.steps_offset for i in range(0, args.num_inference_steps)
168
+ ]
169
+ timesteps = timesteps[::-1]
170
+ pipe.scheduler.set_timesteps(args.num_inference_steps, device)
171
+ timesteps = pipe.scheduler.timesteps
172
+ if args.prompt is None or len(args.prompt) == 0:
173
+ prompt = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
174
+ ultra HD, extreme meticulous detailing, skin pore detailing, \
175
+ hyper sharpness, perfect without deformations, \
176
+ taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. "
177
+ else:
178
+ prompt = args.prompt
179
+ if not isinstance(prompt, list):
180
+ prompt = [prompt]
181
+ prompt = prompt*len(lq)
182
+ if args.neg_prompt is None or len(args.neg_prompt) == 0:
183
+ neg_prompt = "blurry, out of focus, unclear, depth of field, over-smooth, \
184
+ sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \
185
+ dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
186
+ watermark, signature, jpeg artifacts, deformed, lowres"
187
+ else:
188
+ neg_prompt = args.neg_prompt
189
+ if not isinstance(neg_prompt, list):
190
+ neg_prompt = [neg_prompt]
191
+ neg_prompt = neg_prompt*len(lq)
192
+ image = pipe(
193
+ prompt=prompt,
194
+ image=lq,
195
+ num_inference_steps=args.num_inference_steps,
196
+ generator=generator,
197
+ timesteps=timesteps,
198
+ negative_prompt=neg_prompt,
199
+ guidance_scale=args.cfg,
200
+ previewer_scheduler=lcm_scheduler,
201
+ preview_start=args.preview_start,
202
+ control_guidance_end=args.creative_start,
203
+ ).images
204
+
205
+ if args.save_preview_row:
206
+ for i, lcm_image in enumerate(image[1]):
207
+ lcm_image.save(f"./lcm/{i}.png")
208
+ for i, rec_image in enumerate(image):
209
+ rec_image.save(f"{args.out_path}/{post_fix}/{lq_batch[i]}")
210
+
211
+
212
+ if __name__ == "__main__":
213
+ parser = argparse.ArgumentParser(description="InstantIR pipeline")
214
+ parser.add_argument(
215
+ "--sdxl_path",
216
+ type=str,
217
+ default=None,
218
+ required=True,
219
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
220
+ )
221
+ parser.add_argument(
222
+ "--previewer_lora_path",
223
+ type=str,
224
+ default=None,
225
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
226
+ )
227
+ parser.add_argument(
228
+ "--pretrained_vae_model_name_or_path",
229
+ type=str,
230
+ default=None,
231
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
232
+ )
233
+ parser.add_argument(
234
+ "--instantir_path",
235
+ type=str,
236
+ default=None,
237
+ required=True,
238
+ help="Path to pretrained instantir model.",
239
+ )
240
+ parser.add_argument(
241
+ "--vision_encoder_path",
242
+ type=str,
243
+ default='/share/huangrenyuan/model_zoo/vis_backbone/dinov2_large',
244
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
245
+ )
246
+ parser.add_argument(
247
+ "--adapter_model_path",
248
+ type=str,
249
+ default=None,
250
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
251
+ )
252
+ parser.add_argument(
253
+ "--adapter_tokens",
254
+ type=int,
255
+ default=64,
256
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
257
+ )
258
+ parser.add_argument(
259
+ "--use_clip_encoder",
260
+ action="store_true",
261
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
262
+ )
263
+ parser.add_argument(
264
+ "--denoising_start",
265
+ type=int,
266
+ default=1000,
267
+ help="Diffusion start timestep."
268
+ )
269
+ parser.add_argument(
270
+ "--num_inference_steps",
271
+ type=int,
272
+ default=30,
273
+ help="Diffusion steps."
274
+ )
275
+ parser.add_argument(
276
+ "--creative_start",
277
+ type=float,
278
+ default=1.0,
279
+ help="Proportion of timesteps for creative restoration. 1.0 means no creative restoration while 0.0 means completely free rendering."
280
+ )
281
+ parser.add_argument(
282
+ "--preview_start",
283
+ type=float,
284
+ default=0.0,
285
+ help="Proportion of timesteps to stop previewing at the begining to enhance fidelity to input."
286
+ )
287
+ parser.add_argument(
288
+ "--resolution",
289
+ type=int,
290
+ default=1024,
291
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
292
+ )
293
+ parser.add_argument(
294
+ "--batch_size",
295
+ type=int,
296
+ default=6,
297
+ help="Test batch size."
298
+ )
299
+ parser.add_argument(
300
+ "--width",
301
+ type=int,
302
+ default=None,
303
+ help="Output image width."
304
+ )
305
+ parser.add_argument(
306
+ "--height",
307
+ type=int,
308
+ default=None,
309
+ help="Output image height."
310
+ )
311
+ parser.add_argument(
312
+ "--cfg",
313
+ type=float,
314
+ default=7.0,
315
+ help="Scale of Classifier-Free-Guidance (CFG).",
316
+ )
317
+ parser.add_argument(
318
+ "--post_fix",
319
+ type=str,
320
+ default=None,
321
+ help="Subfolder name for restoration output under the output directory.",
322
+ )
323
+ parser.add_argument(
324
+ "--variant",
325
+ type=str,
326
+ default='fp16',
327
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
328
+ )
329
+ parser.add_argument(
330
+ "--revision",
331
+ type=str,
332
+ default=None,
333
+ required=False,
334
+ help="Revision of pretrained model identifier from huggingface.co/models.",
335
+ )
336
+ parser.add_argument(
337
+ "--save_preview_row",
338
+ action="store_true",
339
+ help="Whether or not to save the intermediate lcm outputs.",
340
+ )
341
+ parser.add_argument(
342
+ "--prompt",
343
+ type=str,
344
+ default='',
345
+ nargs="+",
346
+ help=(
347
+ "A set of prompts for creative restoration. Provide either a matching number of test images,"
348
+ " or a single prompt to be used with all inputs."
349
+ ),
350
+ )
351
+ parser.add_argument(
352
+ "--neg_prompt",
353
+ type=str,
354
+ default='',
355
+ nargs="+",
356
+ help=(
357
+ "A set of negative prompts for creative restoration. Provide either a matching number of test images,"
358
+ " or a single negative prompt to be used with all inputs."
359
+ ),
360
+ )
361
+ parser.add_argument(
362
+ "--test_path",
363
+ type=str,
364
+ default=None,
365
+ required=True,
366
+ help="Test directory.",
367
+ )
368
+ parser.add_argument(
369
+ "--out_path",
370
+ type=str,
371
+ default="./output",
372
+ help="Output directory.",
373
+ )
374
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
375
+ args = parser.parse_args()
376
+ args.height = args.height or args.width
377
+ args.width = args.width or args.height
378
+ if args.height is not None and (args.width % 64 != 0 or args.height % 64 != 0):
379
+ raise ValueError("Image resolution must be divisible by 64.")
380
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
381
+ main(args, device)