File size: 2,509 Bytes
76f5df4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# ################################
# Model: Whisper + TLTR + Audio_Proj + LLaMa3
# Authors: Yingzhi Wang 2024
# ################################
# URL for the LLAMA3 model and its save folder
llama_hub: meta-llama/Meta-Llama-3-8B-Instruct # lmsys/vicuna-7b-v1.5
llama3_folder: llama3_checkpoint
# llama generation config
num_beams: 3
max_new_tokens: 400
top_k: 500
top_p: 0.95
temperature: 0.1
repetition_penalty: 1.1
# lora config
lora_dropout: 0.05
lora_alpha: 16
r: 8
bias: "none"
task_type: "CAUSAL_LM"
lora_target_modules: ["q_proj", "v_proj"]
# URL for whisper model.
whisper_hub: openai/whisper-large
whisper_folder: whisper_checkpoint
freeze_whisper: True
whisper_output_dim: 1280
# average pooling
pooling_kernel: 20
# Audio Tagging model
tltr_layers: 32
llama_hidden_size: 4096
# Masks
audio_padding_mask: !name:speechbrain.dataio.dataio.length_to_mask
text_padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
source: !ref <whisper_hub>
freeze: !ref <freeze_whisper>
save_path: !ref <whisper_folder>
encoder_only: True
output_all_hiddens: True
avg_pool: !new:speechbrain.nnet.pooling.Pooling1d
pool_type: "avg"
kernel_size: !ref <pooling_kernel>
tltr: !new:speechbrain.lobes.models.TLTR.AT_MODEL
n_layer: !ref <tltr_layers>
rep_dim: !ref <whisper_output_dim>
freeze: True
audio_proj: !new:speechbrain.lobes.models.TLTR.AudioProjection
input_size: !ref <whisper_output_dim>
hidden_size: !ref <llama_hidden_size>
#LLAMA3 model
# llama3: null
llama3: !new:speechbrain.lobes.models.huggingface_transformers.llama2.LLAMA2
source: !ref <llama_hub>
freeze: True
save_path: !ref <llama3_folder>
max_new_tokens: !ref <max_new_tokens>
num_beams: !ref <num_beams>
top_k: !ref <top_k>
top_p: !ref <top_p>
temperature: !ref <temperature>
repetition_penalty: !ref <repetition_penalty>
with_peft: True
lora_alpha: !ref <lora_alpha>
lora_dropout: !ref <lora_dropout>
r: !ref <r>
bias: !ref <bias>
task_type: !ref <task_type>
lora_target_modules: !ref <lora_target_modules>
modules:
tltr: !ref <tltr>
audio_proj: !ref <audio_proj>
llama3: !ref <llama3>
model: !new:torch.nn.ModuleList
- [!ref <tltr>, !ref <audio_proj>]
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
loadables:
llama3: !ref <llama3>
model: !ref <model>
|