marinone94
commited on
Commit
•
8829a08
1
Parent(s):
0533144
restructure main code
Browse files- prepare_dataset_lm.py +3 -1
- run_speech_recognition_ctc.py +330 -269
prepare_dataset_lm.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
""" Script to prepare and upload dataset for training Swedish n-gram LM to boost ASR. """
|
2 |
|
3 |
# Check colab notebook to get started
|
4 |
-
# https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Boosting_Wav2Vec2_with_n_grams_in_Transformers.ipynb#scrollTo=IrAzjWc3Ok2l
|
|
|
|
|
|
1 |
""" Script to prepare and upload dataset for training Swedish n-gram LM to boost ASR. """
|
2 |
|
3 |
# Check colab notebook to get started
|
4 |
+
# https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Boosting_Wav2Vec2_with_n_grams_in_Transformers.ipynb#scrollTo=IrAzjWc3Ok2l
|
5 |
+
|
6 |
+
# Notebook train_n_gram_lm_with_KenLM.ipynb has actual code
|
run_speech_recognition_ctc.py
CHANGED
@@ -13,7 +13,13 @@
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
|
16 |
-
""" Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
import datetime
|
19 |
import functools
|
@@ -325,6 +331,7 @@ def create_vocabulary_from_data(
|
|
325 |
unk_token: Optional[str] = None,
|
326 |
pad_token: Optional[str] = None,
|
327 |
):
|
|
|
328 |
# Given training and test labels create vocabulary
|
329 |
def extract_all_chars(batch, vocab):
|
330 |
all_text = " ".join(batch)
|
@@ -356,20 +363,18 @@ def create_vocabulary_from_data(
|
|
356 |
return vocab_dict
|
357 |
|
358 |
|
359 |
-
def
|
360 |
-
# See all possible arguments in src/transformers/training_args.py
|
361 |
-
# or by passing the --help flag to this script.
|
362 |
-
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
363 |
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
else
|
370 |
-
|
|
|
|
|
371 |
|
372 |
-
# TODO: Replace with check of wandb env vars
|
373 |
try:
|
374 |
repo_name = os.getcwd().split("/")[-1]
|
375 |
run_name = f"{datetime.datetime.utcnow()}".replace(" ", "T")
|
@@ -377,11 +382,12 @@ def main():
|
|
377 |
wandb.login()
|
378 |
training_args.report_to = ["wandb"]
|
379 |
training_args.run_name = run_name
|
380 |
-
|
381 |
-
|
382 |
-
|
|
|
|
|
383 |
|
384 |
-
# Detecting last checkpoint.
|
385 |
last_checkpoint = None
|
386 |
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
387 |
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
@@ -395,14 +401,10 @@ def main():
|
|
395 |
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
396 |
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
397 |
)
|
|
|
398 |
|
399 |
-
|
400 |
-
|
401 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
402 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
403 |
-
handlers=[logging.StreamHandler(sys.stdout)],
|
404 |
-
)
|
405 |
-
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
406 |
|
407 |
# Log on each process the small summary:
|
408 |
logger.warning(
|
@@ -414,54 +416,12 @@ def main():
|
|
414 |
transformers.utils.logging.set_verbosity_info()
|
415 |
logger.info("Training/evaluation parameters %s", training_args)
|
416 |
|
417 |
-
# Set seed before initializing model.
|
418 |
-
set_seed(training_args.seed)
|
419 |
|
420 |
-
|
421 |
-
raw_datasets = DatasetDict()
|
422 |
|
423 |
-
|
424 |
-
col_a = set(columns_a)
|
425 |
-
col_b = set(columns_b)
|
426 |
-
return [col for col in col_a if col in col_b]
|
427 |
|
428 |
if training_args.do_train:
|
429 |
-
|
430 |
-
# Multiple datasets might need to be loaded from HF
|
431 |
-
# It assumes they all follow the common voice format
|
432 |
-
# for (dataset_name, dataset_config_name, train_split_name) in zip(
|
433 |
-
# data_args.dataset_name.split(","),
|
434 |
-
# data_args.dataset_config_name.split(","),
|
435 |
-
# data_args.train_split_name.split(","),
|
436 |
-
# ):
|
437 |
-
|
438 |
-
|
439 |
-
# if train_split_name != "None":
|
440 |
-
# if "train" not in raw_datasets:
|
441 |
-
# raw_datasets["train"] = load_dataset(
|
442 |
-
# dataset_name,
|
443 |
-
# dataset_config_name,
|
444 |
-
# split=train_split_name,
|
445 |
-
# use_auth_token=data_args.use_auth_token,
|
446 |
-
# )
|
447 |
-
# min_columns_train = raw_datasets["train"].column_names
|
448 |
-
# else:
|
449 |
-
# new_dataset = load_dataset(
|
450 |
-
# dataset_name,
|
451 |
-
# dataset_config_name,
|
452 |
-
# split=train_split_name,
|
453 |
-
# use_auth_token=data_args.use_auth_token,
|
454 |
-
# )
|
455 |
-
# raw_datasets["train"] = concatenate_datasets(
|
456 |
-
# [
|
457 |
-
# raw_datasets["train"],
|
458 |
-
# new_dataset
|
459 |
-
# ]
|
460 |
-
# )
|
461 |
-
# min_columns_train = common_cols(min_columns_train, new_dataset.column_names)
|
462 |
-
# else:
|
463 |
-
# logging.warning(f"{dataset_name} {dataset_config_name} train not loaded as split is {train_split_name}")
|
464 |
-
|
465 |
raw_datasets["train"] = load_dataset(
|
466 |
data_args.dataset_name,
|
467 |
data_args.dataset_config_name,
|
@@ -477,7 +437,7 @@ def main():
|
|
477 |
)
|
478 |
|
479 |
dataset_frequency = raw_datasets["train"].features[data_args.audio_column_name].sampling_rate
|
480 |
-
|
481 |
|
482 |
if data_args.text_column_name not in raw_datasets["train"].column_names:
|
483 |
raise ValueError(
|
@@ -488,48 +448,8 @@ def main():
|
|
488 |
|
489 |
if data_args.max_train_samples is not None:
|
490 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
491 |
-
# other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
|
492 |
-
# raw_datasets["train"].remove_columns(other_columns_train)
|
493 |
-
|
494 |
-
# pd_train_head = raw_datasets["train"].select(range(10)).to_pandas()
|
495 |
-
# pd_train_tail = raw_datasets["train"].select(range(raw_datasets["train"].num_rows-10, raw_datasets["train"].num_rows)).to_pandas()
|
496 |
-
# pd_train = pd.concat([pd_train_head, pd_train_tail])
|
497 |
-
# print(pd_train["audio"])
|
498 |
|
499 |
if training_args.do_eval:
|
500 |
-
# Multiple datasets might need to be loaded from HF
|
501 |
-
# It assumes they all follow the common voice format
|
502 |
-
# for (dataset_name, dataset_config_name, eval_split_name) in zip(
|
503 |
-
# data_args.dataset_name.split(","),
|
504 |
-
# data_args.dataset_config_name.split(","),
|
505 |
-
# data_args.eval_split_name.split(","),
|
506 |
-
# ):
|
507 |
-
|
508 |
-
# if eval_split_name != "None":
|
509 |
-
# if "eval" not in raw_datasets:
|
510 |
-
# raw_datasets["eval"] = load_dataset(
|
511 |
-
# dataset_name,
|
512 |
-
# dataset_config_name,
|
513 |
-
# split=eval_split_name,
|
514 |
-
# use_auth_token=data_args.use_auth_token,
|
515 |
-
# )
|
516 |
-
# min_columns_eval = raw_datasets["eval"].column_names
|
517 |
-
# else:
|
518 |
-
# new_dataset = load_dataset(
|
519 |
-
# dataset_name,
|
520 |
-
# dataset_config_name,
|
521 |
-
# split=eval_split_name,
|
522 |
-
# use_auth_token=data_args.use_auth_token,
|
523 |
-
# )
|
524 |
-
# raw_datasets["eval"] = concatenate_datasets(
|
525 |
-
# [
|
526 |
-
# raw_datasets["eval"],
|
527 |
-
# new_dataset
|
528 |
-
# ]
|
529 |
-
# )
|
530 |
-
# min_columns_eval = common_cols(min_columns_eval, new_dataset.column_names)
|
531 |
-
# else:
|
532 |
-
# logging.warning(f"{dataset_name} {dataset_config_name} eval not loaded as split is {eval_split_name}")
|
533 |
|
534 |
try:
|
535 |
raw_datasets["eval"] = load_dataset(
|
@@ -542,23 +462,16 @@ def main():
|
|
542 |
split_dataset = raw_datasets["train"].train_test_split(test_size=0.1, seed=42)
|
543 |
raw_datasets["train"] = split_dataset["train"]
|
544 |
raw_datasets["eval"] = split_dataset["test"]
|
545 |
-
|
546 |
-
print("Sampled from training set")
|
547 |
|
548 |
if data_args.max_eval_samples is not None:
|
549 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
550 |
-
# other_columns_eval = [col for col in raw_datasets["eval"].column_names if col not in min_columns_eval]
|
551 |
-
# raw_datasets["eval"].remove_columns(other_columns_eval)
|
552 |
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
|
558 |
-
# 2. We remove some special characters from the datasets
|
559 |
-
# that make training complicated and do not help in transcribing the speech
|
560 |
-
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
561 |
-
# that could be easily picked up by the model
|
562 |
chars_to_ignore_regex = (
|
563 |
f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
564 |
)
|
@@ -571,35 +484,30 @@ def main():
|
|
571 |
return False
|
572 |
|
573 |
def remove_special_characters(batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
if chars_to_ignore_regex is not None:
|
575 |
-
|
576 |
-
re.sub(chars_to_ignore_regex, "", batch[text_column_name]) \
|
577 |
-
.replace("\\\\Punkt", "") \
|
578 |
-
.replace("\\\\Komma", "") \
|
579 |
-
.replace("è", "e") \
|
580 |
-
.replace("é", "e") \
|
581 |
-
.replace("î", "i") \
|
582 |
-
.replace("ü", "u") \
|
583 |
-
.replace("ÿ", "y") \
|
584 |
-
.replace("ô", "o") \
|
585 |
-
.replace("\\", "") \
|
586 |
-
.replace("/", "") \
|
587 |
-
.replace("|", "") \
|
588 |
-
.lower() + " "
|
589 |
else:
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
.replace("ü", "u") \
|
597 |
-
.replace("ÿ", "y") \
|
598 |
-
.replace("ô", "o") \
|
599 |
-
.replace("\\", "") \
|
600 |
-
.replace("/", "") \
|
601 |
-
.replace("|", "") \
|
602 |
-
.lower() + " "
|
603 |
return batch
|
604 |
|
605 |
num_workers = data_args.preprocessing_num_workers
|
@@ -617,23 +525,11 @@ def main():
|
|
617 |
desc="remove single words, single chars and 'W O R D S'",
|
618 |
)
|
619 |
|
620 |
-
|
621 |
-
word_delimiter_token = data_args.word_delimiter_token
|
622 |
-
unk_token = data_args.unk_token
|
623 |
-
pad_token = data_args.pad_token
|
624 |
|
625 |
-
# 3. Next, let's load the config as we might need it to create
|
626 |
-
# the tokenizer
|
627 |
-
# load config
|
628 |
-
config = AutoConfig.from_pretrained(
|
629 |
-
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
|
630 |
-
)
|
631 |
|
632 |
-
|
633 |
-
|
634 |
-
# the training and evaluation datasets
|
635 |
-
# We need to make sure that only first rank saves vocabulary
|
636 |
-
# make sure all processes wait until vocab is created
|
637 |
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
638 |
tokenizer_kwargs = {}
|
639 |
if tokenizer_name_or_path is None:
|
@@ -651,9 +547,9 @@ def main():
|
|
651 |
os.makedirs(tokenizer_name_or_path, exist_ok=True)
|
652 |
vocab_dict = create_vocabulary_from_data(
|
653 |
raw_datasets,
|
654 |
-
word_delimiter_token=word_delimiter_token,
|
655 |
-
unk_token=unk_token,
|
656 |
-
pad_token=pad_token,
|
657 |
)
|
658 |
|
659 |
# save vocab dict to be loaded into tokenizer
|
@@ -665,61 +561,15 @@ def main():
|
|
665 |
tokenizer_kwargs = {
|
666 |
"config": config if config.tokenizer_class is not None else None,
|
667 |
"tokenizer_type": config.model_type if config.tokenizer_class is None else None,
|
668 |
-
"unk_token": unk_token,
|
669 |
-
"pad_token": pad_token,
|
670 |
-
"word_delimiter_token": word_delimiter_token,
|
671 |
-
}
|
672 |
-
|
673 |
-
# 5. Now we can instantiate the feature extractor, tokenizer and model
|
674 |
-
# Note for distributed training, the .from_pretrained methods guarantee that only
|
675 |
-
# one local process can concurrently download model & vocab.
|
676 |
-
|
677 |
-
# load feature_extractor and tokenizer
|
678 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
679 |
-
tokenizer_name_or_path,
|
680 |
-
use_auth_token=data_args.use_auth_token,
|
681 |
-
**tokenizer_kwargs,
|
682 |
-
)
|
683 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
684 |
-
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
|
685 |
-
)
|
686 |
-
|
687 |
-
# adapt config
|
688 |
-
config.update(
|
689 |
-
{
|
690 |
-
"feat_proj_dropout": model_args.feat_proj_dropout,
|
691 |
-
"attention_dropout": model_args.attention_dropout,
|
692 |
-
"hidden_dropout": model_args.hidden_dropout,
|
693 |
-
"final_dropout": model_args.final_dropout,
|
694 |
-
"mask_time_prob": model_args.mask_time_prob,
|
695 |
-
"mask_time_length": model_args.mask_time_length,
|
696 |
-
"mask_feature_prob": model_args.mask_feature_prob,
|
697 |
-
"mask_feature_length": model_args.mask_feature_length,
|
698 |
-
"gradient_checkpointing": training_args.gradient_checkpointing,
|
699 |
-
"layerdrop": model_args.layerdrop,
|
700 |
-
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
701 |
-
"pad_token_id": tokenizer.pad_token_id,
|
702 |
-
"vocab_size": len(tokenizer),
|
703 |
-
"activation_dropout": model_args.activation_dropout,
|
704 |
}
|
705 |
-
)
|
706 |
|
707 |
-
|
708 |
-
model = AutoModelForCTC.from_pretrained(
|
709 |
-
model_args.model_name_or_path,
|
710 |
-
cache_dir=model_args.cache_dir,
|
711 |
-
config=config,
|
712 |
-
use_auth_token=data_args.use_auth_token,
|
713 |
-
)
|
714 |
|
715 |
-
# freeze encoder
|
716 |
-
if model_args.freeze_feature_encoder:
|
717 |
-
model.freeze_feature_encoder()
|
718 |
|
719 |
-
|
720 |
-
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
721 |
-
# so that we just need to set the correct target sampling rate and normalize the input
|
722 |
-
# via the `feature_extractor`
|
723 |
|
724 |
# make sure that dataset decodes audio with correct sampling rate
|
725 |
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
@@ -759,13 +609,13 @@ def main():
|
|
759 |
vectorized_datasets["train"] = raw_datasets["train"].map(
|
760 |
prepare_dataset,
|
761 |
remove_columns=raw_datasets["train"].column_names,
|
762 |
-
num_proc=
|
763 |
desc="preprocess datasets",
|
764 |
)
|
765 |
vectorized_datasets["eval"] = raw_datasets["eval"].map(
|
766 |
prepare_dataset,
|
767 |
remove_columns=raw_datasets["eval"].column_names,
|
768 |
-
num_proc=
|
769 |
desc="preprocess datasets",
|
770 |
)
|
771 |
|
@@ -775,13 +625,44 @@ def main():
|
|
775 |
# filter data that is shorter than min_input_length
|
776 |
vectorized_datasets = vectorized_datasets.filter(
|
777 |
is_audio_in_length_range,
|
778 |
-
num_proc=
|
779 |
input_columns=["input_length"],
|
780 |
)
|
781 |
|
782 |
-
|
783 |
-
|
784 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
785 |
|
786 |
# Define evaluation metrics during training, *i.e.* word error rate, character error rate
|
787 |
eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
|
@@ -790,12 +671,11 @@ def main():
|
|
790 |
if data_args.dataset_seed is not None:
|
791 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
|
792 |
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
# for large datasets it is advised to run the preprocessing on a
|
800 |
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
801 |
# be a timeout when running the script in distributed mode.
|
@@ -815,9 +695,6 @@ def main():
|
|
815 |
# we do not want to group tokens when computing the metrics
|
816 |
label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
|
817 |
|
818 |
-
print(pred_str[:10])
|
819 |
-
print(label_str[:10])
|
820 |
-
|
821 |
metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
|
822 |
|
823 |
return metrics
|
@@ -845,7 +722,7 @@ def main():
|
|
845 |
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
846 |
|
847 |
# Initialize Trainer
|
848 |
-
|
849 |
model=model,
|
850 |
data_collator=data_collator,
|
851 |
args=training_args,
|
@@ -855,48 +732,62 @@ def main():
|
|
855 |
tokenizer=feature_extractor,
|
856 |
)
|
857 |
|
858 |
-
# 8. Finally, we can start training
|
859 |
|
860 |
-
|
861 |
-
|
|
|
|
|
|
|
|
|
|
|
862 |
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
|
871 |
-
|
872 |
-
|
873 |
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
895 |
|
896 |
-
trainer.log_metrics("eval", metrics)
|
897 |
-
trainer.save_metrics("eval", metrics)
|
898 |
|
899 |
-
|
|
|
900 |
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
|
901 |
kwargs = {
|
902 |
"finetuned_from": model_args.model_name_or_path,
|
@@ -912,9 +803,179 @@ def main():
|
|
912 |
trainer.push_to_hub(**kwargs)
|
913 |
else:
|
914 |
trainer.create_model_card(**kwargs)
|
915 |
-
|
916 |
-
return results
|
917 |
|
918 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
919 |
if __name__ == "__main__":
|
920 |
main()
|
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
|
16 |
+
""" Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition.
|
17 |
+
|
18 |
+
TODO:
|
19 |
+
* add docstring and complete code docs
|
20 |
+
* update model card
|
21 |
+
|
22 |
+
"""
|
23 |
|
24 |
import datetime
|
25 |
import functools
|
|
|
331 |
unk_token: Optional[str] = None,
|
332 |
pad_token: Optional[str] = None,
|
333 |
):
|
334 |
+
|
335 |
# Given training and test labels create vocabulary
|
336 |
def extract_all_chars(batch, vocab):
|
337 |
all_text = " ".join(batch)
|
|
|
363 |
return vocab_dict
|
364 |
|
365 |
|
366 |
+
def set_log_config_and_level(local_rank):
|
|
|
|
|
|
|
367 |
|
368 |
+
logging.basicConfig(
|
369 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
370 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
371 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
372 |
+
)
|
373 |
+
logger.setLevel(logging.INFO if is_main_process(local_rank) else logging.WARN)
|
374 |
+
|
375 |
+
|
376 |
+
def log_to_wandb(training_args):
|
377 |
|
|
|
378 |
try:
|
379 |
repo_name = os.getcwd().split("/")[-1]
|
380 |
run_name = f"{datetime.datetime.utcnow()}".replace(" ", "T")
|
|
|
382 |
wandb.login()
|
383 |
training_args.report_to = ["wandb"]
|
384 |
training_args.run_name = run_name
|
385 |
+
except Exception as e:
|
386 |
+
logger.warning(f"\nFailed logging in to wandb: {e}\nThis experiment will not be logged.\n")
|
387 |
+
|
388 |
+
|
389 |
+
def detect_last_checkpoint(training_args):
|
390 |
|
|
|
391 |
last_checkpoint = None
|
392 |
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
393 |
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
|
|
401 |
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
402 |
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
403 |
)
|
404 |
+
return last_checkpoint
|
405 |
|
406 |
+
|
407 |
+
def log_small_sumary(training_args):
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
# Log on each process the small summary:
|
410 |
logger.warning(
|
|
|
416 |
transformers.utils.logging.set_verbosity_info()
|
417 |
logger.info("Training/evaluation parameters %s", training_args)
|
418 |
|
|
|
|
|
419 |
|
420 |
+
def load_dataset(training_args, data_args):
|
|
|
421 |
|
422 |
+
raw_datasets = DatasetDict()
|
|
|
|
|
|
|
423 |
|
424 |
if training_args.do_train:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
raw_datasets["train"] = load_dataset(
|
426 |
data_args.dataset_name,
|
427 |
data_args.dataset_config_name,
|
|
|
437 |
)
|
438 |
|
439 |
dataset_frequency = raw_datasets["train"].features[data_args.audio_column_name].sampling_rate
|
440 |
+
logger.info(f"Dataset sampling rate: {dataset_frequency}")
|
441 |
|
442 |
if data_args.text_column_name not in raw_datasets["train"].column_names:
|
443 |
raise ValueError(
|
|
|
448 |
|
449 |
if data_args.max_train_samples is not None:
|
450 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
if training_args.do_eval:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
try:
|
455 |
raw_datasets["eval"] = load_dataset(
|
|
|
462 |
split_dataset = raw_datasets["train"].train_test_split(test_size=0.1, seed=42)
|
463 |
raw_datasets["train"] = split_dataset["train"]
|
464 |
raw_datasets["eval"] = split_dataset["test"]
|
465 |
+
logger.info("Eval training set sampled from training set.\nTest size: 0.1\nSeed: 42")
|
|
|
466 |
|
467 |
if data_args.max_eval_samples is not None:
|
468 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
|
|
|
|
469 |
|
470 |
+
return raw_datasets
|
471 |
+
|
472 |
+
|
473 |
+
def clean_dataset(raw_datasets, training_args, data_args):
|
474 |
|
|
|
|
|
|
|
|
|
475 |
chars_to_ignore_regex = (
|
476 |
f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
477 |
)
|
|
|
484 |
return False
|
485 |
|
486 |
def remove_special_characters(batch):
|
487 |
+
|
488 |
+
repl_dict = {
|
489 |
+
"\\\\Punkt": "",
|
490 |
+
"\\\\Komma": "",
|
491 |
+
"è": "e",
|
492 |
+
"é": "e",
|
493 |
+
"î": "i",
|
494 |
+
"ü": "u",
|
495 |
+
"ÿ": "y",
|
496 |
+
"ô": "o",
|
497 |
+
"\\": "",
|
498 |
+
"/": "",
|
499 |
+
"|": ""
|
500 |
+
}
|
501 |
+
|
502 |
if chars_to_ignore_regex is not None:
|
503 |
+
target_text = re.sub(chars_to_ignore_regex, "", batch[text_column_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
else:
|
505 |
+
target_text = batch[text_column_name]
|
506 |
+
|
507 |
+
for orig, repl in repl_dict.items():
|
508 |
+
target_text = target_text.replace(orig, repl)
|
509 |
+
|
510 |
+
batch["target_text"] = target_text.lower() + " "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
511 |
return batch
|
512 |
|
513 |
num_workers = data_args.preprocessing_num_workers
|
|
|
525 |
desc="remove single words, single chars and 'W O R D S'",
|
526 |
)
|
527 |
|
528 |
+
return raw_datasets
|
|
|
|
|
|
|
529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
|
531 |
+
def create_tokenizer_kwargs(raw_datasets, training_args, model_args, data_args, config):
|
532 |
+
|
|
|
|
|
|
|
533 |
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
534 |
tokenizer_kwargs = {}
|
535 |
if tokenizer_name_or_path is None:
|
|
|
547 |
os.makedirs(tokenizer_name_or_path, exist_ok=True)
|
548 |
vocab_dict = create_vocabulary_from_data(
|
549 |
raw_datasets,
|
550 |
+
word_delimiter_token=data_args.word_delimiter_token,
|
551 |
+
unk_token=data_args.unk_token,
|
552 |
+
pad_token=data_args.pad_token,
|
553 |
)
|
554 |
|
555 |
# save vocab dict to be loaded into tokenizer
|
|
|
561 |
tokenizer_kwargs = {
|
562 |
"config": config if config.tokenizer_class is not None else None,
|
563 |
"tokenizer_type": config.model_type if config.tokenizer_class is None else None,
|
564 |
+
"unk_token": data_args.unk_token,
|
565 |
+
"pad_token": data_args.pad_token,
|
566 |
+
"word_delimiter_token": data_args.word_delimiter_token,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
}
|
|
|
568 |
|
569 |
+
return tokenizer_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
|
|
|
|
|
|
|
571 |
|
572 |
+
def vectorize_dataset(raw_datasets, feature_extractor, tokenizer, training_args, data_args):
|
|
|
|
|
|
|
573 |
|
574 |
# make sure that dataset decodes audio with correct sampling rate
|
575 |
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
|
|
609 |
vectorized_datasets["train"] = raw_datasets["train"].map(
|
610 |
prepare_dataset,
|
611 |
remove_columns=raw_datasets["train"].column_names,
|
612 |
+
num_proc=data_args.preprocessing_num_workers,
|
613 |
desc="preprocess datasets",
|
614 |
)
|
615 |
vectorized_datasets["eval"] = raw_datasets["eval"].map(
|
616 |
prepare_dataset,
|
617 |
remove_columns=raw_datasets["eval"].column_names,
|
618 |
+
num_proc=data_args.preprocessing_num_workers,
|
619 |
desc="preprocess datasets",
|
620 |
)
|
621 |
|
|
|
625 |
# filter data that is shorter than min_input_length
|
626 |
vectorized_datasets = vectorized_datasets.filter(
|
627 |
is_audio_in_length_range,
|
628 |
+
num_proc=data_args.preprocessing_num_workers,
|
629 |
input_columns=["input_length"],
|
630 |
)
|
631 |
|
632 |
+
|
633 |
+
def log_dataset_sample_on_wandb(vectorized_datasets, audio_column_name):
|
634 |
+
|
635 |
+
pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
|
636 |
+
pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
|
637 |
+
|
638 |
+
dict_log = {}
|
639 |
+
for i, audio in pd_train[audio_column_name]:
|
640 |
+
dict_log[f"Training sample {i}"] = wandb.Audio(
|
641 |
+
audio["array"],
|
642 |
+
audio_rate=audio["sampling_rate"]
|
643 |
+
)
|
644 |
+
for i, audio in pd_eval[audio_column_name]:
|
645 |
+
dict_log[f"Eval sample {i}"] = wandb.Audio(
|
646 |
+
audio["array"],
|
647 |
+
audio_rate=audio["sampling_rate"]
|
648 |
+
)
|
649 |
+
|
650 |
+
wandb.log({
|
651 |
+
"Training samples": pd_train.drop(labels=audio_column_name, axis=1),
|
652 |
+
"Eval samples": pd_eval.drop(labels=audio_column_name, axis=1),
|
653 |
+
"Audio samples": dict_log
|
654 |
+
})
|
655 |
+
|
656 |
+
|
657 |
+
def prepare_training(
|
658 |
+
model,
|
659 |
+
vectorized_datasets,
|
660 |
+
feature_extractor,
|
661 |
+
tokenizer,
|
662 |
+
training_args,
|
663 |
+
data_args,
|
664 |
+
config
|
665 |
+
):
|
666 |
|
667 |
# Define evaluation metrics during training, *i.e.* word error rate, character error rate
|
668 |
eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
|
|
|
671 |
if data_args.dataset_seed is not None:
|
672 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
|
673 |
|
674 |
+
log_dataset_sample_on_wandb(
|
675 |
+
vectorized_datasets=vectorized_datasets,
|
676 |
+
audio_column_name=data_args.audio_column_name
|
677 |
+
)
|
678 |
+
|
|
|
679 |
# for large datasets it is advised to run the preprocessing on a
|
680 |
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
681 |
# be a timeout when running the script in distributed mode.
|
|
|
695 |
# we do not want to group tokens when computing the metrics
|
696 |
label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
|
697 |
|
|
|
|
|
|
|
698 |
metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
|
699 |
|
700 |
return metrics
|
|
|
722 |
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
723 |
|
724 |
# Initialize Trainer
|
725 |
+
return Trainer(
|
726 |
model=model,
|
727 |
data_collator=data_collator,
|
728 |
args=training_args,
|
|
|
732 |
tokenizer=feature_extractor,
|
733 |
)
|
734 |
|
|
|
735 |
|
736 |
+
def do_training(
|
737 |
+
trainer,
|
738 |
+
last_checkpoint,
|
739 |
+
vectorized_datasets,
|
740 |
+
model_args,
|
741 |
+
data_args
|
742 |
+
):
|
743 |
|
744 |
+
# use last checkpoint if exist
|
745 |
+
if last_checkpoint is not None:
|
746 |
+
checkpoint = last_checkpoint
|
747 |
+
elif os.path.isdir(model_args.model_name_or_path):
|
748 |
+
checkpoint = model_args.model_name_or_path
|
749 |
+
else:
|
750 |
+
checkpoint = None
|
751 |
|
752 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
753 |
+
trainer.save_model()
|
754 |
|
755 |
+
metrics = train_result.metrics
|
756 |
+
max_train_samples = (
|
757 |
+
data_args.max_train_samples
|
758 |
+
if data_args.max_train_samples is not None
|
759 |
+
else len(vectorized_datasets["train"])
|
760 |
+
)
|
761 |
+
metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
|
762 |
|
763 |
+
trainer.log_metrics("train", metrics)
|
764 |
+
trainer.save_metrics("train", metrics)
|
765 |
+
trainer.save_state()
|
766 |
|
767 |
+
return trainer
|
768 |
+
|
769 |
+
|
770 |
+
def do_eval(
|
771 |
+
trainer,
|
772 |
+
vectorized_datasets,
|
773 |
+
data_args
|
774 |
+
):
|
775 |
+
|
776 |
+
logger.info("*** Evaluate ***")
|
777 |
+
metrics = trainer.evaluate()
|
778 |
+
max_eval_samples = (
|
779 |
+
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
|
780 |
+
)
|
781 |
+
metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
|
782 |
+
|
783 |
+
trainer.log_metrics("eval", metrics)
|
784 |
+
trainer.save_metrics("eval", metrics)
|
785 |
+
|
786 |
+
return trainer
|
787 |
|
|
|
|
|
788 |
|
789 |
+
def log_results(trainer, training_args, model_args, data_args):
|
790 |
+
|
791 |
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
|
792 |
kwargs = {
|
793 |
"finetuned_from": model_args.model_name_or_path,
|
|
|
803 |
trainer.push_to_hub(**kwargs)
|
804 |
else:
|
805 |
trainer.create_model_card(**kwargs)
|
|
|
|
|
806 |
|
807 |
|
808 |
+
def inst_model_tokenizer_feature_extractor(
|
809 |
+
tokenizer_kwargs,
|
810 |
+
training_args,
|
811 |
+
model_args,
|
812 |
+
data_args,
|
813 |
+
config
|
814 |
+
):
|
815 |
+
|
816 |
+
# load tokenizer
|
817 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
818 |
+
model_args.tokenizer_name_or_path,
|
819 |
+
use_auth_token=data_args.use_auth_token,
|
820 |
+
**tokenizer_kwargs,
|
821 |
+
)
|
822 |
+
|
823 |
+
# load feature extractor
|
824 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
825 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
|
826 |
+
)
|
827 |
+
|
828 |
+
# adapt config
|
829 |
+
config.update(
|
830 |
+
{
|
831 |
+
"feat_proj_dropout": model_args.feat_proj_dropout,
|
832 |
+
"attention_dropout": model_args.attention_dropout,
|
833 |
+
"hidden_dropout": model_args.hidden_dropout,
|
834 |
+
"final_dropout": model_args.final_dropout,
|
835 |
+
"mask_time_prob": model_args.mask_time_prob,
|
836 |
+
"mask_time_length": model_args.mask_time_length,
|
837 |
+
"mask_feature_prob": model_args.mask_feature_prob,
|
838 |
+
"mask_feature_length": model_args.mask_feature_length,
|
839 |
+
"gradient_checkpointing": training_args.gradient_checkpointing,
|
840 |
+
"layerdrop": model_args.layerdrop,
|
841 |
+
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
842 |
+
"pad_token_id": tokenizer.pad_token_id,
|
843 |
+
"vocab_size": len(tokenizer),
|
844 |
+
"activation_dropout": model_args.activation_dropout,
|
845 |
+
}
|
846 |
+
)
|
847 |
+
|
848 |
+
# load model
|
849 |
+
model = AutoModelForCTC.from_pretrained(
|
850 |
+
model_args.model_name_or_path,
|
851 |
+
cache_dir=model_args.cache_dir,
|
852 |
+
config=config,
|
853 |
+
use_auth_token=data_args.use_auth_token,
|
854 |
+
)
|
855 |
+
|
856 |
+
# freeze encoder
|
857 |
+
if model_args.freeze_feature_encoder:
|
858 |
+
model.freeze_feature_encoder()
|
859 |
+
|
860 |
+
return model, tokenizer, feature_extractor, config
|
861 |
+
|
862 |
+
|
863 |
+
def main():
|
864 |
+
|
865 |
+
# 0. Parse arguments
|
866 |
+
# See all possible arguments in src/transformers/training_args.py
|
867 |
+
# or by passing the --help flag to this script.
|
868 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
869 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
870 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
871 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
872 |
+
# let's parse it to get our arguments.
|
873 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
874 |
+
else:
|
875 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
876 |
+
|
877 |
+
# 1. Set logging
|
878 |
+
set_log_config_and_level(local_rank=training_args.local_rank)
|
879 |
+
training_args = log_to_wandb(training_args=training_args)
|
880 |
+
log_small_sumary(training_args=training_args)
|
881 |
+
|
882 |
+
# 2. Set random seed
|
883 |
+
set_seed(training_args.seed)
|
884 |
+
|
885 |
+
# 3. First, let's load the dataset
|
886 |
+
raw_datasets = load_dataset(training_args=training_args, data_args=data_args)
|
887 |
+
|
888 |
+
# 4. We remove some special characters from the datasets
|
889 |
+
# that make training complicated and do not help in transcribing the speech
|
890 |
+
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
891 |
+
# that could be easily picked up by the model
|
892 |
+
raw_datasets = clean_dataset(
|
893 |
+
raw_datasets=raw_datasets,
|
894 |
+
training_args=training_args,
|
895 |
+
data_args=data_args
|
896 |
+
)
|
897 |
+
|
898 |
+
# 5. Next, let's load the config as we might need it to create the tokenizer
|
899 |
+
config = AutoConfig.from_pretrained(
|
900 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
|
901 |
+
)
|
902 |
+
|
903 |
+
# 6. Next, if no tokenizer file is defined,
|
904 |
+
# we create the vocabulary of the model by extracting all unique characters from
|
905 |
+
# the training and evaluation datasets
|
906 |
+
# We need to make sure that only first rank saves vocabulary
|
907 |
+
# make sure all processes wait until vocab is created
|
908 |
+
tokenizer_kwargs = create_tokenizer_kwargs(
|
909 |
+
raw_datasets=raw_datasets,
|
910 |
+
training_args=training_args,
|
911 |
+
model_args=model_args,
|
912 |
+
data_args=data_args,
|
913 |
+
config=config
|
914 |
+
)
|
915 |
+
|
916 |
+
# 7. Now we can instantiate the feature extractor, tokenizer and model
|
917 |
+
# Note for distributed training, the .from_pretrained methods guarantee that only
|
918 |
+
# one local process can concurrently download model & vocab.
|
919 |
+
model, tokenizer, feature_extractor, config = inst_model_tokenizer_feature_extractor(
|
920 |
+
tokenizer_kwargs=tokenizer_kwargs,
|
921 |
+
training_args=training_args,
|
922 |
+
model_args=model_args,
|
923 |
+
data_args=data_args,
|
924 |
+
config=config
|
925 |
+
)
|
926 |
+
|
927 |
+
# 8. Now we preprocess the datasets including loading the audio, resampling and normalization
|
928 |
+
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
929 |
+
# so that we just need to set the correct target sampling rate and normalize the input
|
930 |
+
# via the `feature_extractor`
|
931 |
+
vectorized_datasets = vectorize_dataset(
|
932 |
+
raw_datasets=raw_datasets,
|
933 |
+
feature_extractor=feature_extractor,
|
934 |
+
tokenizer=tokenizer,
|
935 |
+
training_args=training_args,
|
936 |
+
data_args=data_args
|
937 |
+
)
|
938 |
+
|
939 |
+
# 9. Next, we can prepare the training.
|
940 |
+
# Let's use word error rate (WER) as our evaluation metric,
|
941 |
+
# instantiate a data collator and the trainer
|
942 |
+
trainer = prepare_training(
|
943 |
+
model=model,
|
944 |
+
vectorized_datasets=vectorized_datasets,
|
945 |
+
feature_extractor=feature_extractor,
|
946 |
+
tokenizer=tokenizer,
|
947 |
+
training_args=training_args,
|
948 |
+
data_args=data_args,
|
949 |
+
config=config
|
950 |
+
)
|
951 |
+
|
952 |
+
# 10. Train model
|
953 |
+
last_checkpoint = detect_last_checkpoint(training_args=training_args)
|
954 |
+
if training_args.do_train:
|
955 |
+
trainer = do_training(
|
956 |
+
trainer=trainer,
|
957 |
+
last_checkpoint=last_checkpoint,
|
958 |
+
vectorized_datasets=vectorized_datasets,
|
959 |
+
model_args=model_args,
|
960 |
+
data_args=data_args
|
961 |
+
)
|
962 |
+
|
963 |
+
# 11. Eval model
|
964 |
+
if training_args.do_eval:
|
965 |
+
trainer = do_eval(
|
966 |
+
trainer=trainer,
|
967 |
+
vectorized_datasets=vectorized_datasets,
|
968 |
+
data_args=data_args
|
969 |
+
)
|
970 |
+
|
971 |
+
# 12. Push to hub and update model card
|
972 |
+
log_results(
|
973 |
+
trainer=trainer,
|
974 |
+
training_args=training_args,
|
975 |
+
model_args=model_args,
|
976 |
+
data_args=data_args
|
977 |
+
)
|
978 |
+
|
979 |
+
|
980 |
if __name__ == "__main__":
|
981 |
main()
|