Axolotl fine tuning Jamba-1.5-Mini
#7
by
coranholmes
- opened
Could you please also provide an example yaml for using axolotl to sft Jamba-1.5-Mini like that for Large? Thank you so much.
Hi!
for qLoRA+fsdp, you can use the same guide as in the Large card but just change the model name and batch according to your hardware - it requires ~70GB total
base_model: ai21labs/AI21-Jamba-1.5-Mini
tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-mini-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]
lora_target_linear: false
gradient_accumulation_steps: 4 # change according to your hardware
micro_batch_size: 4 # change according to your hardware
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
Thanks for your reply.
I tried to train Jamba1.5-Mini using axolotl but got some problems. I have posted an issue here. It would be of great help if you can offer some advice. Thank you in advance.
Hi,
can you specify the transformers version you are using ?
you can try this specific commit mentioned in the Large model-card as well under the qLoRA+fsdp finetuning section
pip install git+https://github.com/xgal/transformers@897f80665c37c531b7803f92655dbc9b3a593fe7
It should be fixed in transformers version >= 4.44.2
coranholmes
changed discussion status to
closed