MikkoLipsanen commited on
Commit
dc41f2f
1 Parent(s): 8713ab2

Update train_trocr.py

Browse files
Files changed (1) hide show
  1. train_trocr.py +8 -20
train_trocr.py CHANGED
@@ -7,23 +7,22 @@ from evaluate import load
7
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
8
  import torchvision.transforms as transforms
9
  from augments import RandAug, RandRotate
10
- #import torch_optimizer as optim
11
 
12
  parser = argparse.ArgumentParser('arguments for the code')
13
 
14
  parser.add_argument('--root_path', type=str, default="",
15
  help='Root path to data files.')
16
- parser.add_argument('--tr_data_path', type=str, default="/data/htr/taulukkosolut/train/trocr/train_new_anno.csv",
17
  help='Path to .csv file containing the training data.')
18
- parser.add_argument('--val_data_path', type=str, default="/data/htr/taulukkosolut/val/trocr/val_new_anno.csv",
19
  help='Path to .csv file containing the validation data.')
20
- parser.add_argument('--output_path', type=str, default="./output/no_aug_1beam_07092024/",
21
  help='Path for saving training results.')
22
- parser.add_argument('--model_path', type=str, default="/4tb_01/models/htr/supermalli/202405_fp16/",
23
  help='Path to trocr model')
24
- parser.add_argument('--processor_path', type=str, default="/4tb_01/models/htr/supermalli/202405_fp16/processor",
25
  help='Path to trocr processor')
26
- parser.add_argument('--epochs', type=int, default=20,
27
  help='Training epochs.')
28
  parser.add_argument('--batch_size', type=int, default=16,
29
  help='Training epochs.')
@@ -34,13 +33,6 @@ parser.add_argument('--augment', type=int, default=0,
34
 
35
  args = parser.parse_args()
36
 
37
- # nohup python train_trocr.py > logs/taulukkosolut_no_aug_1beam_07092024.txt 2>&1 &
38
- # echo $! > logs/save_pid.txt
39
-
40
- #image_size = (224,224)
41
- #resized_images = []
42
- # run using 2 GPUs: torchrun --nproc_per_node=2 train_trocr.py
43
-
44
  # Initialize processor and model
45
  processor = TrOCRProcessor.from_pretrained(args.processor_path)
46
  model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
@@ -53,8 +45,6 @@ wer_metric = load("wer")
53
  # Load train and validation data to dataframes
54
  train_df = pd.read_csv(args.tr_data_path)
55
  val_df = pd.read_csv(args.val_data_path)
56
- #train_df = train_df.iloc[:10]
57
- #val_df = val_df.iloc[:5]
58
 
59
  # Reset the indices to start from zero
60
  train_df.reset_index(drop=True, inplace=True)
@@ -129,7 +119,6 @@ model.config.num_beams = 1
129
  # Set arguments for model training
130
  # For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
131
  training_args = Seq2SeqTrainingArguments(
132
- learning_rate=2.779367469510554e-05,
133
  predict_with_generate=True,
134
  eval_strategy="epoch",
135
  save_strategy="epoch",
@@ -162,7 +151,7 @@ def compute_metrics(pred):
162
  return {"cer": cer, "wer": wer}
163
 
164
 
165
- # instantiate trainer
166
  # For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
167
  trainer = Seq2SeqTrainer(
168
  model=model,
@@ -178,5 +167,4 @@ trainer = Seq2SeqTrainer(
178
  trainer.train()
179
  #trainer.train(resume_from_checkpoint = True)
180
  model.save_pretrained(args.output_path)
181
- processor.save_pretrained(args.output_path + "/processor")
182
-
 
7
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
8
  import torchvision.transforms as transforms
9
  from augments import RandAug, RandRotate
 
10
 
11
  parser = argparse.ArgumentParser('arguments for the code')
12
 
13
  parser.add_argument('--root_path', type=str, default="",
14
  help='Root path to data files.')
15
+ parser.add_argument('--tr_data_path', type=str, default="/path/to/train_data.csv",
16
  help='Path to .csv file containing the training data.')
17
+ parser.add_argument('--val_data_path', type=str, default="/path/to/val_data.csv",
18
  help='Path to .csv file containing the validation data.')
19
+ parser.add_argument('--output_path', type=str, default="./output/path/",
20
  help='Path for saving training results.')
21
+ parser.add_argument('--model_path', type=str, default="/model/path/",
22
  help='Path to trocr model')
23
+ parser.add_argument('--processor_path', type=str, default="/processor/path/",
24
  help='Path to trocr processor')
25
+ parser.add_argument('--epochs', type=int, default=15,
26
  help='Training epochs.')
27
  parser.add_argument('--batch_size', type=int, default=16,
28
  help='Training epochs.')
 
33
 
34
  args = parser.parse_args()
35
 
 
 
 
 
 
 
 
36
  # Initialize processor and model
37
  processor = TrOCRProcessor.from_pretrained(args.processor_path)
38
  model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
 
45
  # Load train and validation data to dataframes
46
  train_df = pd.read_csv(args.tr_data_path)
47
  val_df = pd.read_csv(args.val_data_path)
 
 
48
 
49
  # Reset the indices to start from zero
50
  train_df.reset_index(drop=True, inplace=True)
 
119
  # Set arguments for model training
120
  # For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
121
  training_args = Seq2SeqTrainingArguments(
 
122
  predict_with_generate=True,
123
  eval_strategy="epoch",
124
  save_strategy="epoch",
 
151
  return {"cer": cer, "wer": wer}
152
 
153
 
154
+ # Instantiate trainer
155
  # For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
156
  trainer = Seq2SeqTrainer(
157
  model=model,
 
167
  trainer.train()
168
  #trainer.train(resume_from_checkpoint = True)
169
  model.save_pretrained(args.output_path)
170
+ processor.save_pretrained(args.output_path + "/processor")