Spaces:
Sleeping
Sleeping
File size: 593 Bytes
bd63939 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
from .llama_xformer import LlamaForCausalLM
def get_pretrained_llama_causal_model(pretrained_model_name_or_path=None, torch_dtype='fp16', **kwargs):
if torch_dtype == 'fp16' or torch_dtype == 'float16':
torch_dtype = torch.float16
elif torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
torch_dtype = torch.bfloat16
else:
torch_dtype == torch.float32
model = LlamaForCausalLM.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch_dtype,
**kwargs,
)
return model
|