marinone94 commited on
Commit
8829a08
1 Parent(s): 0533144

restructure main code

Browse files
Files changed (2) hide show
  1. prepare_dataset_lm.py +3 -1
  2. run_speech_recognition_ctc.py +330 -269
prepare_dataset_lm.py CHANGED
@@ -1,4 +1,6 @@
1
  """ Script to prepare and upload dataset for training Swedish n-gram LM to boost ASR. """
2
 
3
  # Check colab notebook to get started
4
- # https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Boosting_Wav2Vec2_with_n_grams_in_Transformers.ipynb#scrollTo=IrAzjWc3Ok2l
 
 
 
1
  """ Script to prepare and upload dataset for training Swedish n-gram LM to boost ASR. """
2
 
3
  # Check colab notebook to get started
4
+ # https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Boosting_Wav2Vec2_with_n_grams_in_Transformers.ipynb#scrollTo=IrAzjWc3Ok2l
5
+
6
+ # Notebook train_n_gram_lm_with_KenLM.ipynb has actual code
run_speech_recognition_ctc.py CHANGED
@@ -13,7 +13,13 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
 
16
- """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition"""
 
 
 
 
 
 
17
 
18
  import datetime
19
  import functools
@@ -325,6 +331,7 @@ def create_vocabulary_from_data(
325
  unk_token: Optional[str] = None,
326
  pad_token: Optional[str] = None,
327
  ):
 
328
  # Given training and test labels create vocabulary
329
  def extract_all_chars(batch, vocab):
330
  all_text = " ".join(batch)
@@ -356,20 +363,18 @@ def create_vocabulary_from_data(
356
  return vocab_dict
357
 
358
 
359
- def main():
360
- # See all possible arguments in src/transformers/training_args.py
361
- # or by passing the --help flag to this script.
362
- # We now keep distinct sets of args, for a cleaner separation of concerns.
363
 
364
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
365
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
366
- # If we pass only one argument to the script and it's the path to a json file,
367
- # let's parse it to get our arguments.
368
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
369
- else:
370
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
 
 
371
 
372
- # TODO: Replace with check of wandb env vars
373
  try:
374
  repo_name = os.getcwd().split("/")[-1]
375
  run_name = f"{datetime.datetime.utcnow()}".replace(" ", "T")
@@ -377,11 +382,12 @@ def main():
377
  wandb.login()
378
  training_args.report_to = ["wandb"]
379
  training_args.run_name = run_name
380
- # wandb.init()
381
- except:
382
- pass
 
 
383
 
384
- # Detecting last checkpoint.
385
  last_checkpoint = None
386
  if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
387
  last_checkpoint = get_last_checkpoint(training_args.output_dir)
@@ -395,14 +401,10 @@ def main():
395
  f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
396
  "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
397
  )
 
398
 
399
- # Setup logging
400
- logging.basicConfig(
401
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
402
- datefmt="%m/%d/%Y %H:%M:%S",
403
- handlers=[logging.StreamHandler(sys.stdout)],
404
- )
405
- logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
406
 
407
  # Log on each process the small summary:
