Spaces:
Sleeping
Sleeping
import torch | |
from .llama_xformer import LlamaForCausalLM | |
def get_pretrained_llama_causal_model(pretrained_model_name_or_path=None, torch_dtype='fp16', **kwargs): | |
if torch_dtype == 'fp16' or torch_dtype == 'float16': | |
torch_dtype = torch.float16 | |
elif torch_dtype == 'bf16' or torch_dtype == 'bfloat16': | |
torch_dtype = torch.bfloat16 | |
else: | |
torch_dtype == torch.float32 | |
model = LlamaForCausalLM.from_pretrained( | |
pretrained_model_name_or_path=pretrained_model_name_or_path, | |
torch_dtype=torch_dtype, | |
**kwargs, | |
) | |
return model | |