RuntimeError: mat1 and mat2 shapes cannot be multiplied (12x4096 and 1x8388608)
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
peft_model_id = "vincentmin/llama-2-7b-reward-oasst1"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSequenceClassification.from_pretrained("data/llama-2-7b-chat-hf", num_labels=1, load_in_4bit=True, torch_dtype=torch.float16)
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = AutoTokenizer.from_pretrained("data/llama-2-7b-chat-hf", use_auth_token=True)
with torch.no_grad():
reward = model(**tokenizer("prompter: hello world. assistant: foo bar", return_tensors='pt')).logits
when i run "reward = model(**tokenizer("prompter: hello world. assistant: foo bar", return_tensors='pt')).logits",
got a error : RuntimeError: mat1 and mat2 shapes cannot be multiplied (12x4096 and 1x8388608)
I dont know which steps were wrong, could you help check it? thanks
I don't recognise the error. I just tried running
!pip install -q transformers git+ accelerate trl bitsandbytes einops`
from huggingface_hub import notebook_login
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
peft_model_id = "vincentmin/llama-2-7b-reward-oasst1"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path, num_labels=1, load_in_4bit=True, torch_dtype=torch.float16)
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, use_auth_token=True)
with torch.no_grad():
reward = model(**tokenizer("prompter: hello world. assistant: foo bar", return_tensors='pt')).logits
in a Google Colab and it ran without issues. Can you try to install the latest version of these packages: transformers git+ accelerate trl bitsandbytes einops
How did you download your model to "data/llama-2-7b-chat-hf"?
Note that it is important to install peft
from git as an important patch was made with PR #755. This shouldn't affect your current issue though.
After installing the latest version of those packages, it worked!