jadechoghari
commited on
Commit
•
7038305
1
Parent(s):
72fd365
add device
Browse files- diffloss.py +2 -2
diffloss.py
CHANGED
@@ -35,12 +35,12 @@ class DiffLoss(nn.Module):
|
|
35 |
def sample(self, z, temperature=1.0, cfg=1.0):
|
36 |
# diffusion loss sampling
|
37 |
if not cfg == 1.0:
|
38 |
-
noise = torch.randn(z.shape[0] // 2, self.in_channels)
|
39 |
noise = torch.cat([noise, noise], dim=0)
|
40 |
model_kwargs = dict(c=z, cfg_scale=cfg)
|
41 |
sample_fn = self.net.forward_with_cfg
|
42 |
else:
|
43 |
-
noise = torch.randn(z.shape[0], self.in_channels)
|
44 |
model_kwargs = dict(c=z)
|
45 |
sample_fn = self.net.forward
|
46 |
|
|
|
35 |
def sample(self, z, temperature=1.0, cfg=1.0):
|
36 |
# diffusion loss sampling
|
37 |
if not cfg == 1.0:
|
38 |
+
noise = torch.randn(z.shape[0] // 2, self.in_channels).to(device)
|
39 |
noise = torch.cat([noise, noise], dim=0)
|
40 |
model_kwargs = dict(c=z, cfg_scale=cfg)
|
41 |
sample_fn = self.net.forward_with_cfg
|
42 |
else:
|
43 |
+
noise = torch.randn(z.shape[0], self.in_channels).to(device)
|
44 |
model_kwargs = dict(c=z)
|
45 |
sample_fn = self.net.forward
|
46 |
|