Spaces:
Runtime error
Runtime error
File size: 1,548 Bytes
fd43906 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
import torch.nn.functional as F
from diffusers import VQDiffusionScheduler
from .test_schedulers import SchedulerCommonTest
class VQDiffusionSchedulerTest(SchedulerCommonTest):
scheduler_classes = (VQDiffusionScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_vec_classes": 4097,
"num_train_timesteps": 100,
}
config.update(**kwargs)
return config
def dummy_sample(self, num_vec_classes):
batch_size = 4
height = 8
width = 8
sample = torch.randint(0, num_vec_classes, (batch_size, height * width))
return sample
@property
def dummy_sample_deter(self):
assert False
def dummy_model(self, num_vec_classes):
def model(sample, t, *args):
batch_size, num_latent_pixels = sample.shape
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
return_value = F.log_softmax(logits.double(), dim=1).float()
return return_value
return model
def test_timesteps(self):
for timesteps in [2, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_num_vec_classes(self):
for num_vec_classes in [5, 100, 1000, 4000]:
self.check_over_configs(num_vec_classes=num_vec_classes)
def test_time_indices(self):
for t in [0, 50, 99]:
self.check_over_forward(time_step=t)
def test_add_noise_device(self):
pass
|