jadechoghari commited on
Commit
e6d384c
1 Parent(s): f1d908c

Update diffloss.py

Browse files
Files changed (1) hide show
  1. diffloss.py +7 -7
diffloss.py CHANGED
@@ -96,15 +96,15 @@ class TimestepEmbedder(nn.Module):
96
  # t_emb = self.mlp(t_freq)
97
  # return t_emb
98
  def forward(self, t):
99
- t = t.to(self.mlp.weight.device)
100
-
101
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
102
 
103
- t_freq = t_freq.to(self.mlp.weight.device)
 
 
 
 
104
 
105
- t_emb = self.mlp(t_freq)
106
-
107
- return t_emb
108
 
109
 
110
  class ResBlock(nn.Module):
 
96
  # t_emb = self.mlp(t_freq)
97
  # return t_emb
98
  def forward(self, t):
99
+ t = t.to(self.mlp.weight.device)
 
 
100
 
101
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
102
+
103
+ t_freq = t_freq.to(self.mlp.weight.device)
104
+
105
+ t_emb = self.mlp(t_freq)
106
 
107
+ return t_emb
 
 
108
 
109
 
110
  class ResBlock(nn.Module):