audio2photoreal / model /cfg_sampler.py
lybxin's picture
Upload folder using huggingface_hub
66b7c56 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
# A wrapper model for Classifier-free guidance **SAMPLING** only
# https://arxiv.org/abs/2207.12598
class ClassifierFreeSampleModel(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model # model is the actual model to run
self.nfeats = self.model.nfeats
self.cond_mode = self.model.cond_mode
self.add_frame_cond = self.model.add_frame_cond
if self.add_frame_cond is not None:
if self.model.resume_trans is not None:
self.transformer = self.model.transformer
self.tokenizer = self.model.tokenizer
self.step = self.model.step
def forward(self, x, timesteps, y=None):
out = self.model(x, timesteps, y, cond_drop_prob=0.0)
out_uncond = self.model(x, timesteps, y, cond_drop_prob=1.0)
return out_uncond + (y["scale"].view(-1, 1, 1) * (out - out_uncond))