Spaces:
Runtime error
Runtime error
Update gen.py
Browse files
gen.py
CHANGED
@@ -71,7 +71,7 @@ def get_pretrained_models(
|
|
71 |
tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
|
72 |
model_args.vocab_size = tokenizer.n_words
|
73 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
74 |
-
model = Transformer(model_args)
|
75 |
torch.set_default_tensor_type(torch.FloatTensor)
|
76 |
model.load_state_dict(checkpoint, strict=False)
|
77 |
|
|
|
71 |
tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
|
72 |
model_args.vocab_size = tokenizer.n_words
|
73 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
74 |
+
model = Transformer(model_args).cuda().half()
|
75 |
torch.set_default_tensor_type(torch.FloatTensor)
|
76 |
model.load_state_dict(checkpoint, strict=False)
|
77 |
|