jadechoghari commited on
Commit
7038305
1 Parent(s): 72fd365

add device

Browse files
Files changed (1) hide show
  1. diffloss.py +2 -2
diffloss.py CHANGED
@@ -35,12 +35,12 @@ class DiffLoss(nn.Module):
35
  def sample(self, z, temperature=1.0, cfg=1.0):
36
  # diffusion loss sampling
37
  if not cfg == 1.0:
38
- noise = torch.randn(z.shape[0] // 2, self.in_channels)
39
  noise = torch.cat([noise, noise], dim=0)
40
  model_kwargs = dict(c=z, cfg_scale=cfg)
41
  sample_fn = self.net.forward_with_cfg
42
  else:
43
- noise = torch.randn(z.shape[0], self.in_channels)
44
  model_kwargs = dict(c=z)
45
  sample_fn = self.net.forward
46
 
 
35
  def sample(self, z, temperature=1.0, cfg=1.0):
36
  # diffusion loss sampling
37
  if not cfg == 1.0:
38
+ noise = torch.randn(z.shape[0] // 2, self.in_channels).to(device)
39
  noise = torch.cat([noise, noise], dim=0)
40
  model_kwargs = dict(c=z, cfg_scale=cfg)
41
  sample_fn = self.net.forward_with_cfg
42
  else:
43
+ noise = torch.randn(z.shape[0], self.in_channels).to(device)
44
  model_kwargs = dict(c=z)
45
  sample_fn = self.net.forward
46