Example to load fails for GPUs with no bfloat16 support.
#1
by
RASMUS
- opened
Change load example for GPUs with no bfloat16 support to something like:
branch = "200B"
model = transformers.AutoModelForCausalLM.from_pretrained(
"LumiOpen/Viking-7B",
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
revision=branch,
)
ah, thanks!
RASMUS
changed discussion status to
closed