SEED-LLaMA / models /model_tools.py
sjzhao's picture
update demo
bd63939
raw
history blame
593 Bytes
import torch
from .llama_xformer import LlamaForCausalLM
def get_pretrained_llama_causal_model(pretrained_model_name_or_path=None, torch_dtype='fp16', **kwargs):
if torch_dtype == 'fp16' or torch_dtype == 'float16':
torch_dtype = torch.float16
elif torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
torch_dtype = torch.bfloat16
else:
torch_dtype == torch.float32
model = LlamaForCausalLM.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch_dtype,
**kwargs,
)
return model