gyrojeff commited on
Commit
afbe904
1 Parent(s): ac3ee6a

feat: add classification only option

Browse files
Files changed (2) hide show
  1. detector/model.py +10 -2
  2. train.py +7 -0
detector/model.py CHANGED
@@ -83,16 +83,21 @@ class ResNet101Regressor(nn.Module):
83
 
84
 
85
  class FontDetectorLoss(nn.Module):
86
- def __init__(self, lambda_font, lambda_direction, lambda_regression):
 
 
87
  super().__init__()
88
  self.category_loss = nn.CrossEntropyLoss()
89
  self.regression_loss = nn.MSELoss()
90
  self.lambda_font = lambda_font
91
  self.lambda_direction = lambda_direction
92
  self.lambda_regression = lambda_regression
 
93
 
94
  def forward(self, y_hat, y):
95
  font_cat = self.category_loss(y_hat[..., : config.FONT_COUNT], y[..., 0].long())
 
 
96
  direction_cat = self.category_loss(
97
  y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1].long()
98
  )
@@ -130,6 +135,7 @@ class FontDetector(ptl.LightningModule):
130
  lambda_font: float,
131
  lambda_direction: float,
132
  lambda_regression: float,
 
133
  lr: float,
134
  betas: Tuple[float, float],
135
  num_warmup_iters: int,
@@ -138,7 +144,9 @@ class FontDetector(ptl.LightningModule):
138
  ):
139
  super().__init__()
140
  self.model = model
141
- self.loss = FontDetectorLoss(lambda_font, lambda_direction, lambda_regression)
 
 
142
  self.font_accur_train = torchmetrics.Accuracy(
143
  task="multiclass", num_classes=config.FONT_COUNT
144
  )
 
83
 
84
 
85
  class FontDetectorLoss(nn.Module):
86
+ def __init__(
87
+ self, lambda_font, lambda_direction, lambda_regression, font_classification_only
88
+ ):
89
  super().__init__()
90
  self.category_loss = nn.CrossEntropyLoss()
91
  self.regression_loss = nn.MSELoss()
92
  self.lambda_font = lambda_font
93
  self.lambda_direction = lambda_direction
94
  self.lambda_regression = lambda_regression
95
+ self.font_classfiication_only = font_classification_only
96
 
97
  def forward(self, y_hat, y):
98
  font_cat = self.category_loss(y_hat[..., : config.FONT_COUNT], y[..., 0].long())
99
+ if self.font_classfiication_only:
100
+ return self.lambda_font * font_cat
101
  direction_cat = self.category_loss(
102
  y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1].long()
103
  )
 
135
  lambda_font: float,
136
  lambda_direction: float,
137
  lambda_regression: float,
138
+ font_classification_only: bool,
139
  lr: float,
140
  betas: Tuple[float, float],
141
  num_warmup_iters: int,
 
144
  ):
145
  super().__init__()
146
  self.model = model
147
+ self.loss = FontDetectorLoss(
148
+ lambda_font, lambda_direction, lambda_regression, font_classification_only
149
+ )
150
  self.font_accur_train = torchmetrics.Accuracy(
151
  task="multiclass", num_classes=config.FONT_COUNT
152
  )
train.py CHANGED
@@ -84,6 +84,12 @@ parser.add_argument(
84
  default=get_current_tag(),
85
  help="Model name (default: current tag)",
86
  )
 
 
 
 
 
 
87
 
88
  args = parser.parse_args()
89
 
@@ -177,6 +183,7 @@ detector = FontDetector(
177
  lambda_font=lambda_font,
178
  lambda_direction=lambda_direction,
179
  lambda_regression=lambda_regression,
 
180
  lr=lr,
181
  betas=(b1, b2),
182
  num_warmup_iters=num_warmup_iter,
 
84
  default=get_current_tag(),
85
  help="Model name (default: current tag)",
86
  )
87
+ parser.add_argument(
88
+ "-f",
89
+ "--font-classification-only",
90
+ action="store_true",
91
+ help="Font classification only (default: False)",
92
+ )
93
 
94
  args = parser.parse_args()
95
 
 
183
  lambda_font=lambda_font,
184
  lambda_direction=lambda_direction,
185
  lambda_regression=lambda_regression,
186
+ font_classification_only=args.font_classification_only,
187
  lr=lr,
188
  betas=(b1, b2),
189
  num_warmup_iters=num_warmup_iter,