MikkoLipsanen
commited on
Commit
•
dc41f2f
1
Parent(s):
8713ab2
Update train_trocr.py
Browse files- train_trocr.py +8 -20
train_trocr.py
CHANGED
@@ -7,23 +7,22 @@ from evaluate import load
|
|
7 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
|
8 |
import torchvision.transforms as transforms
|
9 |
from augments import RandAug, RandRotate
|
10 |
-
#import torch_optimizer as optim
|
11 |
|
12 |
parser = argparse.ArgumentParser('arguments for the code')
|
13 |
|
14 |
parser.add_argument('--root_path', type=str, default="",
|
15 |
help='Root path to data files.')
|
16 |
-
parser.add_argument('--tr_data_path', type=str, default="/
|
17 |
help='Path to .csv file containing the training data.')
|
18 |
-
parser.add_argument('--val_data_path', type=str, default="/
|
19 |
help='Path to .csv file containing the validation data.')
|
20 |
-
parser.add_argument('--output_path', type=str, default="./output/
|
21 |
help='Path for saving training results.')
|
22 |
-
parser.add_argument('--model_path', type=str, default="/
|
23 |
help='Path to trocr model')
|
24 |
-
parser.add_argument('--processor_path', type=str, default="/
|
25 |
help='Path to trocr processor')
|
26 |
-
parser.add_argument('--epochs', type=int, default=
|
27 |
help='Training epochs.')
|
28 |
parser.add_argument('--batch_size', type=int, default=16,
|
29 |
help='Training epochs.')
|
@@ -34,13 +33,6 @@ parser.add_argument('--augment', type=int, default=0,
|
|
34 |
|
35 |
args = parser.parse_args()
|
36 |
|
37 |
-
# nohup python train_trocr.py > logs/taulukkosolut_no_aug_1beam_07092024.txt 2>&1 &
|
38 |
-
# echo $! > logs/save_pid.txt
|
39 |
-
|
40 |
-
#image_size = (224,224)
|
41 |
-
#resized_images = []
|
42 |
-
# run using 2 GPUs: torchrun --nproc_per_node=2 train_trocr.py
|
43 |
-
|
44 |
# Initialize processor and model
|
45 |
processor = TrOCRProcessor.from_pretrained(args.processor_path)
|
46 |
model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
|
@@ -53,8 +45,6 @@ wer_metric = load("wer")
|
|
53 |
# Load train and validation data to dataframes
|
54 |
train_df = pd.read_csv(args.tr_data_path)
|
55 |
val_df = pd.read_csv(args.val_data_path)
|
56 |
-
#train_df = train_df.iloc[:10]
|
57 |
-
#val_df = val_df.iloc[:5]
|
58 |
|
59 |
# Reset the indices to start from zero
|
60 |
train_df.reset_index(drop=True, inplace=True)
|
@@ -129,7 +119,6 @@ model.config.num_beams = 1
|
|
129 |
# Set arguments for model training
|
130 |
# For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
|
131 |
training_args = Seq2SeqTrainingArguments(
|
132 |
-
learning_rate=2.779367469510554e-05,
|
133 |
predict_with_generate=True,
|
134 |
eval_strategy="epoch",
|
135 |
save_strategy="epoch",
|
@@ -162,7 +151,7 @@ def compute_metrics(pred):
|
|
162 |
return {"cer": cer, "wer": wer}
|
163 |
|
164 |
|
165 |
-
#
|
166 |
# For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
|
167 |
trainer = Seq2SeqTrainer(
|
168 |
model=model,
|
@@ -178,5 +167,4 @@ trainer = Seq2SeqTrainer(
|
|
178 |
trainer.train()
|
179 |
#trainer.train(resume_from_checkpoint = True)
|
180 |
model.save_pretrained(args.output_path)
|
181 |
-
processor.save_pretrained(args.output_path + "/processor")
|
182 |
-
|
|
|
7 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
|
8 |
import torchvision.transforms as transforms
|
9 |
from augments import RandAug, RandRotate
|
|
|
10 |
|
11 |
parser = argparse.ArgumentParser('arguments for the code')
|
12 |
|
13 |
parser.add_argument('--root_path', type=str, default="",
|
14 |
help='Root path to data files.')
|
15 |
+
parser.add_argument('--tr_data_path', type=str, default="/path/to/train_data.csv",
|
16 |
help='Path to .csv file containing the training data.')
|
17 |
+
parser.add_argument('--val_data_path', type=str, default="/path/to/val_data.csv",
|
18 |
help='Path to .csv file containing the validation data.')
|
19 |
+
parser.add_argument('--output_path', type=str, default="./output/path/",
|
20 |
help='Path for saving training results.')
|
21 |
+
parser.add_argument('--model_path', type=str, default="/model/path/",
|
22 |
help='Path to trocr model')
|
23 |
+
parser.add_argument('--processor_path', type=str, default="/processor/path/",
|
24 |
help='Path to trocr processor')
|
25 |
+
parser.add_argument('--epochs', type=int, default=15,
|
26 |
help='Training epochs.')
|
27 |
parser.add_argument('--batch_size', type=int, default=16,
|
28 |
help='Training epochs.')
|
|
|
33 |
|
34 |
args = parser.parse_args()
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# Initialize processor and model
|
37 |
processor = TrOCRProcessor.from_pretrained(args.processor_path)
|
38 |
model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
|
|
|
45 |
# Load train and validation data to dataframes
|
46 |
train_df = pd.read_csv(args.tr_data_path)
|
47 |
val_df = pd.read_csv(args.val_data_path)
|
|
|
|
|
48 |
|
49 |
# Reset the indices to start from zero
|
50 |
train_df.reset_index(drop=True, inplace=True)
|
|
|
119 |
# Set arguments for model training
|
120 |
# For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
|
121 |
training_args = Seq2SeqTrainingArguments(
|
|
|
122 |
predict_with_generate=True,
|
123 |
eval_strategy="epoch",
|
124 |
save_strategy="epoch",
|
|
|
151 |
return {"cer": cer, "wer": wer}
|
152 |
|
153 |
|
154 |
+
# Instantiate trainer
|
155 |
# For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
|
156 |
trainer = Seq2SeqTrainer(
|
157 |
model=model,
|
|
|
167 |
trainer.train()
|
168 |
#trainer.train(resume_from_checkpoint = True)
|
169 |
model.save_pretrained(args.output_path)
|
170 |
+
processor.save_pretrained(args.output_path + "/processor")
|
|