chansung commited on
Commit
33cc221
1 Parent(s): 32d38c1

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +1 -1
gen.py CHANGED
@@ -71,7 +71,7 @@ def get_pretrained_models(
71
  tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
72
  model_args.vocab_size = tokenizer.n_words
73
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
74
- model = Transformer(model_args)
75
  torch.set_default_tensor_type(torch.FloatTensor)
76
  model.load_state_dict(checkpoint, strict=False)
77
 
 
71
  tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
72
  model_args.vocab_size = tokenizer.n_words
73
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
74
+ model = Transformer(model_args).cuda().half()
75
  torch.set_default_tensor_type(torch.FloatTensor)
76
  model.load_state_dict(checkpoint, strict=False)
77