408
  logger.warning(
@@ -414,54 +416,12 @@ def main():
414
  transformers.utils.logging.set_verbosity_info()
415
  logger.info("Training/evaluation parameters %s", training_args)
416
 
417
- # Set seed before initializing model.
418
- set_seed(training_args.seed)
419
 
420
- # 1. First, let's load the dataset
421
- raw_datasets = DatasetDict()
422
 
423
- def common_cols(columns_a, columns_b):
424
- col_a = set(columns_a)
425
- col_b = set(columns_b)
426
- return [col for col in col_a if col in col_b]
427
 
428
  if training_args.do_train:
429
-
430
- # Multiple datasets might need to be loaded from HF
431
- # It assumes they all follow the common voice format
432
- # for (dataset_name, dataset_config_name, train_split_name) in zip(
433
- # data_args.dataset_name.split(","),
434
- # data_args.dataset_config_name.split(","),
435
- # data_args.train_split_name.split(","),
436
- # ):
437
-
438
-
439
- # if train_split_name != "None":
440
- # if "train" not in raw_datasets:
441
- # raw_datasets["train"] = load_dataset(
442
- # dataset_name,
443
- # dataset_config_name,
444
- # split=train_split_name,
445
- # use_auth_token=data_args.use_auth_token,
446
- # )
447
- # min_columns_train = raw_datasets["train"].column_names
448
- # else:
449
- # new_dataset = load_dataset(
450
- # dataset_name,
451
- # dataset_config_name,
452
- # split=train_split_name,
453
- # use_auth_token=data_args.use_auth_token,
454
- # )
455
- # raw_datasets["train"] = concatenate_datasets(
456
- # [
457
- # raw_datasets["train"],
458
- # new_dataset
459
- # ]
460
- # )
461
- # min_columns_train = common_cols(min_columns_train, new_dataset.column_names)
462
- # else:
463
- # logging.warning(f"{dataset_name} {dataset_config_name} train not loaded as split is {train_split_name}")
464
-
465
  raw_datasets["train"] = load_dataset(
466
  data_args.dataset_name,
467
  data_args.dataset_config_name,
@@ -477,7 +437,7 @@ def main():
477
  )
478
 
479
  dataset_frequency = raw_datasets["train"].features[data_args.audio_column_name].sampling_rate
480
- print(f"Dataset sampling rate: {dataset_frequency}")
481
 
482
  if data_args.text_column_name not in raw_datasets["train"].column_names:
483
  raise ValueError(
@@ -488,48 +448,8 @@ def main():
488
 
489
  if data_args.max_train_samples is not None:
490
  raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
491
- # other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
492
- # raw_datasets["train"].remove_columns(other_columns_train)
493
-
494
- # pd_train_head = raw_datasets["train"].select(range(10)).to_pandas()
495
- # pd_train_tail = raw_datasets["train"].select(range(raw_datasets["train"].num_rows-10, raw_datasets["train"].num_rows)).to_pandas()
496
- # pd_train = pd.concat([pd_train_head, pd_train_tail])
497
- # print(pd_train["audio"])
498
 
499
  if training_args.do_eval:
500
- # Multiple datasets might need to be loaded from HF
501
- # It assumes they all follow the common voice format
502
- # for (dataset_name, dataset_config_name, eval_split_name) in zip(
503
- # data_args.dataset_name.split(","),
504
- # data_args.dataset_config_name.split(","),
505
- # data_args.eval_split_name.split(","),
506
- # ):
507
-
508
- # if eval_split_name != "None":
509
- # if "eval" not in raw_datasets:
510
- # raw_datasets["eval"] = load_dataset(
511
- # dataset_name,
512
- # dataset_config_name,
513
- # split=eval_split_name,
514
- # use_auth_token=data_args.use_auth_token,
515
- # )
516
- # min_columns_eval = raw_datasets["eval"].column_names
517
- # else:
518
- # new_dataset = load_dataset(
519
- # dataset_name,
520
- # dataset_config_name,
521
- # split=eval_split_name,
522
- # use_auth_token=data_args.use_auth_token,
523
- # )
524
- # raw_datasets["eval"] = concatenate_datasets(
525
- # [
526
- # raw_datasets["eval"],
527
- # new_dataset
528
- # ]
529
- # )
530
- # min_columns_eval = common_cols(min_columns_eval, new_dataset.column_names)
531
- # else:
532
- # logging.warning(f"{dataset_name} {dataset_config_name} eval not loaded as split is {eval_split_name}")
533
 
534
  try:
535
  raw_datasets["eval"] = load_dataset(
@@ -542,23 +462,16 @@ def main():
542
  split_dataset = raw_datasets["train"].train_test_split(test_size=0.1, seed=42)
543
  raw_datasets["train"] = split_dataset["train"]
544
  raw_datasets["eval"] = split_dataset["test"]
545
- print(raw_datasets["eval"])
546
- print("Sampled from training set")
547
 
548
  if data_args.max_eval_samples is not None:
549
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
550
- # other_columns_eval = [col for col in raw_datasets["eval"].column_names if col not in min_columns_eval]
551
- # raw_datasets["eval"].remove_columns(other_columns_eval)
552
 
553
- # pd_eval_head = raw_datasets["eval"].select(range(10)).to_pandas()
554
- # pd_eval_tail = raw_datasets["eval"].select(range(raw_datasets["eval"].num_rows-10, raw_datasets["eval"].num_rows)).to_pandas()
555
- # pd_eval = pd.concat([pd_eval_head, pd_eval_tail])
556
- # print(pd_eval["audio"])
557
 
558
- # 2. We remove some special characters from the datasets
559
- # that make training complicated and do not help in transcribing the speech
560
- # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
561
- # that could be easily picked up by the model
562
  chars_to_ignore_regex = (
563
  f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
564
  )
@@ -571,35 +484,30 @@ def main():
571
  return False
572
 
573
  def remove_special_characters(batch):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  if chars_to_ignore_regex is not None:
575
- batch["target_text"] = \
576
- re.sub(chars_to_ignore_regex, "", batch[text_column_name]) \
577
- .replace("\\\\Punkt", "") \
578
- .replace("\\\\Komma", "") \
579
- .replace("è", "e") \
580
- .replace("é", "e") \
581
- .replace("î", "i") \
582
- .replace("ü", "u") \
583
- .replace("ÿ", "y") \
584
- .replace("ô", "o") \
585
- .replace("\\", "") \
586
- .replace("/", "") \
587
- .replace("|", "") \
588
- .lower() + " "
589
  else:
590
- batch["target_text"] = batch[text_column_name] \
591
- .replace("\\\\Punkt", "") \
592
- .replace("\\\\Komma", "") \
593
- .replace("è", "e") \
594
- .replace("é", "e") \
595
- .replace("î", "i") \
596
- .replace("ü", "u") \
597
- .replace("ÿ", "y") \
598
- .replace("ô", "o") \
599
- .replace("\\", "") \
600
- .replace("/", "") \
601
- .replace("|", "") \
602
- .lower() + " "
603
  return batch
604
 
605
  num_workers = data_args.preprocessing_num_workers
@@ -617,23 +525,11 @@ def main():
617
  desc="remove single words, single chars and 'W O R D S'",
618
  )
619
 
620
- # save special tokens for tokenizer
621
- word_delimiter_token = data_args.word_delimiter_token
622
- unk_token = data_args.unk_token
623
- pad_token = data_args.pad_token
624
 
625
- # 3. Next, let's load the config as we might need it to create
626
- # the tokenizer
627
- # load config
628
- config = AutoConfig.from_pretrained(
629
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
630
- )
631
 
632
- # 4. Next, if no tokenizer file is defined,
633
- # we create the vocabulary of the model by extracting all unique characters from
634
- # the training and evaluation datasets
635
- # We need to make sure that only first rank saves vocabulary
636
- # make sure all processes wait until vocab is created
637
  tokenizer_name_or_path = model_args.tokenizer_name_or_path
638
  tokenizer_kwargs = {}
639
  if tokenizer_name_or_path is None:
@@ -651,9 +547,9 @@ def main():
651
  os.makedirs(tokenizer_name_or_path, exist_ok=True)
652
  vocab_dict = create_vocabulary_from_data(
653
  raw_datasets,
654
- word_delimiter_token=word_delimiter_token,
655
- unk_token=unk_token,
656
- pad_token=pad_token,
657
  )
658
 
659
  # save vocab dict to be loaded into tokenizer
@@ -665,61 +561,15 @@ def main():
665
  tokenizer_kwargs = {
666
  "config": config if config.tokenizer_class is not None else None,
667
  "tokenizer_type": config.model_type if config.tokenizer_class is None else None,
668
- "unk_token": unk_token,
669
- "pad_token": pad_token,
670
- "word_delimiter_token": word_delimiter_token,
671
- }
672
-
673
- # 5. Now we can instantiate the feature extractor, tokenizer and model
674
- # Note for distributed training, the .from_pretrained methods guarantee that only
675
- # one local process can concurrently download model & vocab.
676
-
677
- # load feature_extractor and tokenizer
678
- tokenizer = AutoTokenizer.from_pretrained(
679
- tokenizer_name_or_path,
680
- use_auth_token=data_args.use_auth_token,
681
- **tokenizer_kwargs,
682
- )
683
- feature_extractor = AutoFeatureExtractor.from_pretrained(
684
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
685
- )
686
-
687
- # adapt config
688
- config.update(
689
- {
690
- "feat_proj_dropout": model_args.feat_proj_dropout,
691
- "attention_dropout": model_args.attention_dropout,
692
- "hidden_dropout": model_args.hidden_dropout,
693
- "final_dropout": model_args.final_dropout,
694
- "mask_time_prob": model_args.mask_time_prob,
695
- "mask_time_length": model_args.mask_time_length,
696
- "mask_feature_prob": model_args.mask_feature_prob,
697
- "mask_feature_length": model_args.mask_feature_length,
698
- "gradient_checkpointing": training_args.gradient_checkpointing,
699
- "layerdrop": model_args.layerdrop,
700
- "ctc_loss_reduction": model_args.ctc_loss_reduction,
701
- "pad_token_id": tokenizer.pad_token_id,
702
- "vocab_size": len(tokenizer),
703
- "activation_dropout": model_args.activation_dropout,
704
  }
705
- )
706
 
707
- # create model
708
- model = AutoModelForCTC.from_pretrained(
709
- model_args.model_name_or_path,
710
- cache_dir=model_args.cache_dir,
711
- config=config,
712
- use_auth_token=data_args.use_auth_token,
713
- )
714
 
715
- # freeze encoder
716
- if model_args.freeze_feature_encoder:
717
- model.freeze_feature_encoder()
718
 
719
- # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
720
- # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
721
- # so that we just need to set the correct target sampling rate and normalize the input
722
- # via the `feature_extractor`
723
 
724
  # make sure that dataset decodes audio with correct sampling rate
725
  dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
@@ -759,13 +609,13 @@ def main():
759
  vectorized_datasets["train"] = raw_datasets["train"].map(
760
  prepare_dataset,
761
  remove_columns=raw_datasets["train"].column_names,
762
- num_proc=num_workers,
763
  desc="preprocess datasets",
764
  )
765
  vectorized_datasets["eval"] = raw_datasets["eval"].map(
766
  prepare_dataset,
767
  remove_columns=raw_datasets["eval"].column_names,
768
- num_proc=num_workers,
769
  desc="preprocess datasets",
770
  )
771
 
@@ -775,13 +625,44 @@ def main():
775
  # filter data that is shorter than min_input_length
776
  vectorized_datasets = vectorized_datasets.filter(
777
  is_audio_in_length_range,
778
- num_proc=num_workers,
779
  input_columns=["input_length"],
780
  )
781
 
782
- # 7. Next, we can prepare the training.
783
- # Let's use word error rate (WER) as our evaluation metric,
784
- # instantiate a data collator and the trainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785
 
786
  # Define evaluation metrics during training, *i.e.* word error rate, character error rate
787
  eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
@@ -790,12 +671,11 @@ def main():
790
  if data_args.dataset_seed is not None:
791
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
792
 
793
- # TODO: Log sample of datasets in the right way (see wandb docs)
794
- pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
795
- pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
796
- # wandb.log({"train_sample": pd_train})
797
- # wandb.log({"eval_sample": pd_eval})
798
-
799
  # for large datasets it is advised to run the preprocessing on a
800
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
801
  # be a timeout when running the script in distributed mode.
@@ -815,9 +695,6 @@ def main():
815
  # we do not want to group tokens when computing the metrics
816
  label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
817
 
818
- print(pred_str[:10])
819
- print(label_str[:10])
820
-
821
  metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
822
 
823
  return metrics
@@ -845,7 +722,7 @@ def main():
845
  data_collator = DataCollatorCTCWithPadding(processor=processor)
846
 
847
  # Initialize Trainer
848
- trainer = Trainer(
849
  model=model,
850
  data_collator=data_collator,
851
  args=training_args,
@@ -855,48 +732,62 @@ def main():
855
  tokenizer=feature_extractor,
856
  )
857
 
858
- # 8. Finally, we can start training
859
 
860
- # Training
861
- if training_args.do_train:
 
 
 
 
 
862
 
863
- # use last checkpoint if exist
864
- if last_checkpoint is not None:
865
- checkpoint = last_checkpoint
866
- elif os.path.isdir(model_args.model_name_or_path):
867
- checkpoint = model_args.model_name_or_path
868
- else:
869
- checkpoint = None
870
 
871
- train_result = trainer.train(resume_from_checkpoint=checkpoint)
872
- trainer.save_model()
873
 
874
- metrics = train_result.metrics
875
- max_train_samples = (
876
- data_args.max_train_samples
877
- if data_args.max_train_samples is not None
878
- else len(vectorized_datasets["train"])
879
- )
880
- metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
881
 
882
- trainer.log_metrics("train", metrics)
883
- trainer.save_metrics("train", metrics)
884
- trainer.save_state()
885
 
886
- # Evaluation
887
- results = {}
888
- if training_args.do_eval:
889
- logger.info("*** Evaluate ***")
890
- metrics = trainer.evaluate()
891
- max_eval_samples = (
892
- data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
893
- )
894
- metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
 
 
 
 
 
 
 
 
 
 
 
895
 
896
- trainer.log_metrics("eval", metrics)
897
- trainer.save_metrics("eval", metrics)
898
 
899
- # Write model card and (optionally) push to hub
 
900
  config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
901
  kwargs = {
902
  "finetuned_from": model_args.model_name_or_path,
@@ -912,9 +803,179 @@ def main():
912
  trainer.push_to_hub(**kwargs)
913
  else:
914
  trainer.create_model_card(**kwargs)
915
-
916
- return results
917
 
918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
919
  if __name__ == "__main__":
920
  main()
 
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
 
16
+ """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition.
17
+
18
+ TODO:
19
+ * add docstring and complete code docs
20
+ * update model card
21
+
22
+ """
23
 
24
  import datetime
25
  import functools
 
331
  unk_token: Optional[str] = None,
332
  pad_token: Optional[str] = None,
333
  ):
334
+
335
  # Given training and test labels create vocabulary
336
  def extract_all_chars(batch, vocab):
337
  all_text = " ".join(batch)
 
363
  return vocab_dict
364
 
365
 
366
+ def set_log_config_and_level(local_rank):
 
 
 
367
 
368
+ logging.basicConfig(
369
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
370
+ datefmt="%m/%d/%Y %H:%M:%S",
371
+ handlers=[logging.StreamHandler(sys.stdout)],
372
+ )
373
+ logger.setLevel(logging.INFO if is_main_process(local_rank) else logging.WARN)
374
+
375
+
376
+ def log_to_wandb(training_args):
377
 
 
378
  try:
379
  repo_name = os.getcwd().split("/")[-1]
380
  run_name = f"{datetime.datetime.utcnow()}".replace(" ", "T")
 
382
  wandb.login()
383
  training_args.report_to = ["wandb"]
384
  training_args.run_name = run_name
385
+ except Exception as e:
386
+ logger.warning(f"\nFailed logging in to wandb: {e}\nThis experiment will not be logged.\n")
387
+
388
+
389
+ def detect_last_checkpoint(training_args):
390
 
 
391
  last_checkpoint = None
392
  if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
393
  last_checkpoint = get_last_checkpoint(training_args.output_dir)
 
401
  f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
402
  "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
403
  )
404
+ return last_checkpoint
405
 
406
+
407
+ def log_small_sumary(training_args):
 
 
 
 
 
408
 
409
  # Log on each process the small summary:
410
  logger.warning(
 
416
  transformers.utils.logging.set_verbosity_info()
417
  logger.info("Training/evaluation parameters %s", training_args)
418
 
 
 
419
 
420
+ def load_dataset(training_args, data_args):
 
421
 
422
+ raw_datasets = DatasetDict()
 
 
 
423
 
424
  if training_args.do_train:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  raw_datasets["train"] = load_dataset(
426
  data_args.dataset_name,
427
  data_args.dataset_config_name,
 
437
  )
438
 
439
  dataset_frequency = raw_datasets["train"].features[data_args.audio_column_name].sampling_rate
440
+ logger.info(f"Dataset sampling rate: {dataset_frequency}")
441
 
442
  if data_args.text_column_name not in raw_datasets["train"].column_names:
443
  raise ValueError(
 
448
 
449
  if data_args.max_train_samples is not None:
450
  raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
 
 
 
 
 
 
 
451
 
452
  if training_args.do_eval:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
  try:
455
  raw_datasets["eval"] = load_dataset(
 
462
  split_dataset = raw_datasets["train"].train_test_split(test_size=0.1, seed=42)
463
  raw_datasets["train"] = split_dataset["train"]
464
  raw_datasets["eval"] = split_dataset["test"]
465
+ logger.info("Eval training set sampled from training set.\nTest size: 0.1\nSeed: 42")
 
466
 
467
  if data_args.max_eval_samples is not None:
468
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
 
 
469
 
470
+ return raw_datasets
471
+
472
+
473
+ def clean_dataset(raw_datasets, training_args, data_args):
474
 
 
 
 
 
475
  chars_to_ignore_regex = (
476
  f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
477
  )
 
484
  return False
485
 
486
  def remove_special_characters(batch):
487
+
488
+ repl_dict = {
489
+ "\\\\Punkt": "",
490
+ "\\\\Komma": "",
491
+ "è": "e",
492
+ "é": "e",
493
+ "î": "i",
494
+ "ü": "u",
495
+ "ÿ": "y",
496
+ "ô": "o",
497
+ "\\": "",
498
+ "/": "",
499
+ "|": ""
500
+ }
501
+
502
  if chars_to_ignore_regex is not None:
503
+ target_text = re.sub(chars_to_ignore_regex, "", batch[text_column_name])
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  else:
505
+ target_text = batch[text_column_name]
506
+
507
+ for orig, repl in repl_dict.items():
508
+ target_text = target_text.replace(orig, repl)
509
+
510
+ batch["target_text"] = target_text.lower() + " "
 
 
 
 
 
 
 
511
  return batch
512
 
513
  num_workers = data_args.preprocessing_num_workers
 
525
  desc="remove single words, single chars and 'W O R D S'",
526
  )
527
 
528
+ return raw_datasets
 
 
 
529
 
 
 
 
 
 
 
530
 
531
+ def create_tokenizer_kwargs(raw_datasets, training_args, model_args, data_args, config):
532
+
 
 
 
533
  tokenizer_name_or_path = model_args.tokenizer_name_or_path
534
  tokenizer_kwargs = {}
535
  if tokenizer_name_or_path is None:
 
547
  os.makedirs(tokenizer_name_or_path, exist_ok=True)
548
  vocab_dict = create_vocabulary_from_data(
549
  raw_datasets,
550
+ word_delimiter_token=data_args.word_delimiter_token,
551
+ unk_token=data_args.unk_token,
552
+ pad_token=data_args.pad_token,
553
  )
554
 
555
  # save vocab dict to be loaded into tokenizer
 
561
  tokenizer_kwargs = {
562
  "config": config if config.tokenizer_class is not None else None,
563
  "tokenizer_type": config.model_type if config.tokenizer_class is None else None,
564
+ "unk_token": data_args.unk_token,
565
+ "pad_token": data_args.pad_token,
566
+ "word_delimiter_token": data_args.word_delimiter_token,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  }
 
568
 
569
+ return tokenizer_kwargs
 
 
 
 
 
 
570
 
 
 
 
571
 
572
+ def vectorize_dataset(raw_datasets, feature_extractor, tokenizer, training_args, data_args):
 
 
 
573
 
574
  # make sure that dataset decodes audio with correct sampling rate
575
  dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
 
609
  vectorized_datasets["train"] = raw_datasets["train"].map(
610
  prepare_dataset,
611
  remove_columns=raw_datasets["train"].column_names,
612
+ num_proc=data_args.preprocessing_num_workers,
613
  desc="preprocess datasets",
614
  )
615
  vectorized_datasets["eval"] = raw_datasets["eval"].map(
616
  prepare_dataset,
617
  remove_columns=raw_datasets["eval"].column_names,
618
+ num_proc=data_args.preprocessing_num_workers,
619
  desc="preprocess datasets",
620
  )
621
 
 
625
  # filter data that is shorter than min_input_length
626
  vectorized_datasets = vectorized_datasets.filter(
627
  is_audio_in_length_range,
628
+ num_proc=data_args.preprocessing_num_workers,
629
  input_columns=["input_length"],
630
  )
631
 
632
+
633
+ def log_dataset_sample_on_wandb(vectorized_datasets, audio_column_name):
634
+
635
+ pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
636
+ pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
637
+
638
+ dict_log = {}
639
+ for i, audio in pd_train[audio_column_name]:
640
+ dict_log[f"Training sample {i}"] = wandb.Audio(
641
+ audio["array"],
642
+ audio_rate=audio["sampling_rate"]
643
+ )
644
+ for i, audio in pd_eval[audio_column_name]:
645
+ dict_log[f"Eval sample {i}"] = wandb.Audio(
646
+ audio["array"],
647
+ audio_rate=audio["sampling_rate"]
648
+ )
649
+
650
+ wandb.log({
651
+ "Training samples": pd_train.drop(labels=audio_column_name, axis=1),
652
+ "Eval samples": pd_eval.drop(labels=audio_column_name, axis=1),
653
+ "Audio samples": dict_log
654
+ })
655
+
656
+
657
+ def prepare_training(
658
+ model,
659
+ vectorized_datasets,
660
+ feature_extractor,
661
+ tokenizer,
662
+ training_args,
663
+ data_args,
664
+ config
665
+ ):
666
 
667
  # Define evaluation metrics during training, *i.e.* word error rate, character error rate
668
  eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
 
671
  if data_args.dataset_seed is not None:
672
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
673
 
674
+ log_dataset_sample_on_wandb(
675
+ vectorized_datasets=vectorized_datasets,
676
+ audio_column_name=data_args.audio_column_name
677
+ )
678
+
 
679
  # for large datasets it is advised to run the preprocessing on a
680
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
681
  # be a timeout when running the script in distributed mode.
 
695
  # we do not want to group tokens when computing the metrics
696
  label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
697
 
 
 
 
698
  metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
699
 
700
  return metrics
 
722
  data_collator = DataCollatorCTCWithPadding(processor=processor)
723
 
724
  # Initialize Trainer
725
+ return Trainer(
726
  model=model,
727
  data_collator=data_collator,
728
  args=training_args,
 
732
  tokenizer=feature_extractor,
733
  )
734
 
 
735
 
736
+ def do_training(
737
+ trainer,
738
+ last_checkpoint,
739
+ vectorized_datasets,
740
+ model_args,
741
+ data_args
742
+ ):
743
 
744
+ # use last checkpoint if exist
745
+ if last_checkpoint is not None:
746
+ checkpoint = last_checkpoint
747
+ elif os.path.isdir(model_args.model_name_or_path):
748
+ checkpoint = model_args.model_name_or_path
749
+ else:
750
+ checkpoint = None
751
 
752
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
753
+ trainer.save_model()
754
 
755
+ metrics = train_result.metrics
756
+ max_train_samples = (
757
+ data_args.max_train_samples
758
+ if data_args.max_train_samples is not None
759
+ else len(vectorized_datasets["train"])
760
+ )
761
+ metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
762
 
763
+ trainer.log_metrics("train", metrics)
764
+ trainer.save_metrics("train", metrics)
765
+ trainer.save_state()
766
 
767
+ return trainer
768
+
769
+
770
+ def do_eval(
771
+ trainer,
772
+ vectorized_datasets,
773
+ data_args
774
+ ):
775
+
776
+ logger.info("*** Evaluate ***")
777
+ metrics = trainer.evaluate()
778
+ max_eval_samples = (
779
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
780
+ )
781
+ metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
782
+
783
+ trainer.log_metrics("eval", metrics)
784
+ trainer.save_metrics("eval", metrics)
785
+
786
+ return trainer
787
 
 
 
788
 
789
+ def log_results(trainer, training_args, model_args, data_args):
790
+
791
  config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
792
  kwargs = {
793
  "finetuned_from": model_args.model_name_or_path,
 
803
  trainer.push_to_hub(**kwargs)
804
  else:
805
  trainer.create_model_card(**kwargs)
 
 
806
 
807
 
808
+ def inst_model_tokenizer_feature_extractor(
809
+ tokenizer_kwargs,
810
+ training_args,
811
+ model_args,
812
+ data_args,
813
+ config
814
+ ):
815
+
816
+ # load tokenizer
817
+ tokenizer = AutoTokenizer.from_pretrained(
818
+ model_args.tokenizer_name_or_path,
819
+ use_auth_token=data_args.use_auth_token,
820
+ **tokenizer_kwargs,
821
+ )
822
+
823
+ # load feature extractor
824
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
825
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
826
+ )
827
+
828
+ # adapt config
829
+ config.update(
830
+ {
831
+ "feat_proj_dropout": model_args.feat_proj_dropout,
832
+ "attention_dropout": model_args.attention_dropout,
833
+ "hidden_dropout": model_args.hidden_dropout,
834
+ "final_dropout": model_args.final_dropout,
835
+ "mask_time_prob": model_args.mask_time_prob,
836
+ "mask_time_length": model_args.mask_time_length,
837
+ "mask_feature_prob": model_args.mask_feature_prob,
838
+ "mask_feature_length": model_args.mask_feature_length,
839
+ "gradient_checkpointing": training_args.gradient_checkpointing,
840
+ "layerdrop": model_args.layerdrop,
841
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
842
+ "pad_token_id": tokenizer.pad_token_id,
843
+ "vocab_size": len(tokenizer),
844
+ "activation_dropout": model_args.activation_dropout,
845
+ }
846
+ )
847
+
848
+ # load model
849
+ model = AutoModelForCTC.from_pretrained(
850
+ model_args.model_name_or_path,
851
+ cache_dir=model_args.cache_dir,
852
+ config=config,
853
+ use_auth_token=data_args.use_auth_token,
854
+ )
855
+
856
+ # freeze encoder
857
+ if model_args.freeze_feature_encoder:
858
+ model.freeze_feature_encoder()
859
+
860
+ return model, tokenizer, feature_extractor, config
861
+
862
+
863
+ def main():
864
+
865
+ # 0. Parse arguments
866
+ # See all possible arguments in src/transformers/training_args.py
867
+ # or by passing the --help flag to this script.
868
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
869
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
870
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
871
+ # If we pass only one argument to the script and it's the path to a json file,
872
+ # let's parse it to get our arguments.
873
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
874
+ else:
875
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
876
+
877
+ # 1. Set logging
878
+ set_log_config_and_level(local_rank=training_args.local_rank)
879
+ training_args = log_to_wandb(training_args=training_args)
880
+ log_small_sumary(training_args=training_args)
881
+
882
+ # 2. Set random seed
883
+ set_seed(training_args.seed)
884
+
885
+ # 3. First, let's load the dataset
886
+ raw_datasets = load_dataset(training_args=training_args, data_args=data_args)
887
+
888
+ # 4. We remove some special characters from the datasets
889
+ # that make training complicated and do not help in transcribing the speech
890
+ # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
891
+ # that could be easily picked up by the model
892
+ raw_datasets = clean_dataset(
893
+ raw_datasets=raw_datasets,
894
+ training_args=training_args,
895
+ data_args=data_args
896
+ )
897
+
898
+ # 5. Next, let's load the config as we might need it to create the tokenizer
899
+ config = AutoConfig.from_pretrained(
900
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
901
+ )
902
+
903
+ # 6. Next, if no tokenizer file is defined,
904
+ # we create the vocabulary of the model by extracting all unique characters from
905
+ # the training and evaluation datasets
906
+ # We need to make sure that only first rank saves vocabulary
907
+ # make sure all processes wait until vocab is created
908
+ tokenizer_kwargs = create_tokenizer_kwargs(
909
+ raw_datasets=raw_datasets,
910
+ training_args=training_args,
911
+ model_args=model_args,
912
+ data_args=data_args,
913
+ config=config
914
+ )
915
+
916
+ # 7. Now we can instantiate the feature extractor, tokenizer and model
917
+ # Note for distributed training, the .from_pretrained methods guarantee that only
918
+ # one local process can concurrently download model & vocab.
919
+ model, tokenizer, feature_extractor, config = inst_model_tokenizer_feature_extractor(
920
+ tokenizer_kwargs=tokenizer_kwargs,
921
+ training_args=training_args,
922
+ model_args=model_args,
923
+ data_args=data_args,
924
+ config=config
925
+ )
926
+
927
+ # 8. Now we preprocess the datasets including loading the audio, resampling and normalization
928
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
929
+ # so that we just need to set the correct target sampling rate and normalize the input
930
+ # via the `feature_extractor`
931
+ vectorized_datasets = vectorize_dataset(
932
+ raw_datasets=raw_datasets,
933
+ feature_extractor=feature_extractor,
934
+ tokenizer=tokenizer,
935
+ training_args=training_args,
936
+ data_args=data_args
937
+ )
938
+
939
+ # 9. Next, we can prepare the training.
940
+ # Let's use word error rate (WER) as our evaluation metric,
941
+ # instantiate a data collator and the trainer
942
+ trainer = prepare_training(
943
+ model=model,
944
+ vectorized_datasets=vectorized_datasets,
945
+ feature_extractor=feature_extractor,
946
+ tokenizer=tokenizer,
947
+ training_args=training_args,
948
+ data_args=data_args,
949
+ config=config
950
+ )
951
+
952
+ # 10. Train model
953
+ last_checkpoint = detect_last_checkpoint(training_args=training_args)
954
+ if training_args.do_train:
955
+ trainer = do_training(
956
+ trainer=trainer,
957
+ last_checkpoint=last_checkpoint,
958
+ vectorized_datasets=vectorized_datasets,
959
+ model_args=model_args,
960
+ data_args=data_args
961
+ )
962
+
963
+ # 11. Eval model
964
+ if training_args.do_eval:
965
+ trainer = do_eval(
966
+ trainer=trainer,
967
+ vectorized_datasets=vectorized_datasets,
968
+ data_args=data_args
969
+ )
970
+
971
+ # 12. Push to hub and update model card
972
+ log_results(
973
+ trainer=trainer,
974
+ training_args=training_args,
975
+ model_args=model_args,
976
+ data_args=data_args
977
+ )
978
+
979
+
980
  if __name__ == "__main__":
981
  main()