Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

LLaVA-LLaMA-3.1 Model Card

Model Details

LLaVA is an open-source chatbot trained by fine-tuning LLM on multimodal instruction-following data. Now we release a LLaVA model with meta-llama/Meta-Llama-3.1-8B-Instruct as the language model.

License Notices

This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the OpenAI Terms of Use for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. Llama-1/2 community license for LLaMA-2 and Vicuna-v1.5, Tongyi Qianwen LICENSE AGREEMENT and META LLAMA 3 COMMUNITY LICENSE AGREEMENT). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations.

Training

Training Procedure

Before our training, the weights of the vision encoder are initialized from pretrained checkpoints of LLaVA-1.6; and the weights of LLM are LLaMA-3.1-8B-instruct model.

Training Data

  • Stage0: We train the parameters of the projector on the 3K image-text pairs from llava-pretrain.
  • Stage1: We train the parameters of the projector and LLM on the 558K data from llava-pretrain (after filtering the incomplete data items, 270K are used).
  • Stage2: We conduct the instruction-tuning for the parameters of the projector and LLM on llava-instruct-150K dtaa.

Training Details

The training cost is ~70 hours on 8 NVIDIA A100-80GB (may vary due to hardware differences). All the training and evaluation were conduected using the MS-SWIFT framework.

Using the MS-SWIFT framework, the scripts of all the training stages are shown as follows:

  • Stage1: Training the projector on 3K llava-pretrain data:

    DATASET_ENABLE_CACHE=1 swift sft \
        --model_type llava1_6-llama3_1-8b-instruct \
        --dataset llava-pretrain#3000 \
        --batch_size 1 \
        --gradient_accumulation_steps 16 \
        --warmup_ratio 0.03 \
        --learning_rate 1e-5 \
        --sft_type full \
        --freeze_parameters_ratio 1 \
        --additional_trainable_parameters multi_modal_projector  
    
  • Stage2: Training the projector and language model on llava-pretrain dataset:

    DATASET_ENABLE_CACHE=1 NPROC_PER_NODE=8 swift sft \
        --model_type llava1_6-llama3_1-8b-instruct \
        --resume_from_checkpoint <the_ckpt_in_stage1> \
        --resume_only_model true \
        --dataset llava-pretrain \
        --batch_size 2 \
        --gradient_accumulation_steps 16 \
        --warmup_ratio 0.03 \
        --learning_rate 2e-5 \
        --deepspeed default-zero3 \
        --sft_type full \
        --freeze_parameters_ratio 0 
    
  • Stage3: After the stage 1 and 2, our model has had the ability to understand and describe images, but still need instruction-tuning to obtain the instruction-following ability.

    DATASET_ENABLE_CACHE=1 NPROC_PER_NODE=8 swift sft \
        --model_type llava1_6-llama3_1-8b-instruct \
        --resume_from_checkpoint <the_ckpt_in_stage2> \
        --resume_only_model true \
        --dataset llava-instruct-150k \
        --num_train_epochs 2 \
        --batch_size 2 \
        --gradient_accumulation_steps 16 \
        --warmup_ratio 0.03 \
        --learning_rate 1e-5 \
        --deepspeed zero3-offload \
        --sft_type full \
        --freeze_parameters_ratio 0 
    

Evaluation

We evaluate our model on the TextVQA benchmark. This is also conducted using MS-SWIFT. It achieves 60.482 on overall accuracy, which is basically comparable with other llava models.

CUDA_VISIBLE_DEVICES=0 swift eval \
    --ckpt_dir <the_ckpt_in_stage3> \
    --eval_dataset TextVQA_VAL 
Downloads last month
9
Safetensors
Model size
8.35B params
Tensor type
BF16
·
Inference API
Unable to determine this model's library. Check the docs .