File size: 1,189 Bytes
66b7c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn


# A wrapper model for Classifier-free guidance **SAMPLING** only
# https://arxiv.org/abs/2207.12598
class ClassifierFreeSampleModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model  # model is the actual model to run
        self.nfeats = self.model.nfeats
        self.cond_mode = self.model.cond_mode
        self.add_frame_cond = self.model.add_frame_cond
        if self.add_frame_cond is not None:
            if self.model.resume_trans is not None:
                self.transformer = self.model.transformer
                self.tokenizer = self.model.tokenizer
            self.step = self.model.step

    def forward(self, x, timesteps, y=None):
        out = self.model(x, timesteps, y, cond_drop_prob=0.0)
        out_uncond = self.model(x, timesteps, y, cond_drop_prob=1.0)
        return out_uncond + (y["scale"].view(-1, 1, 1) * (out - out_uncond))