jadechoghari commited on
Commit
3cb89bb
1 Parent(s): e6d384c

Update diffloss.py

Browse files
Files changed (1) hide show
  1. diffloss.py +8 -4
diffloss.py CHANGED
@@ -96,13 +96,17 @@ class TimestepEmbedder(nn.Module):
96
  # t_emb = self.mlp(t_freq)
97
  # return t_emb
98
  def forward(self, t):
99
- t = t.to(self.mlp.weight.device)
 
 
 
100
 
101
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
102
-
103
- t_freq = t_freq.to(self.mlp.weight.device)
104
-
105
  t_emb = self.mlp(t_freq)
 
106
 
107
  return t_emb
108
 
 
96
  # t_emb = self.mlp(t_freq)
97
  # return t_emb
98
  def forward(self, t):
99
+
100
+ device = next(self.mlp.parameters()).device
101
+
102
+ t = t.to(device)
103
 
104
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
105
+
106
+ t_freq = t_freq.to(device)
107
+
108
  t_emb = self.mlp(t_freq)
109
+
110
 
111
  return t_emb
112