Spaces:
Runtime error
Runtime error
base: | |
name: "OpenSLUv1" | |
multi_intent: true | |
train: true | |
test: true | |
device: cuda | |
seed: 42 | |
epoch_num: 100 | |
batch_size: 16 | |
ignore_index: -100 | |
model_manager: | |
load_dir: null | |
save_dir: save/vanilla-mix-snips | |
evaluator: | |
best_key: EMA | |
eval_by_epoch: true | |
# eval_step: 1800 | |
metric: | |
- intent_acc | |
- intent_f1 | |
- slot_f1 | |
- EMA | |
dataset: | |
dataset_name: atis | |
tokenizer: | |
_tokenizer_name_: word_tokenizer | |
_padding_side_: right | |
_align_mode_: fast | |
add_special_tokens: false | |
max_length: 512 | |
optimizer: | |
_model_target_: torch.optim.Adam | |
_model_partial_: true | |
lr: 0.001 | |
weight_decay: 1e-6 | |
scheduler: | |
_model_target_: transformers.get_scheduler | |
_model_partial_: true | |
name : "linear" | |
num_warmup_steps: 0 | |
model: | |
_model_target_: model.OpenSLUModel | |
encoder: | |
_model_target_: model.encoder.AutoEncoder | |
encoder_name: self-attention-lstm | |
embedding: | |
embedding_dim: 128 | |
dropout_rate: 0.4 | |
lstm: | |
layer_num: 1 | |
bidirectional: true | |
output_dim: 256 | |
dropout_rate: 0.4 | |
attention: | |
hidden_dim: 1024 | |
output_dim: 128 | |
dropout_rate: 0.4 | |
output_dim: "{model.encoder.lstm.output_dim} + {model.encoder.attention.output_dim}" | |
return_with_input: true | |
return_sentence_level_hidden: true | |
decoder: | |
_model_target_: model.decoder.BaseDecoder | |
intent_classifier: | |
_model_target_: model.decoder.classifier.LinearClassifier | |
mode: "intent" | |
input_dim: "{model.encoder.output_dim}" | |
loss_fn: | |
_model_target_: torch.nn.BCEWithLogitsLoss | |
use_multi: "{base.multi_intent}" | |
multi_threshold: 0.5 | |
return_sentence_level: true | |
ignore_index: "{base.ignore_index}" | |
slot_classifier: | |
_model_target_: model.decoder.classifier.LinearClassifier | |
mode: "slot" | |
input_dim: "{model.encoder.output_dim}" | |
use_multi: false | |
multi_threshold: 0.5 | |
ignore_index: "{base.ignore_index}" | |
return_sentence_level: false |