jadechoghari
commited on
Commit
•
813b71e
1
Parent(s):
469d7c2
Update mar.py
Browse files
mar.py
CHANGED
@@ -14,8 +14,8 @@ from diffloss import DiffLoss
|
|
14 |
|
15 |
|
16 |
def mask_by_order(mask_len, order, bsz, seq_len):
|
17 |
-
masking = torch.zeros(bsz, seq_len)
|
18 |
-
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len)).bool()
|
19 |
return masking
|
20 |
|
21 |
|
@@ -156,7 +156,7 @@ class MAR(nn.Module):
|
|
156 |
order = np.array(list(range(self.seq_len)))
|
157 |
np.random.shuffle(order)
|
158 |
orders.append(order)
|
159 |
-
orders = torch.Tensor(np.array(orders)).long()
|
160 |
return orders
|
161 |
|
162 |
def random_masking(self, x, orders):
|
@@ -180,7 +180,7 @@ class MAR(nn.Module):
|
|
180 |
# random drop class embedding during training
|
181 |
if self.training:
|
182 |
drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
|
183 |
-
drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(x.dtype)
|
184 |
class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
|
185 |
|
186 |
x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
|
@@ -262,8 +262,8 @@ class MAR(nn.Module):
|
|
262 |
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
|
263 |
|
264 |
# init and sample generation orders
|
265 |
-
mask = torch.ones(bsz, self.seq_len)
|
266 |
-
tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim)
|
267 |
orders = self.sample_orders(bsz)
|
268 |
|
269 |
indices = list(range(num_iter))
|
@@ -291,10 +291,10 @@ class MAR(nn.Module):
|
|
291 |
|
292 |
# mask ratio for the next round, following MaskGIT and MAGE.
|
293 |
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
|
294 |
-
mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)])
|
295 |
|
296 |
# masks out at least one for the next iteration
|
297 |
-
mask_len = torch.maximum(torch.Tensor([1]),
|
298 |
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
|
299 |
|
300 |
# get masking for next iteration and locations to be predicted in this iteration
|
|
|
14 |
|
15 |
|
16 |
def mask_by_order(mask_len, order, bsz, seq_len):
|
17 |
+
masking = torch.zeros(bsz, seq_len).to(device)
|
18 |
+
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).to(device)).bool()
|
19 |
return masking
|
20 |
|
21 |
|
|
|
156 |
order = np.array(list(range(self.seq_len)))
|
157 |
np.random.shuffle(order)
|
158 |
orders.append(order)
|
159 |
+
orders = torch.Tensor(np.array(orders)).to(device).long()
|
160 |
return orders
|
161 |
|
162 |
def random_masking(self, x, orders):
|
|
|
180 |
# random drop class embedding during training
|
181 |
if self.training:
|
182 |
drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
|
183 |
+
drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(device).to(x.dtype)
|
184 |
class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
|
185 |
|
186 |
x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
|
|
|
262 |
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
|
263 |
|
264 |
# init and sample generation orders
|
265 |
+
mask = torch.ones(bsz, self.seq_len).to(device)
|
266 |
+
tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).to(device)
|
267 |
orders = self.sample_orders(bsz)
|
268 |
|
269 |
indices = list(range(num_iter))
|
|
|
291 |
|
292 |
# mask ratio for the next round, following MaskGIT and MAGE.
|
293 |
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
|
294 |
+
mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).to(device)
|
295 |
|
296 |
# masks out at least one for the next iteration
|
297 |
+
mask_len = torch.maximum(torch.Tensor([1]).to(device),
|
298 |
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
|
299 |
|
300 |
# get masking for next iteration and locations to be predicted in this iteration
|