marinone94 commited on
Commit
833e02b
1 Parent(s): 96a2519

fix script

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -165,10 +165,16 @@ class DataTrainingArguments:
165
  Arguments pertaining to what data we are going to input our model for training and eval.
166
  """
167
 
168
- dataset_name: str = field(
169
  default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
170
  )
171
- dataset_config_name: Optional[str] = field(
 
 
 
 
 
 
172
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
173
  )
174
  text_column: Optional[str] = field(
@@ -529,17 +535,17 @@ def main():
529
  )
530
 
531
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
532
-
533
  if data_args.audio_column_name not in raw_datasets_features:
534
  raise ValueError(
535
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
536
  "Make sure to set `--audio_column_name` to the correct audio column - one of "
537
  f"{', '.join(raw_datasets_features)}."
538
  )
539
 
540
- if data_args.text_column_name not in raw_datasets_features:
541
  raise ValueError(
542
- f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
543
  "Make sure to set `--text_column_name` to the correct text column - one of "
544
  f"{', '.join(raw_datasets_features)}."
545
  )
@@ -600,7 +606,6 @@ def main():
600
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
601
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
602
  audio_column_name = data_args.audio_column_name
603
- text_column_name = data_args.text_column_name
604
  model_input_name = feature_extractor.model_input_names[0]
605
  do_lower_case = data_args.do_lower_case
606
  do_remove_punctuation = data_args.do_remove_punctuation
@@ -761,13 +766,13 @@ def main():
761
  "tasks": "automatic-speech-recognition",
762
  "tags": "whisper-event",
763
  }
764
- if data_args.dataset_name is not None:
765
- kwargs["dataset_tags"] = data_args.dataset_name
766
  if data_args.dataset_config_name is not None:
767
- kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
768
  else:
769
- kwargs["dataset"] = data_args.dataset_name
770
- if "common_voice" in data_args.dataset_name:
771
  kwargs["language"] = data_args.dataset_config_name[:2]
772
  if model_args.model_index_name is not None:
773
  kwargs["model_name"] = model_args.model_index_name
 
165
  Arguments pertaining to what data we are going to input our model for training and eval.
166
  """
167
 
168
+ dataset_train_name: str = field(
169
  default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
170
  )
171
+ dataset_train_config_name: Optional[str] = field(
172
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_eval_name: str = field(
175
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
+ )
177
+ dataset_eval_config_name: Optional[str] = field(
178
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
  )
180
  text_column: Optional[str] = field(
 
535
  )
536
 
537
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
538
+ text_column_name = data_args.text_column_name.split(",")[0]
539
  if data_args.audio_column_name not in raw_datasets_features:
540
  raise ValueError(
541
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_train_name}'. "
542
  "Make sure to set `--audio_column_name` to the correct audio column - one of "
543
  f"{', '.join(raw_datasets_features)}."
544
  )
545
 
546
+ if text_column_name not in raw_datasets_features:
547
  raise ValueError(
548
+ f"--text_column_name {text_column_name} not found in dataset '{data_args.dataset_train_name}'. "
549
  "Make sure to set `--text_column_name` to the correct text column - one of "
550
  f"{', '.join(raw_datasets_features)}."
551
  )
 
606
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
607
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
608
  audio_column_name = data_args.audio_column_name
 
609
  model_input_name = feature_extractor.model_input_names[0]
610
  do_lower_case = data_args.do_lower_case
611
  do_remove_punctuation = data_args.do_remove_punctuation
 
766
  "tasks": "automatic-speech-recognition",
767
  "tags": "whisper-event",
768
  }
769
+ if data_args.dataset_train_name is not None:
770
+ kwargs["dataset_tags"] = data_args.dataset_train_name
771
  if data_args.dataset_config_name is not None:
772
+ kwargs["dataset"] = f"{data_args.dataset_train_name} {data_args.dataset_config_name}"
773
  else:
774
+ kwargs["dataset"] = data_args.dataset_train_name
775
+ if "common_voice" in data_args.dataset_train_name:
776
  kwargs["language"] = data_args.dataset_config_name[:2]
777
  if model_args.model_index_name is not None:
778
  kwargs["model_name"] = model_args.model_index_name