SunderAli17 commited on
Commit
7f1b096
1 Parent(s): 6802c18

Create train_utils.py

Browse files
Files changed (1) hide show
  1. utils/train_utils.py +360 -0
utils/train_utils.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import contextlib
3
+ import time
4
+ import gc
5
+ import logging
6
+ import math
7
+ import os
8
+ import random
9
+ import jsonlines
10
+ import functools
11
+ import shutil
12
+ import pyrallis
13
+ import itertools
14
+ from pathlib import Path
15
+ from collections import namedtuple, OrderedDict
16
+
17
+ import accelerate
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ import transformers
23
+ from accelerate import Accelerator
24
+ from accelerate.logging import get_logger
25
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
26
+ from datasets import load_dataset
27
+ from packaging import version
28
+ from PIL import Image
29
+ from losses.losses import *
30
+ from torchvision import transforms
31
+ from torchvision.transforms.functional import crop
32
+ from tqdm.auto import tqdm
33
+
34
+
35
+ def import_model_class_from_model_name_or_path(
36
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
37
+ ):
38
+ from transformers import PretrainedConfig
39
+ text_encoder_config = PretrainedConfig.from_pretrained(
40
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
41
+ )
42
+ model_class = text_encoder_config.architectures[0]
43
+
44
+ if model_class == "CLIPTextModel":
45
+ from transformers import CLIPTextModel
46
+
47
+ return CLIPTextModel
48
+ elif model_class == "CLIPTextModelWithProjection":
49
+ from transformers import CLIPTextModelWithProjection
50
+
51
+ return CLIPTextModelWithProjection
52
+ else:
53
+ raise ValueError(f"{model_class} is not supported.")
54
+
55
+ def get_train_dataset(dataset_name, dataset_dir, args, accelerator):
56
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
57
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
58
+
59
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
60
+ # download the dataset.
61
+ dataset = load_dataset(
62
+ dataset_name,
63
+ data_dir=dataset_dir,
64
+ cache_dir=os.path.join(dataset_dir, ".cache"),
65
+ num_proc=4,
66
+ split="train",
67
+ )
68
+
69
+ # Preprocessing the datasets.
70
+ # We need to tokenize inputs and targets.
71
+ column_names = dataset.column_names
72
+
73
+ # 6. Get the column names for input/target.
74
+ if args.image_column is None:
75
+ args.image_column = column_names[0]
76
+ logger.info(f"image column defaulting to {column_names[0]}")
77
+ else:
78
+ image_column = args.image_column
79
+ if image_column not in column_names:
80
+ logger.warning(f"dataset {dataset_name} has no column {image_column}")
81
+
82
+ if args.caption_column is None:
83
+ args.caption_column = column_names[1]
84
+ logger.info(f"caption column defaulting to {column_names[1]}")
85
+ else:
86
+ caption_column = args.caption_column
87
+ if caption_column not in column_names:
88
+ logger.warning(f"dataset {dataset_name} has no column {caption_column}")
89
+
90
+ if args.conditioning_image_column is None:
91
+ args.conditioning_image_column = column_names[2]
92
+ logger.info(f"conditioning image column defaulting to {column_names[2]}")
93
+ else:
94
+ conditioning_image_column = args.conditioning_image_column
95
+ if conditioning_image_column not in column_names:
96
+ logger.warning(f"dataset {dataset_name} has no column {conditioning_image_column}")
97
+
98
+ with accelerator.main_process_first():
99
+ train_dataset = dataset.shuffle(seed=args.seed)
100
+ if args.max_train_samples is not None:
101
+ train_dataset = train_dataset.select(range(args.max_train_samples))
102
+ return train_dataset
103
+
104
+ def prepare_train_dataset(dataset, accelerator, deg_pipeline, centralize=False):
105
+
106
+ # Data augmentations.
107
+ hflip = deg_pipeline.augment_opt['use_hflip'] and random.random() < 0.5
108
+ vflip = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5
109
+ rot90 = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5
110
+ augment_transforms = []
111
+ if hflip:
112
+ augment_transforms.append(transforms.RandomHorizontalFlip(p=1.0))
113
+ if vflip:
114
+ augment_transforms.append(transforms.RandomVerticalFlip(p=1.0))
115
+ if rot90:
116
+ # FIXME
117
+ augment_transforms.append(transforms.RandomRotation(degrees=(90,90)))
118
+ torch_transforms=[transforms.ToTensor()]
119
+ if centralize:
120
+ # to [-1, 1]
121
+ torch_transforms.append(transforms.Normalize([0.5], [0.5]))
122
+
123
+ training_size = deg_pipeline.degrade_opt['gt_size']
124
+ image_transforms = transforms.Compose(augment_transforms)
125
+ train_transforms = transforms.Compose(torch_transforms)
126
+ train_resize = transforms.Resize(training_size, interpolation=transforms.InterpolationMode.BILINEAR)
127
+ train_crop = transforms.RandomCrop(training_size)
128
+
129
+ def preprocess_train(examples):
130
+ raw_images = []
131
+ for img_data in examples[args.image_column]:
132
+ raw_images.append(Image.open(img_data).convert("RGB"))
133
+
134
+ # Image stack.
135
+ images = []
136
+ original_sizes = []
137
+ crop_top_lefts = []
138
+ # Degradation kernels stack.
139
+ kernel = []
140
+ kernel2 = []
141
+ sinc_kernel = []
142
+
143
+ for raw_image in raw_images:
144
+ raw_image = image_transforms(raw_image)
145
+ original_sizes.append((raw_image.height, raw_image.width))
146
+
147
+ # Resize smaller edge.
148
+ raw_image = train_resize(raw_image)
149
+ # Crop to training size.
150
+ y1, x1, h, w = train_crop.get_params(raw_image, (training_size, training_size))
151
+ raw_image = crop(raw_image, y1, x1, h, w)
152
+ crop_top_left = (y1, x1)
153
+ crop_top_lefts.append(crop_top_left)
154
+ image = train_transforms(raw_image)
155
+
156
+ images.append(image)
157
+ k, k2, sk = deg_pipeline.get_kernel()
158
+ kernel.append(k)
159
+ kernel2.append(k2)
160
+ sinc_kernel.append(sk)
161
+
162
+ examples["images"] = images
163
+ examples["original_sizes"] = original_sizes
164
+ examples["crop_top_lefts"] = crop_top_lefts
165
+ examples["kernel"] = kernel
166
+ examples["kernel2"] = kernel2
167
+ examples["sinc_kernel"] = sinc_kernel
168
+
169
+ return examples
170
+
171
+ with accelerator.main_process_first():
172
+ dataset = dataset.with_transform(preprocess_train)
173
+
174
+ return dataset
175
+
176
+ def collate_fn(examples):
177
+ images = torch.stack([example["images"] for example in examples])
178
+ images = images.to(memory_format=torch.contiguous_format).float()
179
+ kernel = torch.stack([example["kernel"] for example in examples])
180
+ kernel = kernel.to(memory_format=torch.contiguous_format).float()
181
+ kernel2 = torch.stack([example["kernel2"] for example in examples])
182
+ kernel2 = kernel2.to(memory_format=torch.contiguous_format).float()
183
+ sinc_kernel = torch.stack([example["sinc_kernel"] for example in examples])
184
+ sinc_kernel = sinc_kernel.to(memory_format=torch.contiguous_format).float()
185
+ original_sizes = [example["original_sizes"] for example in examples]
186
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
187
+
188
+ prompts = []
189
+ for example in examples:
190
+ prompts.append(example[args.caption_column]) if args.caption_column in example else prompts.append("")
191
+
192
+ return {
193
+ "images": images,
194
+ "text": prompts,
195
+ "kernel": kernel,
196
+ "kernel2": kernel2,
197
+ "sinc_kernel": sinc_kernel,
198
+ "original_sizes": original_sizes,
199
+ "crop_top_lefts": crop_top_lefts,
200
+ }
201
+
202
+ def encode_prompt(prompt_batch, text_encoders, tokenizers, is_train=True):
203
+ prompt_embeds_list = []
204
+
205
+ captions = []
206
+ for caption in prompt_batch:
207
+ if isinstance(caption, str):
208
+ captions.append(caption)
209
+ elif isinstance(caption, (list, np.ndarray)):
210
+ # take a random caption if there are multiple
211
+ captions.append(random.choice(caption) if is_train else caption[0])
212
+
213
+ with torch.no_grad():
214
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
215
+ text_inputs = tokenizer(
216
+ captions,
217
+ padding="max_length",
218
+ max_length=tokenizer.model_max_length,
219
+ truncation=True,
220
+ return_tensors="pt",
221
+ )
222
+ text_input_ids = text_inputs.input_ids
223
+ prompt_embeds = text_encoder(
224
+ text_input_ids.to(text_encoder.device),
225
+ output_hidden_states=True,
226
+ )
227
+
228
+ # We are only ALWAYS interested in the pooled output of the final text encoder
229
+ pooled_prompt_embeds = prompt_embeds[0]
230
+ prompt_embeds = prompt_embeds.hidden_states[-2]
231
+ bs_embed, seq_len, _ = prompt_embeds.shape
232
+ prompt_embeds_list.append(prompt_embeds)
233
+
234
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
235
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
236
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
237
+ return prompt_embeds, pooled_prompt_embeds
238
+
239
+ def importance_sampling_fn(t, max_t, alpha):
240
+ """Importance Sampling Function f(t)"""
241
+ return 1 / max_t * (1 - alpha * np.cos(np.pi * t / max_t))
242
+
243
+ def extract_into_tensor(a, t, x_shape):
244
+ b, *_ = t.shape
245
+ out = a.gather(-1, t)
246
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
247
+
248
+ def tensor_to_pil(images):
249
+ """
250
+ Convert image tensor or a batch of image tensors to PIL image(s).
251
+ """
252
+ images = (images + 1) / 2
253
+ images_np = images.detach().cpu().numpy()
254
+ if images_np.ndim == 4:
255
+ images_np = np.transpose(images_np, (0, 2, 3, 1))
256
+ elif images_np.ndim == 3:
257
+ images_np = np.transpose(images_np, (1, 2, 0))
258
+ images_np = images_np[None, ...]
259
+ images_np = (images_np * 255).round().astype("uint8")
260
+ if images_np.shape[-1] == 1:
261
+ # special case for grayscale (single channel) images
262
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np]
263
+ else:
264
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np]
265
+
266
+ return pil_images
267
+
268
+ def save_np_to_image(img_np, save_dir):
269
+ img_np = np.transpose(img_np, (0, 2, 3, 1))
270
+ img_np = (img_np * 255).astype(np.uint8)
271
+ img_np = Image.fromarray(img_np[0])
272
+ img_np.save(save_dir)
273
+
274
+
275
+ def seperate_SFT_params_from_unet(unet):
276
+ params = []
277
+ non_params = []
278
+ for name, param in unet.named_parameters():
279
+ if "SFT" in name:
280
+ params.append(param)
281
+ else:
282
+ non_params.append(param)
283
+ return params, non_params
284
+
285
+
286
+ def seperate_lora_params_from_unet(unet):
287
+ keys = []
288
+ frozen_keys = []
289
+ for name, param in unet.named_parameters():
290
+ if "lora" in name:
291
+ keys.append(param)
292
+ else:
293
+ frozen_keys.append(param)
294
+ return keys, frozen_keys
295
+
296
+
297
+ def seperate_ip_params_from_unet(unet):
298
+ ip_params = []
299
+ non_ip_params = []
300
+ for name, param in unet.named_parameters():
301
+ if "encoder_hid_proj." in name or "_ip." in name:
302
+ ip_params.append(param)
303
+ elif "attn" in name and "processor" in name:
304
+ if "ip" in name or "ln" in name:
305
+ ip_params.append(param)
306
+ else:
307
+ non_ip_params.append(param)
308
+ return ip_params, non_ip_params
309
+
310
+
311
+ def seperate_ref_params_from_unet(unet):
312
+ ip_params = []
313
+ non_ip_params = []
314
+ for name, param in unet.named_parameters():
315
+ if "encoder_hid_proj." in name or "_ip." in name:
316
+ ip_params.append(param)
317
+ elif "attn" in name and "processor" in name:
318
+ if "ip" in name or "ln" in name:
319
+ ip_params.append(param)
320
+ elif "extract" in name:
321
+ ip_params.append(param)
322
+ else:
323
+ non_ip_params.append(param)
324
+ return ip_params, non_ip_params
325
+
326
+
327
+ def seperate_ip_modules_from_unet(unet):
328
+ ip_modules = []
329
+ non_ip_modules = []
330
+ for name, module in unet.named_modules():
331
+ if "encoder_hid_proj" in name or "attn2.processor" in name:
332
+ ip_modules.append(module)
333
+ else:
334
+ non_ip_modules.append(module)
335
+ return ip_modules, non_ip_modules
336
+
337
+
338
+ def seperate_SFT_keys_from_unet(unet):
339
+ keys = []
340
+ non_keys = []
341
+ for name, param in unet.named_parameters():
342
+ if "SFT" in name:
343
+ keys.append(name)
344
+ else:
345
+ non_keys.append(name)
346
+ return keys, non_keys
347
+
348
+
349
+ def seperate_ip_keys_from_unet(unet):
350
+ keys = []
351
+ non_keys = []
352
+ for name, param in unet.named_parameters():
353
+ if "encoder_hid_proj." in name or "_ip." in name:
354
+ keys.append(name)
355
+ elif "attn" in name and "processor" in name:
356
+ if "ip" in name or "ln" in name:
357
+ keys.append(name)
358
+ else:
359
+ non_keys.append(name)
360
+ return keys, non_keys