speechbrain
PyTorch
English
speech-llm
audio-llm
File size: 2,509 Bytes
76f5df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# ################################
# Model: Whisper + TLTR + Audio_Proj + LLaMa3
# Authors: Yingzhi Wang 2024
# ################################

# URL for the LLAMA3 model and its save folder
llama_hub: meta-llama/Meta-Llama-3-8B-Instruct # lmsys/vicuna-7b-v1.5
llama3_folder: llama3_checkpoint

# llama generation config
num_beams: 3
max_new_tokens: 400
top_k: 500
top_p: 0.95
temperature: 0.1
repetition_penalty: 1.1

# lora config
lora_dropout: 0.05
lora_alpha: 16
r: 8
bias: "none"
task_type: "CAUSAL_LM"
lora_target_modules: ["q_proj", "v_proj"]

# URL for whisper model.
whisper_hub: openai/whisper-large
whisper_folder: whisper_checkpoint
freeze_whisper: True
whisper_output_dim: 1280

# average pooling
pooling_kernel: 20

# Audio Tagging model
tltr_layers: 32
llama_hidden_size: 4096

# Masks
audio_padding_mask: !name:speechbrain.dataio.dataio.length_to_mask
text_padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask

whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
    source: !ref <whisper_hub>
    freeze: !ref <freeze_whisper>
    save_path: !ref <whisper_folder>
    encoder_only: True
    output_all_hiddens: True

avg_pool: !new:speechbrain.nnet.pooling.Pooling1d
    pool_type: "avg"
    kernel_size: !ref <pooling_kernel>

tltr: !new:speechbrain.lobes.models.TLTR.AT_MODEL
    n_layer: !ref <tltr_layers>
    rep_dim: !ref <whisper_output_dim>
    freeze: True

audio_proj: !new:speechbrain.lobes.models.TLTR.AudioProjection
    input_size: !ref <whisper_output_dim>
    hidden_size: !ref <llama_hidden_size>

#LLAMA3 model
# llama3: null
llama3: !new:speechbrain.lobes.models.huggingface_transformers.llama2.LLAMA2
    source: !ref <llama_hub>
    freeze: True
    save_path: !ref <llama3_folder>
    max_new_tokens: !ref <max_new_tokens>
    num_beams: !ref <num_beams>
    top_k: !ref  <top_k>
    top_p: !ref <top_p>
    temperature: !ref <temperature>
    repetition_penalty: !ref <repetition_penalty>
    with_peft: True
    lora_alpha: !ref <lora_alpha>
    lora_dropout: !ref <lora_dropout>
    r: !ref <r>
    bias: !ref <bias>
    task_type: !ref <task_type>
    lora_target_modules: !ref <lora_target_modules>

modules:
   tltr: !ref <tltr>
   audio_proj: !ref <audio_proj>
   llama3: !ref <llama3>

model: !new:torch.nn.ModuleList
   - [!ref <tltr>, !ref <audio_proj>]

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    loadables:
        llama3: !ref <llama3>
        model: !ref <model>