Would it work well with sequence length > 2048?
Like other MPT models.
Possibly, although we haven't tested this extensively. Let us know if you find that it works well!
I've tried with the following code and it (any sequence length rather than 2048) doesn't work for me:
The same code works for mosaicml/mpt-7b-instruct
though.
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu'
print(f'Selected device is: {device}')
model_name = "nomic-ai/gpt4all-mpt"
config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True
)
# use the optimized triton implementation of FlashAttention, you can load the model with attn_impl='triton' and move the model to bfloat16
#config.attn_config['attn_impl'] = 'triton'
config.init_device = device
# config.max_seq_len = 2048
# update the maximum sequence length during inference to 4096
config.max_seq_len = 3072
print(config)
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
torch_dtype=torch.bfloat16,
trust_remote_code = True
)
model.eval()
I got the following error:
RuntimeError: Error(s) in loading state_dict for MPTForCausalLM:
size mismatch for transformer.wpe.weight: copying a param with shape torch.Size([2048, 4096]) from checkpoint, the shape in current model is torch.Size([3072, 4096]).
You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.
Set ignore_mismatched_sizes=True
still won't fix it. Instead, you got a different error:
File /opt/anaconda3/lib/python3.9/site-packages/transformers/modeling_utils.py:3031, in PreTrainedModel._load_pretrained_model.<locals>._find_mismatched_keys(state_dict, model_state_dict, loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes)
3025 elif add_prefix_to_model:
3026 # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
3027 model_key = ".".join(checkpoint_key.split(".")[1:])
3029 if (
3030 model_key in model_state_dict
-> 3031 and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
3032 ):
3033 mismatched_keys.append(
3034 (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
3035 )
3036 del state_dict[checkpoint_key]
KeyError: 'transformer.blocks.11.ffn.down_proj.weight'
By the way, this model also doesn't support the optimized triton implementation of FlashAttention like mosaicml/mpt-7b-instruct
.
If you turn it on via config.attn_config['attn_impl'] = 'triton'
, you will get the same KeyError: 'transformer.blocks.11.ffn.down_proj.weight'
error.
@zpn any chance you could shed some light on the possible cause of this error? Thanks a lot~