wondervictor commited on
Commit
fa0ea4a
·
verified ·
1 Parent(s): ec32474

Update autoregressive/models/gpt_t2i.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/gpt_t2i.py +1 -1
autoregressive/models/gpt_t2i.py CHANGED
@@ -430,7 +430,7 @@ class Transformer(nn.Module):
430
  token_embeddings = self.cls_embedding(cond_idx, train=self.training)
431
  token_embeddings = token_embeddings[:,:self.cls_token_num]
432
  if condition is not None:
433
- condition_embeddings = self.condition_mlp(condition.to(torch.bfloat16),train=self.training)
434
  self.condition_token = condition_embeddings
435
 
436
  else: # decode_n_tokens(kv cache) in inference
 
430
  token_embeddings = self.cls_embedding(cond_idx, train=self.training)
431
  token_embeddings = token_embeddings[:,:self.cls_token_num]
432
  if condition is not None:
433
+ condition_embeddings = self.condition_mlp(condition,train=self.training)#.to(torch.bfloat16),train=self.training)
434
  self.condition_token = condition_embeddings
435
 
436
  else: # decode_n_tokens(kv cache) in inference