SunderAli17 commited on
Commit
a3c492a
1 Parent(s): fb33b84

Create train_stage1_adapter.py

Browse files
Files changed (1) hide show
  1. functions/train_stage1_adapter.py +1259 -0
functions/train_stage1_adapter.py ADDED
@@ -0,0 +1,1259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import time
19
+ import gc
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import jsonlines
25
+ import functools
26
+ import shutil
27
+ import pyrallis
28
+ import itertools
29
+ from pathlib import Path
30
+ from collections import namedtuple, OrderedDict
31
+
32
+ import accelerate
33
+ import numpy as np
34
+ import torch
35
+ import torch.nn.functional as F
36
+ import torch.utils.checkpoint
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
41
+ from datasets import load_dataset
42
+ from packaging import version
43
+ from PIL import Image
44
+ from data.data_config import DataConfig
45
+ from basicsr.utils.degradation_pipeline import RealESRGANDegradation
46
+ from losses.loss_config import LossesConfig
47
+ from losses.losses import *
48
+ from torchvision import transforms
49
+ from torchvision.transforms.functional import crop
50
+ from tqdm.auto import tqdm
51
+ from transformers import (
52
+ AutoTokenizer,
53
+ PretrainedConfig,
54
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
55
+ AutoImageProcessor, AutoModel)
56
+
57
+ import diffusers
58
+ from diffusers import (
59
+ AutoencoderKL,
60
+ AutoencoderTiny,
61
+ DDPMScheduler,
62
+ StableDiffusionXLPipeline,
63
+ UNet2DConditionModel,
64
+ )
65
+ from diffusers.optimization import get_scheduler
66
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
67
+ from diffusers.utils.import_utils import is_xformers_available
68
+ from diffusers.utils.torch_utils import is_compiled_module
69
+
70
+ from pipelines.lcm_single_step_scheduler import LCMSingleStepScheduler
71
+ from utils.train_utils import (
72
+ seperate_ip_params_from_unet,
73
+ import_model_class_from_model_name_or_path,
74
+ tensor_to_pil,
75
+ get_train_dataset, prepare_train_dataset, collate_fn,
76
+ encode_prompt, importance_sampling_fn, extract_into_tensor
77
+ )
78
+ from module.ip_adapter.resampler import Resampler
79
+ from module.ip_adapter.attention_processor import init_attn_proc
80
+ from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
81
+
82
+
83
+ if is_wandb_available():
84
+ import wandb
85
+
86
+
87
+ logger = get_logger(__name__)
88
+
89
+
90
+ def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
91
+ scheduler, image_encoder, image_processor, deg_pipeline,
92
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
93
+ logger.info("Running validation... ")
94
+
95
+ image_logs = []
96
+
97
+ lq = [Image.open(lq_example) for lq_example in args.validation_image]
98
+
99
+ pipe = StableDiffusionXLPipeline(
100
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
101
+ unet, scheduler, image_encoder, image_processor,
102
+ ).to(accelerator.device)
103
+
104
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
105
+ image = pipe(
106
+ prompt=[""]*len(lq),
107
+ ip_adapter_image=[lq],
108
+ num_inference_steps=20,
109
+ generator=generator,
110
+ guidance_scale=5.0,
111
+ height=args.resolution,
112
+ width=args.resolution,
113
+ ).images
114
+
115
+ if log_local:
116
+ for i, img in enumerate(tensor_to_pil(lq_img)):
117
+ img.save(f"./lq_{i}.png")
118
+ for i, img in enumerate(tensor_to_pil(gt_img)):
119
+ img.save(f"./gt_{i}.png")
120
+ for i, img in enumerate(image):
121
+ img.save(f"./lq_IPA_{i}.png")
122
+ return
123
+
124
+ tracker_key = "test" if is_final_validation else "validation"
125
+ for tracker in accelerator.trackers:
126
+ if tracker.name == "tensorboard":
127
+ images = [np.asarray(pil_img) for pil_img in image]
128
+ images = np.stack(images, axis=0)
129
+ if lq_img is not None and gt_img is not None:
130
+ input_lq = lq_img.detach().cpu()
131
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
132
+ input_gt = gt_img.detach().cpu()
133
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
134
+ tracker.writer.add_images("lq", input_lq[0], step, dataformats="CHW")
135
+ tracker.writer.add_images("gt", input_gt[0], step, dataformats="CHW")
136
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
137
+ elif tracker.name == "wandb":
138
+ raise NotImplementedError("Wandb logging not implemented for validation.")
139
+ formatted_images = []
140
+
141
+ for log in image_logs:
142
+ images = log["images"]
143
+ validation_prompt = log["validation_prompt"]
144
+ validation_image = log["validation_image"]
145
+
146
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
147
+
148
+ for image in images:
149
+ image = wandb.Image(image, caption=validation_prompt)
150
+ formatted_images.append(image)
151
+
152
+ tracker.log({tracker_key: formatted_images})
153
+ else:
154
+ logger.warning(f"image logging not implemented for {tracker.name}")
155
+
156
+ gc.collect()
157
+ torch.cuda.empty_cache()
158
+
159
+ return image_logs
160
+
161
+
162
+ def parse_args(input_args=None):
163
+ parser = argparse.ArgumentParser(description="InstantIR stage-1 training.")
164
+ parser.add_argument(
165
+ "--pretrained_model_name_or_path",
166
+ type=str,
167
+ default=None,
168
+ required=True,
169
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
170
+ )
171
+ parser.add_argument(
172
+ "--pretrained_vae_model_name_or_path",
173
+ type=str,
174
+ default=None,
175
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
176
+ )
177
+ parser.add_argument(
178
+ "--feature_extractor_path",
179
+ type=str,
180
+ default=None,
181
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
182
+ )
183
+ parser.add_argument(
184
+ "--pretrained_adapter_model_path",
185
+ type=str,
186
+ default=None,
187
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
188
+ )
189
+ parser.add_argument(
190
+ "--adapter_tokens",
191
+ type=int,
192
+ default=64,
193
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
194
+ )
195
+ parser.add_argument(
196
+ "--use_clip_encoder",
197
+ action="store_true",
198
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
199
+ )
200
+ parser.add_argument(
201
+ "--image_encoder_hidden_feature",
202
+ action="store_true",
203
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
204
+ )
205
+ parser.add_argument(
206
+ "--losses_config_path",
207
+ type=str,
208
+ required=True,
209
+ default='config_files/losses.yaml'
210
+ help=("A yaml file containing losses to use and their weights."),
211
+ )
212
+ parser.add_argument(
213
+ "--data_config_path",
214
+ type=str,
215
+ default='config_files/IR_dataset.yaml',
216
+ help=("A folder containing the training data. "),
217
+ )
218
+ parser.add_argument(
219
+ "--variant",
220
+ type=str,
221
+ default=None,
222
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
223
+ )
224
+ parser.add_argument(
225
+ "--revision",
226
+ type=str,
227
+ default=None,
228
+ required=False,
229
+ help="Revision of pretrained model identifier from huggingface.co/models.",
230
+ )
231
+ parser.add_argument(
232
+ "--tokenizer_name",
233
+ type=str,
234
+ default=None,
235
+ help="Pretrained tokenizer name or path if not the same as model_name",
236
+ )
237
+ parser.add_argument(
238
+ "--output_dir",
239
+ type=str,
240
+ default="stage1_model",
241
+ help="The output directory where the model predictions and checkpoints will be written.",
242
+ )
243
+ parser.add_argument(
244
+ "--cache_dir",
245
+ type=str,
246
+ default=None,
247
+ help="The directory where the downloaded models and datasets will be stored.",
248
+ )
249
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
250
+ parser.add_argument(
251
+ "--resolution",
252
+ type=int,
253
+ default=512,
254
+ help=(
255
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
256
+ " resolution"
257
+ ),
258
+ )
259
+ parser.add_argument(
260
+ "--crops_coords_top_left_h",
261
+ type=int,
262
+ default=0,
263
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
264
+ )
265
+ parser.add_argument(
266
+ "--crops_coords_top_left_w",
267
+ type=int,
268
+ default=0,
269
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
270
+ )
271
+ parser.add_argument(
272
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
273
+ )
274
+ parser.add_argument("--num_train_epochs", type=int, default=1)
275
+ parser.add_argument(
276
+ "--max_train_steps",
277
+ type=int,
278
+ default=None,
279
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
280
+ )
281
+ parser.add_argument(
282
+ "--checkpointing_steps",
283
+ type=int,
284
+ default=2000,
285
+ help=(
286
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
287
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
288
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
289
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
290
+ "instructions."
291
+ ),
292
+ )
293
+ parser.add_argument(
294
+ "--checkpoints_total_limit",
295
+ type=int,
296
+ default=5,
297
+ help=("Max number of checkpoints to store."),
298
+ )
299
+ parser.add_argument(
300
+ "--resume_from_checkpoint",
301
+ type=str,
302
+ default=None,
303
+ help=(
304
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
305
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
306
+ ),
307
+ )
308
+ parser.add_argument(
309
+ "--gradient_accumulation_steps",
310
+ type=int,
311
+ default=1,
312
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
313
+ )
314
+ parser.add_argument(
315
+ "--gradient_checkpointing",
316
+ action="store_true",
317
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
318
+ )
319
+ parser.add_argument(
320
+ "--save_only_adapter",
321
+ action="store_true",
322
+ help="Only save extra adapter to save space.",
323
+ )
324
+ parser.add_argument(
325
+ "--importance_sampling",
326
+ action="store_true",
327
+ help="Whether or not to use importance sampling.",
328
+ )
329
+ parser.add_argument(
330
+ "--learning_rate",
331
+ type=float,
332
+ default=1e-4,
333
+ help="Initial learning rate (after the potential warmup period) to use.",
334
+ )
335
+ parser.add_argument(
336
+ "--scale_lr",
337
+ action="store_true",
338
+ default=False,
339
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
340
+ )
341
+ parser.add_argument(
342
+ "--lr_scheduler",
343
+ type=str,
344
+ default="constant",
345
+ help=(
346
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
347
+ ' "constant", "constant_with_warmup"]'
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
352
+ )
353
+ parser.add_argument(
354
+ "--lr_num_cycles",
355
+ type=int,
356
+ default=1,
357
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
358
+ )
359
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
360
+ parser.add_argument(
361
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
362
+ )
363
+ parser.add_argument(
364
+ "--dataloader_num_workers",
365
+ type=int,
366
+ default=0,
367
+ help=(
368
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
369
+ ),
370
+ )
371
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
372
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
373
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
374
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
375
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
376
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
377
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
378
+ parser.add_argument(
379
+ "--hub_model_id",
380
+ type=str,
381
+ default=None,
382
+ help="The name of the repository to keep in sync with the local `output_dir`.",
383
+ )
384
+ parser.add_argument(
385
+ "--logging_dir",
386
+ type=str,
387
+ default="logs",
388
+ help=(
389
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
390
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
391
+ ),
392
+ )
393
+ parser.add_argument(
394
+ "--allow_tf32",
395
+ action="store_true",
396
+ help=(
397
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
398
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
399
+ ),
400
+ )
401
+ parser.add_argument(
402
+ "--report_to",
403
+ type=str,
404
+ default="tensorboard",
405
+ help=(
406
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
407
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
408
+ ),
409
+ )
410
+ parser.add_argument(
411
+ "--mixed_precision",
412
+ type=str,
413
+ default=None,
414
+ choices=["no", "fp16", "bf16"],
415
+ help=(
416
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
417
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
418
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
419
+ ),
420
+ )
421
+ parser.add_argument(
422
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
423
+ )
424
+ parser.add_argument(
425
+ "--set_grads_to_none",
426
+ action="store_true",
427
+ help=(
428
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
429
+ " behaviors, so disable this argument if it causes any problems. More info:"
430
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
431
+ ),
432
+ )
433
+ parser.add_argument(
434
+ "--dataset_name",
435
+ type=str,
436
+ default=None,
437
+ help=(
438
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
439
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
440
+ " or to a folder containing files that 🤗 Datasets can understand."
441
+ ),
442
+ )
443
+ parser.add_argument(
444
+ "--dataset_config_name",
445
+ type=str,
446
+ default=None,
447
+ help="The config of the Dataset, leave as None if there's only one config.",
448
+ )
449
+ parser.add_argument(
450
+ "--train_data_dir",
451
+ type=str,
452
+ default=None,
453
+ help=(
454
+ "A folder containing the training data. Folder contents must follow the structure described in"
455
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
456
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
457
+ ),
458
+ )
459
+ parser.add_argument(
460
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
461
+ )
462
+ parser.add_argument(
463
+ "--conditioning_image_column",
464
+ type=str,
465
+ default="conditioning_image",
466
+ help="The column of the dataset containing the controlnet conditioning image.",
467
+ )
468
+ parser.add_argument(
469
+ "--caption_column",
470
+ type=str,
471
+ default="text",
472
+ help="The column of the dataset containing a caption or a list of captions.",
473
+ )
474
+ parser.add_argument(
475
+ "--max_train_samples",
476
+ type=int,
477
+ default=None,
478
+ help=(
479
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
480
+ "value if set."
481
+ ),
482
+ )
483
+ parser.add_argument(
484
+ "--text_drop_rate",
485
+ type=float,
486
+ default=0.05,
487
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
488
+ )
489
+ parser.add_argument(
490
+ "--image_drop_rate",
491
+ type=float,
492
+ default=0.05,
493
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
494
+ )
495
+ parser.add_argument(
496
+ "--cond_drop_rate",
497
+ type=float,
498
+ default=0.05,
499
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
500
+ )
501
+ parser.add_argument(
502
+ "--sanity_check",
503
+ action="store_true",
504
+ help=(
505
+ "sanity check"
506
+ ),
507
+ )
508
+ parser.add_argument(
509
+ "--validation_prompt",
510
+ type=str,
511
+ default=None,
512
+ nargs="+",
513
+ help=(
514
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
515
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
516
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
517
+ ),
518
+ )
519
+ parser.add_argument(
520
+ "--validation_image",
521
+ type=str,
522
+ default=None,
523
+ nargs="+",
524
+ help=(
525
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
526
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
527
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
528
+ " `--validation_image` that will be used with all `--validation_prompt`s."
529
+ ),
530
+ )
531
+ parser.add_argument(
532
+ "--num_validation_images",
533
+ type=int,
534
+ default=4,
535
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
536
+ )
537
+ parser.add_argument(
538
+ "--validation_steps",
539
+ type=int,
540
+ default=3000,
541
+ help=(
542
+ "Run validation every X steps. Validation consists of running the prompt"
543
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
544
+ " and logging the images."
545
+ ),
546
+ )
547
+ parser.add_argument(
548
+ "--tracker_project_name",
549
+ type=str,
550
+ default="instantir_stage1",
551
+ help=(
552
+ "The `project_name` argument passed to Accelerator.init_trackers for"
553
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
554
+ ),
555
+ )
556
+
557
+ if input_args is not None:
558
+ args = parser.parse_args(input_args)
559
+ else:
560
+ args = parser.parse_args()
561
+
562
+ # if args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None:
563
+ # raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
564
+
565
+ if args.dataset_name is not None and args.train_data_dir is not None:
566
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
567
+
568
+ if args.text_drop_rate < 0 or args.text_drop_rate > 1:
569
+ raise ValueError("`--text_drop_rate` must be in the range [0, 1].")
570
+
571
+ if args.validation_prompt is not None and args.validation_image is None:
572
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
573
+
574
+ if args.validation_prompt is None and args.validation_image is not None:
575
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
576
+
577
+ if (
578
+ args.validation_image is not None
579
+ and args.validation_prompt is not None
580
+ and len(args.validation_image) != 1
581
+ and len(args.validation_prompt) != 1
582
+ and len(args.validation_image) != len(args.validation_prompt)
583
+ ):
584
+ raise ValueError(
585
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
586
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
587
+ )
588
+
589
+ if args.resolution % 8 != 0:
590
+ raise ValueError(
591
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
592
+ )
593
+
594
+ return args
595
+
596
+
597
+ def main(args):
598
+ if args.report_to == "wandb" and args.hub_token is not None:
599
+ raise ValueError(
600
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
601
+ " Please use `huggingface-cli login` to authenticate with the Hub."
602
+ )
603
+
604
+ logging_dir = Path(args.output_dir, args.logging_dir)
605
+
606
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
607
+ # due to pytorch#99272, MPS does not yet support bfloat16.
608
+ raise ValueError(
609
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
610
+ )
611
+
612
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
613
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
614
+ accelerator = Accelerator(
615
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
616
+ mixed_precision=args.mixed_precision,
617
+ log_with=args.report_to,
618
+ project_config=accelerator_project_config,
619
+ # kwargs_handlers=[kwargs],
620
+ )
621
+
622
+ # Make one log on every process with the configuration for debugging.
623
+ logging.basicConfig(
624
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
625
+ datefmt="%m/%d/%Y %H:%M:%S",
626
+ level=logging.INFO,
627
+ )
628
+ logger.info(accelerator.state, main_process_only=False)
629
+ if accelerator.is_local_main_process:
630
+ transformers.utils.logging.set_verbosity_warning()
631
+ diffusers.utils.logging.set_verbosity_info()
632
+ else:
633
+ transformers.utils.logging.set_verbosity_error()
634
+ diffusers.utils.logging.set_verbosity_error()
635
+
636
+ # If passed along, set the training seed now.
637
+ if args.seed is not None:
638
+ set_seed(args.seed)
639
+
640
+ # Handle the repository creation.
641
+ if accelerator.is_main_process:
642
+ if args.output_dir is not None:
643
+ os.makedirs(args.output_dir, exist_ok=True)
644
+
645
+ # Load scheduler and models
646
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
647
+ # Importance sampling.
648
+ list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64')
649
+ prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5)
650
+ importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps
651
+ importance_ratio = torch.from_numpy(importance_ratio.copy()).float()
652
+
653
+ # Load the tokenizers
654
+ tokenizer = AutoTokenizer.from_pretrained(
655
+ args.pretrained_model_name_or_path,
656
+ subfolder="tokenizer",
657
+ revision=args.revision,
658
+ use_fast=False,
659
+ )
660
+ tokenizer_2 = AutoTokenizer.from_pretrained(
661
+ args.pretrained_model_name_or_path,
662
+ subfolder="tokenizer_2",
663
+ revision=args.revision,
664
+ use_fast=False,
665
+ )
666
+
667
+ # Text encoder and image encoder.
668
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
669
+ args.pretrained_model_name_or_path, args.revision
670
+ )
671
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
672
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
673
+ )
674
+ text_encoder = text_encoder_cls_one.from_pretrained(
675
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
676
+ )
677
+ text_encoder_2 = text_encoder_cls_two.from_pretrained(
678
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
679
+ )
680
+ if args.use_clip_encoder:
681
+ image_processor = CLIPImageProcessor()
682
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
683
+ else:
684
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
685
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
686
+
687
+ # VAE.
688
+ vae_path = (
689
+ args.pretrained_model_name_or_path
690
+ if args.pretrained_vae_model_name_or_path is None
691
+ else args.pretrained_vae_model_name_or_path
692
+ )
693
+ vae = AutoencoderKL.from_pretrained(
694
+ vae_path,
695
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
696
+ revision=args.revision,
697
+ variant=args.variant,
698
+ )
699
+
700
+ # UNet.
701
+ unet = UNet2DConditionModel.from_pretrained(
702
+ args.pretrained_model_name_or_path,
703
+ subfolder="unet",
704
+ revision=args.revision,
705
+ variant=args.variant
706
+ )
707
+
708
+ pipe = StableDiffusionXLPipeline.from_pretrained(
709
+ args.pretrained_model_name_or_path,
710
+ unet=unet,
711
+ text_encoder=text_encoder,
712
+ text_encoder_2=text_encoder_2,
713
+ vae=vae,
714
+ tokenizer=tokenizer,
715
+ tokenizer_2=tokenizer_2,
716
+ variant=args.variant
717
+ )
718
+
719
+ # Resampler for project model in IP-Adapter
720
+ image_proj_model = Resampler(
721
+ dim=1280,
722
+ depth=4,
723
+ dim_head=64,
724
+ heads=20,
725
+ num_queries=args.adapter_tokens,
726
+ embedding_dim=image_encoder.config.hidden_size,
727
+ output_dim=unet.config.cross_attention_dim,
728
+ ff_mult=4
729
+ )
730
+
731
+ init_adapter_in_unet(
732
+ unet,
733
+ image_proj_model,
734
+ os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'),
735
+ adapter_tokens=args.adapter_tokens,
736
+ )
737
+
738
+ # Initialize training state.
739
+ vae.requires_grad_(False)
740
+ text_encoder.requires_grad_(False)
741
+ text_encoder_2.requires_grad_(False)
742
+ unet.requires_grad_(False)
743
+ image_encoder.requires_grad_(False)
744
+
745
+ def unwrap_model(model):
746
+ model = accelerator.unwrap_model(model)
747
+ model = model._orig_mod if is_compiled_module(model) else model
748
+ return model
749
+
750
+ # `accelerate` 0.16.0 will have better support for customized saving
751
+ if args.save_only_adapter:
752
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
753
+ def save_model_hook(models, weights, output_dir):
754
+ if accelerator.is_main_process:
755
+ for model in models:
756
+ if isinstance(model, type(unwrap_model(unet))): # save adapter only
757
+ adapter_state_dict = OrderedDict()
758
+ adapter_state_dict["image_proj"] = model.encoder_hid_proj.image_projection_layers[0].state_dict()
759
+ adapter_state_dict["ip_adapter"] = torch.nn.ModuleList(model.attn_processors.values()).state_dict()
760
+ torch.save(adapter_state_dict, os.path.join(output_dir, "adapter_ckpt.pt"))
761
+
762
+ weights.pop()
763
+
764
+ def load_model_hook(models, input_dir):
765
+
766
+ while len(models) > 0:
767
+ # pop models so that they are not loaded again
768
+ model = models.pop()
769
+
770
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
771
+ adapter_state_dict = torch.load(os.path.join(input_dir, "adapter_ckpt.pt"), map_location="cpu")
772
+ if list(adapter_state_dict.keys()) != ["image_proj", "ip_adapter"]:
773
+ from module.ip_adapter.utils import revise_state_dict
774
+ adapter_state_dict = revise_state_dict(adapter_state_dict)
775
+ model.encoder_hid_proj.image_projection_layers[0].load_state_dict(adapter_state_dict["image_proj"], strict=True)
776
+ missing, unexpected = torch.nn.ModuleList(model.attn_processors.values()).load_state_dict(adapter_state_dict["ip_adapter"], strict=False)
777
+ if len(unexpected) > 0:
778
+ raise ValueError(f"Unexpected keys: {unexpected}")
779
+ if len(missing) > 0:
780
+ for mk in missing:
781
+ if "ln" not in mk:
782
+ raise ValueError(f"Missing keys: {missing}")
783
+
784
+ accelerator.register_save_state_pre_hook(save_model_hook)
785
+ accelerator.register_load_state_pre_hook(load_model_hook)
786
+
787
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
788
+ # as these models are only used for inference, keeping weights in full precision is not required.
789
+ weight_dtype = torch.float32
790
+ if accelerator.mixed_precision == "fp16":
791
+ weight_dtype = torch.float16
792
+ elif accelerator.mixed_precision == "bf16":
793
+ weight_dtype = torch.bfloat16
794
+
795
+ if args.enable_xformers_memory_efficient_attention:
796
+ if is_xformers_available():
797
+ import xformers
798
+
799
+ xformers_version = version.parse(xformers.__version__)
800
+ if xformers_version == version.parse("0.0.16"):
801
+ logger.warning(
802
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
803
+ )
804
+ unet.enable_xformers_memory_efficient_attention()
805
+ else:
806
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
807
+
808
+ if args.gradient_checkpointing:
809
+ unet.enable_gradient_checkpointing()
810
+ vae.enable_gradient_checkpointing()
811
+
812
+ # Enable TF32 for faster training on Ampere GPUs,
813
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
814
+ if args.allow_tf32:
815
+ torch.backends.cuda.matmul.allow_tf32 = True
816
+
817
+ if args.scale_lr:
818
+ args.learning_rate = (
819
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
820
+ )
821
+
822
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
823
+ if args.use_8bit_adam:
824
+ try:
825
+ import bitsandbytes as bnb
826
+ except ImportError:
827
+ raise ImportError(
828
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
829
+ )
830
+
831
+ optimizer_class = bnb.optim.AdamW8bit
832
+ else:
833
+ optimizer_class = torch.optim.AdamW
834
+
835
+ # Optimizer creation.
836
+ ip_params, non_ip_params = seperate_ip_params_from_unet(unet)
837
+ params_to_optimize = ip_params
838
+ optimizer = optimizer_class(
839
+ params_to_optimize,
840
+ lr=args.learning_rate,
841
+ betas=(args.adam_beta1, args.adam_beta2),
842
+ weight_decay=args.adam_weight_decay,
843
+ eps=args.adam_epsilon,
844
+ )
845
+
846
+ # Instantiate Loss.
847
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
848
+ diffusion_losses = list()
849
+ for loss_config in losses_configs.diffusion_losses:
850
+ logger.info(f"Loading diffusion loss: {loss_config.name}")
851
+ loss = namedtuple("loss", ["loss", "weight"])
852
+ loss_class = eval(loss_config.name)
853
+ diffusion_losses.append(loss(loss_class(visualize_every_k=loss_config.visualize_every_k,
854
+ dtype=weight_dtype,
855
+ accelerator=accelerator,
856
+ **loss_config.init_params), weight=loss_config.weight))
857
+
858
+ # SDXL additional condition that will be added to time embedding.
859
+ def compute_time_ids(original_size, crops_coords_top_left):
860
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
861
+ target_size = (args.resolution, args.resolution)
862
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
863
+ add_time_ids = torch.tensor([add_time_ids])
864
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
865
+ return add_time_ids
866
+
867
+ # Text prompt embeddings.
868
+ @torch.no_grad()
869
+ def compute_embeddings(batch, text_encoders, tokenizers, drop_idx=None, is_train=True):
870
+ prompt_batch = batch[args.caption_column]
871
+ if drop_idx is not None:
872
+ for i in range(len(prompt_batch)):
873
+ prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i]
874
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
875
+ prompt_batch, text_encoders, tokenizers, is_train
876
+ )
877
+
878
+ add_time_ids = torch.cat(
879
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
880
+ )
881
+
882
+ prompt_embeds = prompt_embeds.to(accelerator.device)
883
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
884
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
885
+ sdxl_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
886
+
887
+ return prompt_embeds, sdxl_added_cond_kwargs
888
+
889
+ # Move pixels into latents.
890
+ @torch.no_grad()
891
+ def convert_to_latent(pixels):
892
+ model_input = vae.encode(pixels).latent_dist.sample()
893
+ model_input = model_input * vae.config.scaling_factor
894
+ if args.pretrained_vae_model_name_or_path is None:
895
+ model_input = model_input.to(weight_dtype)
896
+ return model_input
897
+
898
+ # Datasets and other data moduels.
899
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
900
+ compute_embeddings_fn = functools.partial(
901
+ compute_embeddings,
902
+ text_encoders=[text_encoder, text_encoder_2],
903
+ tokenizers=[tokenizer, tokenizer_2],
904
+ is_train=True,
905
+ )
906
+
907
+ datasets = []
908
+ datasets_name = []
909
+ datasets_weights = []
910
+ if args.data_config_path is not None:
911
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
912
+ for single_dataset in data_config.datasets:
913
+ datasets_weights.append(single_dataset.dataset_weight)
914
+ datasets_name.append(single_dataset.dataset_folder)
915
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
916
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
917
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
918
+ datasets.append(image_dataset)
919
+ # TODO: Validation dataset
920
+ if data_config.val_dataset is not None:
921
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
922
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
923
+
924
+ # Mix training datasets.
925
+ sampler_train = None
926
+ if len(datasets) == 1:
927
+ train_dataset = datasets[0]
928
+ else:
929
+ # Weighted each dataset
930
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
931
+ dataset_weights = []
932
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
933
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
934
+ sampler_train = torch.utils.data.WeightedRandomSampler(
935
+ weights=dataset_weights,
936
+ num_samples=len(dataset_weights)
937
+ )
938
+
939
+ train_dataloader = torch.utils.data.DataLoader(
940
+ train_dataset,
941
+ batch_size=args.train_batch_size,
942
+ sampler=sampler_train,
943
+ shuffle=True if sampler_train is None else False,
944
+ collate_fn=collate_fn,
945
+ num_workers=args.dataloader_num_workers
946
+ )
947
+
948
+ # We need to initialize the trackers we use, and also store our configuration.
949
+ # The trackers initializes automatically on the main process.
950
+ if accelerator.is_main_process:
951
+ tracker_config = dict(vars(args))
952
+
953
+ # tensorboard cannot handle list types for config
954
+ tracker_config.pop("validation_prompt")
955
+ tracker_config.pop("validation_image")
956
+
957
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
958
+
959
+ # Scheduler and math around the number of training steps.
960
+ overrode_max_train_steps = False
961
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
962
+ if args.max_train_steps is None:
963
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
964
+ overrode_max_train_steps = True
965
+
966
+ lr_scheduler = get_scheduler(
967
+ args.lr_scheduler,
968
+ optimizer=optimizer,
969
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
970
+ num_training_steps=args.max_train_steps,
971
+ num_cycles=args.lr_num_cycles,
972
+ power=args.lr_power,
973
+ )
974
+
975
+ # Prepare everything with our `accelerator`.
976
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
977
+ unet, optimizer, train_dataloader, lr_scheduler
978
+ )
979
+
980
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
981
+ if args.pretrained_vae_model_name_or_path is None:
982
+ # The VAE is fp32 to avoid NaN losses.
983
+ vae.to(accelerator.device, dtype=torch.float32)
984
+ else:
985
+ vae.to(accelerator.device, dtype=weight_dtype)
986
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
987
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
988
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
989
+ importance_ratio = importance_ratio.to(accelerator.device)
990
+ for non_ip_param in non_ip_params:
991
+ non_ip_param.data = non_ip_param.data.to(dtype=weight_dtype)
992
+ for ip_param in ip_params:
993
+ ip_param.requires_grad_(True)
994
+ unet.to(accelerator.device)
995
+
996
+ # Final check.
997
+ for n, p in unet.named_parameters():
998
+ if p.requires_grad: assert p.dtype == torch.float32, n
999
+ else: assert p.dtype == weight_dtype, n
1000
+ if args.sanity_check:
1001
+ if args.resume_from_checkpoint:
1002
+ if args.resume_from_checkpoint != "latest":
1003
+ path = os.path.basename(args.resume_from_checkpoint)
1004
+ else:
1005
+ # Get the most recent checkpoint
1006
+ dirs = os.listdir(args.output_dir)
1007
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1008
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1009
+ path = dirs[-1] if len(dirs) > 0 else None
1010
+
1011
+ if path is None:
1012
+ accelerator.print(
1013
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1014
+ )
1015
+ args.resume_from_checkpoint = None
1016
+ initial_global_step = 0
1017
+ else:
1018
+ accelerator.print(f"Resuming from checkpoint {path}")
1019
+ accelerator.load_state(os.path.join(args.output_dir, path))
1020
+
1021
+ # Check input data
1022
+ batch = next(iter(train_dataloader))
1023
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1024
+ images_log = log_validation(
1025
+ unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1026
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
1027
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True
1028
+ )
1029
+ exit()
1030
+
1031
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1032
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1033
+ if overrode_max_train_steps:
1034
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1035
+ # Afterwards we recalculate our number of training epochs
1036
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1037
+
1038
+ # Train!
1039
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1040
+
1041
+ logger.info("***** Running training *****")
1042
+ logger.info(f" Num examples = {len(train_dataset)}")
1043
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1044
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1045
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1046
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1047
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1048
+ logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}")
1049
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1050
+ global_step = 0
1051
+ first_epoch = 0
1052
+
1053
+ # Potentially load in the weights and states from a previous save
1054
+ if args.resume_from_checkpoint:
1055
+ if args.resume_from_checkpoint != "latest":
1056
+ path = os.path.basename(args.resume_from_checkpoint)
1057
+ else:
1058
+ # Get the most recent checkpoint
1059
+ dirs = os.listdir(args.output_dir)
1060
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1061
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1062
+ path = dirs[-1] if len(dirs) > 0 else None
1063
+
1064
+ if path is None:
1065
+ accelerator.print(
1066
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1067
+ )
1068
+ args.resume_from_checkpoint = None
1069
+ initial_global_step = 0
1070
+ else:
1071
+ accelerator.print(f"Resuming from checkpoint {path}")
1072
+ accelerator.load_state(os.path.join(args.output_dir, path))
1073
+ global_step = int(path.split("-")[1])
1074
+
1075
+ initial_global_step = global_step
1076
+ first_epoch = global_step // num_update_steps_per_epoch
1077
+ else:
1078
+ initial_global_step = 0
1079
+
1080
+ progress_bar = tqdm(
1081
+ range(0, args.max_train_steps),
1082
+ initial=initial_global_step,
1083
+ desc="Steps",
1084
+ # Only show the progress bar once on each machine.
1085
+ disable=not accelerator.is_local_main_process,
1086
+ )
1087
+
1088
+ trainable_models = [unet]
1089
+
1090
+ if args.gradient_checkpointing:
1091
+ checkpoint_models = []
1092
+ else:
1093
+ checkpoint_models = []
1094
+
1095
+ image_logs = None
1096
+ tic = time.time()
1097
+ for epoch in range(first_epoch, args.num_train_epochs):
1098
+ for step, batch in enumerate(train_dataloader):
1099
+ toc = time.time()
1100
+ io_time = toc - tic
1101
+ tic = toc
1102
+ for model in trainable_models + checkpoint_models:
1103
+ model.train()
1104
+ with accelerator.accumulate(*trainable_models):
1105
+ loss = torch.tensor(0.0)
1106
+
1107
+ # Drop conditions.
1108
+ rand_tensor = torch.rand(batch["images"].shape[0])
1109
+ drop_image_idx = rand_tensor < args.image_drop_rate
1110
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
1111
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
1112
+ drop_image_idx = drop_image_idx | drop_both_idx
1113
+ drop_text_idx = drop_text_idx | drop_both_idx
1114
+
1115
+ # Get LQ embeddings
1116
+ with torch.no_grad():
1117
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1118
+ lq_pt = image_processor(
1119
+ images=lq_img*0.5+0.5,
1120
+ do_rescale=False, return_tensors="pt"
1121
+ ).pixel_values
1122
+ image_embeds = prepare_training_image_embeds(
1123
+ image_encoder, image_processor,
1124
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1125
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
1126
+ idx_to_replace=drop_image_idx
1127
+ )
1128
+
1129
+ # Process text inputs.
1130
+ prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx)
1131
+ added_conditions["image_embeds"] = image_embeds
1132
+
1133
+ # Move inputs to latent space.
1134
+ gt_img = gt_img.to(dtype=vae.dtype)
1135
+ model_input = convert_to_latent(gt_img)
1136
+ if args.pretrained_vae_model_name_or_path is None:
1137
+ model_input = model_input.to(weight_dtype)
1138
+
1139
+ # Sample noise that we'll add to the latents.
1140
+ noise = torch.randn_like(model_input)
1141
+ bsz = model_input.shape[0]
1142
+
1143
+ # Sample a random timestep for each image.
1144
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
1145
+
1146
+ # Add noise to the model input according to the noise magnitude at each timestep
1147
+ # (this is the forward diffusion process)
1148
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1149
+ loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None
1150
+
1151
+ toc = time.time()
1152
+ prepare_time = toc - tic
1153
+ tic = time.time()
1154
+
1155
+ model_pred = unet(
1156
+ noisy_model_input, timesteps,
1157
+ encoder_hidden_states=prompt_embeds_input,
1158
+ added_cond_kwargs=added_conditions,
1159
+ return_dict=False
1160
+ )[0]
1161
+
1162
+ diffusion_loss_arguments = {
1163
+ "target": noise,
1164
+ "predict": model_pred,
1165
+ "prompt_embeddings_input": prompt_embeds_input,
1166
+ "timesteps": timesteps,
1167
+ "weights": loss_weights,
1168
+ }
1169
+
1170
+ loss_dict = dict()
1171
+ for loss_config in diffusion_losses:
1172
+ non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator)
1173
+ loss = loss + non_weighted_loss * loss_config.weight
1174
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
1175
+
1176
+ accelerator.backward(loss)
1177
+ if accelerator.sync_gradients:
1178
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1179
+ optimizer.step()
1180
+ lr_scheduler.step()
1181
+ optimizer.zero_grad()
1182
+
1183
+ toc = time.time()
1184
+ forward_time = toc - tic
1185
+ tic = toc
1186
+
1187
+ # Checks if the accelerator has performed an optimization step behind the scenes
1188
+ if accelerator.sync_gradients:
1189
+ progress_bar.update(1)
1190
+ global_step += 1
1191
+
1192
+ if accelerator.is_main_process:
1193
+ if global_step % args.checkpointing_steps == 0:
1194
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1195
+ if args.checkpoints_total_limit is not None:
1196
+ checkpoints = os.listdir(args.output_dir)
1197
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1198
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1199
+
1200
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1201
+ if len(checkpoints) >= args.checkpoints_total_limit:
1202
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1203
+ removing_checkpoints = checkpoints[0:num_to_remove]
1204
+
1205
+ logger.info(
1206
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1207
+ )
1208
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1209
+
1210
+ for removing_checkpoint in removing_checkpoints:
1211
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1212
+ shutil.rmtree(removing_checkpoint)
1213
+
1214
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1215
+ accelerator.save_state(save_path)
1216
+ logger.info(f"Saved state to {save_path}")
1217
+
1218
+ if global_step % args.validation_steps == 0:
1219
+ image_logs = log_validation(unwrap_model(unet), vae,
1220
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1221
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
1222
+ args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False)
1223
+
1224
+ logs = {}
1225
+ logs.update(loss_dict)
1226
+ logs.update({
1227
+ "lr": lr_scheduler.get_last_lr()[0],
1228
+ "io_time": io_time,
1229
+ "prepare_time": prepare_time,
1230
+ "forward_time": forward_time
1231
+ })
1232
+ progress_bar.set_postfix(**logs)
1233
+ accelerator.log(logs, step=global_step)
1234
+ tic = time.time()
1235
+
1236
+ if global_step >= args.max_train_steps:
1237
+ break
1238
+
1239
+ # Create the pipeline using using the trained modules and save it.
1240
+ accelerator.wait_for_everyone()
1241
+ if accelerator.is_main_process:
1242
+ accelerator.save_state(os.path.join(args.output_dir, "last"), safe_serialization=False)
1243
+ # Run a final round of validation.
1244
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
1245
+ image_logs = None
1246
+ if args.validation_image is not None:
1247
+ image_logs = log_validation(
1248
+ unwrap_model(unet), vae,
1249
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1250
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
1251
+ args, accelerator, weight_dtype, global_step,
1252
+ )
1253
+
1254
+ accelerator.end_training()
1255
+
1256
+
1257
+ if __name__ == "__main__":
1258
+ args = parse_args()
1259
+ main(args)