marinone94 commited on
Commit
38706e1
β€’
1 Parent(s): ba980b2

shuffle dataset- fix ttraining params.

Browse files
Files changed (2) hide show
  1. run.sh +7 -4
  2. run_speech_recognition_ctc.py +8 -0
run.sh CHANGED
@@ -6,18 +6,21 @@ python run_speech_recognition_ctc.py \
6
  --eval_split_name="test,None" \
7
  --output_dir="./" \
8
  --overwrite_output_dir \
9
- --num_train_epochs="5" \
10
  --per_device_train_batch_size="16" \
11
  --per_device_eval_batch_size="16" \
12
  --gradient_accumulation_steps="4" \
13
  --learning_rate="7.5e-5" \
14
- --warmup_steps="10000" \
15
  --length_column_name="input_length" \
16
- --evaluation_strategy="epoch" \
17
- --save_strategy="epoch" \
 
 
18
  --text_column_name="sentence" \
19
  --chars_to_ignore , ? . ! \- \; \: \" β€œ % β€˜ ” οΏ½ β€” ’ … – \
20
  --logging_steps="100" \
 
21
  --layerdrop="0.0" \
22
  --activation_dropout="0.1" \
23
  --save_total_limit="3" \
 
6
  --eval_split_name="test,None" \
7
  --output_dir="./" \
8
  --overwrite_output_dir \
9
+ --num_train_epochs="3" \
10
  --per_device_train_batch_size="16" \
11
  --per_device_eval_batch_size="16" \
12
  --gradient_accumulation_steps="4" \
13
  --learning_rate="7.5e-5" \
14
+ --warmup_ratio="0.02" \
15
  --length_column_name="input_length" \
16
+ --evaluation_strategy="steps" \
17
+ --save_strategy="steps" \
18
+ --eval_steps="250" \
19
+ --save_steps="250" \
20
  --text_column_name="sentence" \
21
  --chars_to_ignore , ? . ! \- \; \: \" β€œ % β€˜ ” οΏ½ β€” ’ … – \
22
  --logging_steps="100" \
23
+ --dataset_seed="42" \
24
  --layerdrop="0.0" \
25
  --activation_dropout="0.1" \
26
  --save_total_limit="3" \
run_speech_recognition_ctc.py CHANGED
@@ -252,6 +252,10 @@ class DataTrainingArguments:
252
  " input audio to a sequence of phoneme sequences."
253
  },
254
  )
 
 
 
 
255
 
256
 
257
  @dataclass
@@ -743,6 +747,10 @@ def main():
743
  # Define evaluation metrics during training, *i.e.* word error rate, character error rate
744
  eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
745
 
 
 
 
 
746
  # for large datasets it is advised to run the preprocessing on a
747
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
748
  # be a timeout when running the script in distributed mode.
 
252
  " input audio to a sequence of phoneme sequences."
253
  },
254
  )
255
+ dataset_seed: Optional[int] = field(
256
+ default=None,
257
+ metadata={"help": "Seed for shuffling training data"},
258
+ )
259
 
260
 
261
  @dataclass
 
747
  # Define evaluation metrics during training, *i.e.* word error rate, character error rate
748
  eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
749
 
750
+ # If dataset_seed is set, shuffle train
751
+ if data_args.dataset_seed is not None:
752
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
753
+
754
  # for large datasets it is advised to run the preprocessing on a
755
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
756
  # be a timeout when running the script in distributed mode.