|
import torch |
|
from peft import PeftModel |
|
from transformers import LlamaForCausalLM |
|
from transformers import LlamaTokenizer |
|
|
|
BASE_MODEL = "decapoda-research/llama-13b-hf" |
|
LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep" |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) |
|
model = LlamaForCausalLM.from_pretrained( |
|
BASE_MODEL, |
|
load_in_8bit=False, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
) |
|
model = PeftModel.from_pretrained( |
|
model, LORA_WEIGHTS, torch_dtype=torch.float16, use_auth_token=True |
|
) |
|
|