FAcodecV2 / gradient_reversal.py
Plachta's picture
Upload 5 files
a50ee15 verified
raw
history blame
894 Bytes
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.autograd import Function
import torch
from torch import nn
class GradientReversal(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return x
@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = -alpha * grad_output
return grad_input, None
revgrad = GradientReversal.apply
class GradientReversal(nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = torch.tensor(alpha, requires_grad=False)
def forward(self, x):
return revgrad(x, self.alpha)