Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,518 Bytes
2422035 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Modified from:
# taming-transformers: https://github.com/CompVis/taming-transformers
# muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizer.tokenizer_image.lpips import LPIPS
from tokenizer.tokenizer_image.discriminator_patchgan import NLayerDiscriminator as PatchGANDiscriminator
from tokenizer.tokenizer_image.discriminator_stylegan import Discriminator as StyleGANDiscriminator
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.softplus(-logits_real))
loss_fake = torch.mean(F.softplus(logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def non_saturating_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real))
loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def hinge_gen_loss(logit_fake):
return -torch.mean(logit_fake)
def non_saturating_gen_loss(logit_fake):
return torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake))
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
class VQLoss(nn.Module):
def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,
disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight = False,
gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0,
codebook_weight=1.0, perceptual_weight=1.0,
):
super().__init__()
# discriminator loss
assert disc_type in ["patchgan", "stylegan"]
assert disc_loss in ["hinge", "vanilla", "non-saturating"]
if disc_type == "patchgan":
self.discriminator = PatchGANDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
ndf=disc_dim,
)
elif disc_type == "stylegan":
self.discriminator = StyleGANDiscriminator(
input_nc=disc_in_channels,
image_size=image_size,
)
else:
raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
elif disc_loss == "non-saturating":
self.disc_loss = non_saturating_d_loss
else:
raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
self.discriminator_iter_start = disc_start
self.disc_weight = disc_weight
self.disc_adaptive_weight = disc_adaptive_weight
assert gen_adv_loss in ["hinge", "non-saturating"]
# gen_adv_loss
if gen_adv_loss == "hinge":
self.gen_adv_loss = hinge_gen_loss
elif gen_adv_loss == "non-saturating":
self.gen_adv_loss = non_saturating_gen_loss
else:
raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
# perceptual loss
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# reconstruction loss
if reconstruction_loss == "l1":
self.rec_loss = F.l1_loss
elif reconstruction_loss == "l2":
self.rec_loss = F.mse_loss
else:
raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
self.rec_weight = reconstruction_weight
# codebook loss
self.codebook_weight = codebook_weight
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
return d_weight.detach()
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None,
logger=None, log_every=100):
# generator update
if optimizer_idx == 0:
# reconstruction loss
rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
# perceptual loss
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = torch.mean(p_loss)
# discriminator loss
logits_fake = self.discriminator(reconstructions.contiguous())
generator_adv_loss = self.gen_adv_loss(logits_fake)
if self.disc_adaptive_weight:
null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss
disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss, last_layer=last_layer)
else:
disc_adaptive_weight = 1
disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
loss = self.rec_weight * rec_loss + \
self.perceptual_weight * p_loss + \
disc_adaptive_weight * disc_weight * generator_adv_loss + \
codebook_loss[0] + codebook_loss[1] + codebook_loss[2]
if global_step % log_every == 0:
rec_loss = self.rec_weight * rec_loss
p_loss = self.perceptual_weight * p_loss
generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss
logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, "
f"vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, "
f"codebook_usage: {codebook_loss[3]:.4f}, generator_adv_loss: {generator_adv_loss:.4f}, "
f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}")
return loss
# discriminator update
if optimizer_idx == 1:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)
if global_step % log_every == 0:
logits_real = logits_real.detach().mean()
logits_fake = logits_fake.detach().mean()
logger.info(f"(Discriminator) "
f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}")
return d_adversarial_loss |