SunderAli17 commited on
Commit
6f884cb
1 Parent(s): 73cf0ec

Create losses.py

Browse files
Files changed (1) hide show
  1. losses/losses.py +463 -0
losses/losses.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import wandb
3
+ import cv2
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from facenet_pytorch import MTCNN
7
+ from torchvision import transforms
8
+ from dreamsim import dreamsim
9
+ from einops import rearrange
10
+ import kornia.augmentation as K
11
+ import lpips
12
+
13
+ from pretrained_models.arcface import Backbone
14
+ from utils.vis_utils import add_text_to_image
15
+ from utils.utils import extract_faces_and_landmarks
16
+ import clip
17
+
18
+
19
+ class Loss():
20
+ """
21
+ General purpose loss class.
22
+ Mainly handles dtype and visualize_every_k.
23
+ keeps current iteration of loss, mainly for visualization purposes.
24
+ """
25
+ def __init__(self, visualize_every_k=-1, dtype=torch.float32, accelerator=None, **kwargs):
26
+ self.visualize_every_k = visualize_every_k
27
+ self.iteration = -1
28
+ self.dtype=dtype
29
+ self.accelerator = accelerator
30
+
31
+ def __call__(self, **kwargs):
32
+ self.iteration += 1
33
+ return self.forward(**kwargs)
34
+
35
+
36
+ class L1Loss(Loss):
37
+ """
38
+ Simple L1 loss between predicted_pixel_values and pixel_values
39
+
40
+ Args:
41
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
42
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
43
+ """
44
+ def forward(
45
+ self,
46
+ predict: torch.Tensor,
47
+ target: torch.Tensor,
48
+ **kwargs
49
+ ) -> torch.Tensor:
50
+ return F.l1_loss(predict, target, reduction="mean")
51
+
52
+
53
+ class DreamSIMLoss(Loss):
54
+ """DreamSIM loss between predicted_pixel_values and pixel_values.
55
+ DreamSIM is similar to LPIPS (https://dreamsim-nights.github.io/) but is trained on more human defined similarity dataset
56
+ DreamSIM expects an RGB image of size 224x224 and values between 0 and 1. So we need to normalize the input images to 0-1 range and resize them to 224x224.
57
+ Args:
58
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
59
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
60
+ """
61
+ def __init__(self, device: str='cuda:0', **kwargs):
62
+ super().__init__(**kwargs)
63
+ self.model, _ = dreamsim(pretrained=True, device=device)
64
+ self.model.to(dtype=self.dtype, device=device)
65
+ self.model = self.accelerator.prepare(self.model)
66
+ self.transforms = transforms.Compose([
67
+ transforms.Lambda(lambda x: (x + 1) / 2),
68
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)])
69
+
70
+ def forward(
71
+ self,
72
+ predicted_pixel_values: torch.Tensor,
73
+ encoder_pixel_values: torch.Tensor,
74
+ **kwargs,
75
+ ) -> torch.Tensor:
76
+ predicted_pixel_values.to(dtype=self.dtype)
77
+ encoder_pixel_values.to(dtype=self.dtype)
78
+ return self.model(self.transforms(predicted_pixel_values), self.transforms(encoder_pixel_values)).mean()
79
+
80
+
81
+ class LPIPSLoss(Loss):
82
+ """LPIPS loss between predicted_pixel_values and pixel_values.
83
+ Args:
84
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
85
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
86
+ """
87
+ def __init__(self, **kwargs):
88
+ super().__init__(**kwargs)
89
+ self.model = lpips.LPIPS(net='vgg')
90
+ self.model.to(dtype=self.dtype, device=self.accelerator.device)
91
+ self.model = self.accelerator.prepare(self.model)
92
+
93
+ def forward(self, predict, target, **kwargs):
94
+ predict.to(dtype=self.dtype)
95
+ target.to(dtype=self.dtype)
96
+ return self.model(predict, target).mean()
97
+
98
+
99
+ class LCMVisualization(Loss):
100
+ """Dummy loss used to visualize the LCM outputs
101
+ Args:
102
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
103
+ pixel_values (torch.Tensor): The input image to the decoder
104
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
105
+ """
106
+ def forward(
107
+ self,
108
+ predicted_pixel_values: torch.Tensor,
109
+ pixel_values: torch.Tensor,
110
+ encoder_pixel_values: torch.Tensor,
111
+ timesteps: torch.Tensor,
112
+ **kwargs,
113
+ ) -> None:
114
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
115
+ predicted_pixel_values = rearrange(predicted_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
116
+ pixel_values = rearrange(pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
117
+ encoder_pixel_values = rearrange(encoder_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
118
+ image = np.hstack([encoder_pixel_values, pixel_values, predicted_pixel_values])
119
+ for tracker in self.accelerator.trackers:
120
+ if tracker.name == 'wandb':
121
+ tracker.log({"TrainVisualization": wandb.Image(image, caption=f"Encoder Input Image, Decoder Input Image, Predicted LCM Image. Timesteps {timesteps.cpu().tolist()}")})
122
+ return torch.tensor(0.0)
123
+
124
+
125
+ class L2Loss(Loss):
126
+ """
127
+ Regular diffusion loss between predicted noise and target noise.
128
+ Args:
129
+ predicted_noise (torch.Tensor): noise predicted by the diffusion model
130
+ target_noise (torch.Tensor): actual noise added to the image.
131
+ """
132
+ def forward(
133
+ self,
134
+ predict: torch.Tensor,
135
+ target: torch.Tensor,
136
+ weights: torch.Tensor = None,
137
+ **kwargs
138
+ ) -> torch.Tensor:
139
+ if weights is not None:
140
+ loss = (predict.float() - target.float()).pow(2) * weights
141
+ return loss.mean()
142
+ return F.mse_loss(predict.float(), target.float(), reduction="mean")
143
+
144
+
145
+ class HuberLoss(Loss):
146
+ """Huber loss between predicted_pixel_values and pixel_values.
147
+ Args:
148
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
149
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
150
+ """
151
+ def __init__(self, huber_c=0.001, **kwargs):
152
+ super().__init__(**kwargs)
153
+ self.huber_c = huber_c
154
+
155
+ def forward(
156
+ self,
157
+ predict: torch.Tensor,
158
+ target: torch.Tensor,
159
+ weights: torch.Tensor = None,
160
+ **kwargs
161
+ ) -> torch.Tensor:
162
+ loss = torch.sqrt((predict.float() - target.float()) ** 2 + self.huber_c**2) - self.huber_c
163
+ if weights is not None:
164
+ return (loss * weights).mean()
165
+ return loss.mean()
166
+
167
+
168
+ class WeightedNoiseLoss(Loss):
169
+ """
170
+ Weighted diffusion loss between predicted noise and target noise.
171
+ Args:
172
+ predicted_noise (torch.Tensor): noise predicted by the diffusion model
173
+ target_noise (torch.Tensor): actual noise added to the image.
174
+ loss_batch_weights (torch.Tensor): weighting for each batch item. Can be used to e.g. zero-out loss for InstantID training if keypoint extraction fails.
175
+ """
176
+ def forward(
177
+ self,
178
+ predict: torch.Tensor,
179
+ target: torch.Tensor,
180
+ weights,
181
+ **kwargs
182
+ ) -> torch.Tensor:
183
+ return F.mse_loss(predict.float() * weights, target.float() * weights, reduction="mean")
184
+
185
+
186
+ class IDLoss(Loss):
187
+ """
188
+ Use pretrained facenet model to extract features from the face of the predicted image and target image.
189
+ Facenet expects 112x112 images, so we crop the face using MTCNN and resize it to 112x112.
190
+ Then we use the cosine similarity between the features to calculate the loss. (The cosine similarity is 1 - cosine distance).
191
+ Also notice that the outputs of facenet are normalized so the dot product is the same as cosine distance.
192
+ """
193
+ def __init__(self, pretrained_arcface_path: str, skip_not_found=True, **kwargs):
194
+ super().__init__(**kwargs)
195
+ assert pretrained_arcface_path is not None, "please pass `pretrained_arcface_path` in the losses config. You can download the pretrained model from "\
196
+ "https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing"
197
+ self.mtcnn = MTCNN(device=self.accelerator.device)
198
+ self.mtcnn.forward = self.mtcnn.detect
199
+ self.facenet_input_size = 112 # Has to be 112, can't find weights for 224 size.
200
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
201
+ self.facenet.load_state_dict(torch.load(pretrained_arcface_path))
202
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((self.facenet_input_size, self.facenet_input_size))
203
+ self.facenet.requires_grad_(False)
204
+ self.facenet.eval()
205
+ self.facenet.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision
206
+ self.face_pool.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision
207
+ self.visualization_resize = transforms.Resize((self.facenet_input_size, self.facenet_input_size), interpolation=transforms.InterpolationMode.BICUBIC)
208
+ self.reference_facial_points = np.array([[38.29459953, 51.69630051],
209
+ [72.53179932, 51.50139999],
210
+ [56.02519989, 71.73660278],
211
+ [41.54930115, 92.3655014],
212
+ [70.72990036, 92.20410156]
213
+ ]) # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112
214
+ self.facenet, self.face_pool, self.mtcnn = self.accelerator.prepare(self.facenet, self.face_pool, self.mtcnn)
215
+
216
+ self.skip_not_found = skip_not_found
217
+
218
+ def extract_feats(self, x: torch.Tensor):
219
+ """
220
+ Extract features from the face of the image using facenet model.
221
+ """
222
+ x = self.face_pool(x)
223
+ x_feats = self.facenet(x)
224
+
225
+ return x_feats
226
+
227
+ def forward(
228
+ self,
229
+ predicted_pixel_values: torch.Tensor,
230
+ encoder_pixel_values: torch.Tensor,
231
+ timesteps: torch.Tensor,
232
+ **kwargs
233
+ ):
234
+ encoder_pixel_values = encoder_pixel_values.to(dtype=self.dtype)
235
+ predicted_pixel_values = predicted_pixel_values.to(dtype=self.dtype)
236
+
237
+ predicted_pixel_values_face, predicted_invalid_indices = extract_faces_and_landmarks(predicted_pixel_values, mtcnn=self.mtcnn)
238
+ with torch.no_grad():
239
+ encoder_pixel_values_face, source_invalid_indices = extract_faces_and_landmarks(encoder_pixel_values, mtcnn=self.mtcnn)
240
+
241
+ if self.skip_not_found:
242
+ valid_indices = []
243
+ for i in range(predicted_pixel_values.shape[0]):
244
+ if i not in predicted_invalid_indices and i not in source_invalid_indices:
245
+ valid_indices.append(i)
246
+ else:
247
+ valid_indices = list(range(predicted_pixel_values))
248
+
249
+ valid_indices = torch.tensor(valid_indices).to(device=predicted_pixel_values.device)
250
+
251
+ if len(valid_indices) == 0:
252
+ loss = (predicted_pixel_values_face * 0.0).mean() # It's done this way so the `backwards` will delete the computation graph of the predicted_pixel_values.
253
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
254
+ self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss)
255
+ return loss
256
+
257
+ with torch.no_grad():
258
+ pixel_values_feats = self.extract_feats(encoder_pixel_values_face[valid_indices])
259
+
260
+ predicted_pixel_values_feats = self.extract_feats(predicted_pixel_values_face[valid_indices])
261
+ loss = 1 - torch.einsum("bi,bi->b", pixel_values_feats, predicted_pixel_values_feats)
262
+
263
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
264
+ self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss)
265
+ return loss.mean()
266
+
267
+ def visualize(
268
+ self,
269
+ predicted_pixel_values: torch.Tensor,
270
+ encoder_pixel_values: torch.Tensor,
271
+ predicted_pixel_values_face: torch.Tensor,
272
+ encoder_pixel_values_face: torch.Tensor,
273
+ timesteps: torch.Tensor,
274
+ valid_indices: torch.Tensor,
275
+ loss: torch.Tensor,
276
+ ) -> None:
277
+ small_predicted_pixel_values = (rearrange(self.visualization_resize(predicted_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy())
278
+ small_pixle_values = rearrange(self.visualization_resize(encoder_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy()
279
+ small_predicted_pixel_values_face = rearrange(self.visualization_resize(predicted_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy()
280
+ small_pixle_values_face = rearrange(self.visualization_resize(encoder_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy()
281
+
282
+ small_predicted_pixel_values = add_text_to_image(((small_predicted_pixel_values * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Images", add_below=False)
283
+ small_pixle_values = add_text_to_image(((small_pixle_values * 0.5 + 0.5) * 255).astype(np.uint8), "Target Images", add_below=False)
284
+ small_predicted_pixel_values_face = add_text_to_image(((small_predicted_pixel_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Faces", add_below=False)
285
+ small_pixle_values_face = add_text_to_image(((small_pixle_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Target Faces", add_below=False)
286
+
287
+
288
+ final_image = np.hstack([small_predicted_pixel_values, small_pixle_values, small_predicted_pixel_values_face, small_pixle_values_face])
289
+ for tracker in self.accelerator.trackers:
290
+ if tracker.name == 'wandb':
291
+ tracker.log({"IDLoss Visualization": wandb.Image(final_image, caption=f"loss: {loss.cpu().tolist()} timesteps: {timesteps.cpu().tolist()}, valid_indices: {valid_indices.cpu().tolist()}")})
292
+
293
+
294
+ class ImageAugmentations(torch.nn.Module):
295
+ # Standard image augmentations used for CLIP loss to discourage adversarial outputs.
296
+ def __init__(self, output_size, augmentations_number, p=0.7):
297
+ super().__init__()
298
+ self.output_size = output_size
299
+ self.augmentations_number = augmentations_number
300
+
301
+ self.augmentations = torch.nn.Sequential(
302
+ K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), # type: ignore
303
+ K.RandomPerspective(0.7, p=p),
304
+ )
305
+
306
+ self.avg_pool = torch.nn.AdaptiveAvgPool2d((self.output_size, self.output_size))
307
+
308
+ self.device = None
309
+
310
+ def forward(self, input):
311
+ """Extents the input batch with augmentations
312
+ If the input is consists of images [I1, I2] the extended augmented output
313
+ will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...]
314
+ Args:
315
+ input ([type]): input batch of shape [batch, C, H, W]
316
+ Returns:
317
+ updated batch: of shape [batch * augmentations_number, C, H, W]
318
+ """
319
+ # We want to multiply the number of images in the batch in contrast to regular augmantations
320
+ # that do not change the number of samples in the batch)
321
+ resized_images = self.avg_pool(input)
322
+ resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1))
323
+
324
+ batch_size = input.shape[0]
325
+ # We want at least one non augmented image
326
+ non_augmented_batch = resized_images[:batch_size]
327
+ augmented_batch = self.augmentations(resized_images[batch_size:])
328
+ updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0)
329
+
330
+ return updated_batch
331
+
332
+
333
+ class CLIPLoss(Loss):
334
+ def __init__(self, augmentations_number: int = 4, **kwargs):
335
+ super().__init__(**kwargs)
336
+
337
+ self.clip_model, clip_preprocess = clip.load("ViT-B/16", device=self.accelerator.device, jit=False)
338
+
339
+ self.clip_model.device = None
340
+
341
+ self.clip_model.eval().requires_grad_(False)
342
+
343
+ self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1].
344
+ clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
345
+ clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
346
+
347
+ self.clip_size = self.clip_model.visual.input_resolution
348
+
349
+ self.clip_normalize = transforms.Normalize(
350
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
351
+ )
352
+
353
+ self.image_augmentations = ImageAugmentations(output_size=self.clip_size,
354
+ augmentations_number=augmentations_number)
355
+
356
+ self.clip_model, self.image_augmentations = self.accelerator.prepare(self.clip_model, self.image_augmentations)
357
+
358
+ def forward(self, decoder_prompts, predicted_pixel_values: torch.Tensor, **kwargs) -> torch.Tensor:
359
+
360
+ if not isinstance(decoder_prompts, list):
361
+ decoder_prompts = [decoder_prompts]
362
+
363
+ tokens = clip.tokenize(decoder_prompts).to(predicted_pixel_values.device)
364
+ image = self.preprocess(predicted_pixel_values)
365
+
366
+ logits_per_image, _ = self.clip_model(image, tokens)
367
+
368
+ logits_per_image = torch.diagonal(logits_per_image)
369
+
370
+ return (1. - logits_per_image / 100).mean()
371
+
372
+
373
+ class DINOLoss(Loss):
374
+ def __init__(
375
+ self,
376
+ dino_model,
377
+ dino_preprocess,
378
+ output_hidden_states: bool = False,
379
+ center_momentum: float = 0.9,
380
+ student_temp: float = 0.1,
381
+ teacher_temp: float = 0.04,
382
+ warmup_teacher_temp: float = 0.04,
383
+ warmup_teacher_temp_epochs: int = 30,
384
+ **kwargs):
385
+ super().__init__(**kwargs)
386
+
387
+ self.dino_model = dino_model
388
+ self.output_hidden_states = output_hidden_states
389
+ self.rescale_factor = dino_preprocess.rescale_factor
390
+
391
+ # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1].
392
+ self.preprocess = transforms.Compose(
393
+ [
394
+ transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
395
+ transforms.Resize(size=256),
396
+ transforms.CenterCrop(size=(224, 224)),
397
+ transforms.Normalize(mean=dino_preprocess.image_mean, std=dino_preprocess.image_std)
398
+ ]
399
+ )
400
+
401
+ self.student_temp = student_temp
402
+ self.teacher_temp = teacher_temp
403
+ self.center_momentum = center_momentum
404
+ self.center = torch.zeros(1, 257, 1024).to(self.accelerator.device, dtype=self.dtype)
405
+
406
+ # TODO: add temp, now fixed to 0.04
407
+ # we apply a warm up for the teacher temperature because
408
+ # a too high temperature makes the training instable at the beginning
409
+ # self.teacher_temp_schedule = np.concatenate((
410
+ # np.linspace(warmup_teacher_temp,
411
+ # teacher_temp, warmup_teacher_temp_epochs),
412
+ # np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
413
+ # ))
414
+
415
+ self.dino_model = self.accelerator.prepare(self.dino_model)
416
+
417
+ def forward(
418
+ self,
419
+ target: torch.Tensor,
420
+ predict: torch.Tensor,
421
+ weights: torch.Tensor = None,
422
+ **kwargs) -> torch.Tensor:
423
+
424
+ predict = self.preprocess(predict)
425
+ target = self.preprocess(target)
426
+
427
+ encoder_input = torch.cat([target, predict]).to(self.dino_model.device, dtype=self.dino_model.dtype)
428
+
429
+ if self.output_hidden_states:
430
+ raise ValueError("Output hidden states not supported for DINO loss.")
431
+ image_enc_hidden_states = self.dino_model(encoder_input, output_hidden_states=True).hidden_states[-2]
432
+ else:
433
+ image_enc_hidden_states = self.dino_model(encoder_input).last_hidden_state
434
+
435
+ teacher_output, student_output = image_enc_hidden_states.chunk(2, dim=0) # [B, 257, 1024]
436
+
437
+ student_out = student_output.float() / self.student_temp
438
+
439
+ # teacher centering and sharpening
440
+ # temp = self.teacher_temp_schedule[epoch]
441
+ temp = self.teacher_temp
442
+ teacher_out = F.softmax((teacher_output.float() - self.center) / temp, dim=-1)
443
+ teacher_out = teacher_out.detach()
444
+
445
+ loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1, keepdim=True)
446
+ # self.update_center(teacher_output)
447
+
448
+ if weights is not None:
449
+ loss = loss * weights
450
+ return loss.mean()
451
+ return loss.mean()
452
+
453
+ @torch.no_grad()
454
+ def update_center(self, teacher_output):
455
+ """
456
+ Update center used for teacher output.
457
+ """
458
+ batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
459
+ self.accelerator.reduce(batch_center, reduction="sum")
460
+ batch_center = batch_center / (len(teacher_output) * self.accelerator.num_processes)
461
+
462
+ # ema update
463
+ self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)