feat: add classification only option
Browse files- detector/model.py +10 -2
- train.py +7 -0
detector/model.py
CHANGED
@@ -83,16 +83,21 @@ class ResNet101Regressor(nn.Module):
|
|
83 |
|
84 |
|
85 |
class FontDetectorLoss(nn.Module):
|
86 |
-
def __init__(
|
|
|
|
|
87 |
super().__init__()
|
88 |
self.category_loss = nn.CrossEntropyLoss()
|
89 |
self.regression_loss = nn.MSELoss()
|
90 |
self.lambda_font = lambda_font
|
91 |
self.lambda_direction = lambda_direction
|
92 |
self.lambda_regression = lambda_regression
|
|
|
93 |
|
94 |
def forward(self, y_hat, y):
|
95 |
font_cat = self.category_loss(y_hat[..., : config.FONT_COUNT], y[..., 0].long())
|
|
|
|
|
96 |
direction_cat = self.category_loss(
|
97 |
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1].long()
|
98 |
)
|
@@ -130,6 +135,7 @@ class FontDetector(ptl.LightningModule):
|
|
130 |
lambda_font: float,
|
131 |
lambda_direction: float,
|
132 |
lambda_regression: float,
|
|
|
133 |
lr: float,
|
134 |
betas: Tuple[float, float],
|
135 |
num_warmup_iters: int,
|
@@ -138,7 +144,9 @@ class FontDetector(ptl.LightningModule):
|
|
138 |
):
|
139 |
super().__init__()
|
140 |
self.model = model
|
141 |
-
self.loss = FontDetectorLoss(
|
|
|
|
|
142 |
self.font_accur_train = torchmetrics.Accuracy(
|
143 |
task="multiclass", num_classes=config.FONT_COUNT
|
144 |
)
|
|
|
83 |
|
84 |
|
85 |
class FontDetectorLoss(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self, lambda_font, lambda_direction, lambda_regression, font_classification_only
|
88 |
+
):
|
89 |
super().__init__()
|
90 |
self.category_loss = nn.CrossEntropyLoss()
|
91 |
self.regression_loss = nn.MSELoss()
|
92 |
self.lambda_font = lambda_font
|
93 |
self.lambda_direction = lambda_direction
|
94 |
self.lambda_regression = lambda_regression
|
95 |
+
self.font_classfiication_only = font_classification_only
|
96 |
|
97 |
def forward(self, y_hat, y):
|
98 |
font_cat = self.category_loss(y_hat[..., : config.FONT_COUNT], y[..., 0].long())
|
99 |
+
if self.font_classfiication_only:
|
100 |
+
return self.lambda_font * font_cat
|
101 |
direction_cat = self.category_loss(
|
102 |
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1].long()
|
103 |
)
|
|
|
135 |
lambda_font: float,
|
136 |
lambda_direction: float,
|
137 |
lambda_regression: float,
|
138 |
+
font_classification_only: bool,
|
139 |
lr: float,
|
140 |
betas: Tuple[float, float],
|
141 |
num_warmup_iters: int,
|
|
|
144 |
):
|
145 |
super().__init__()
|
146 |
self.model = model
|
147 |
+
self.loss = FontDetectorLoss(
|
148 |
+
lambda_font, lambda_direction, lambda_regression, font_classification_only
|
149 |
+
)
|
150 |
self.font_accur_train = torchmetrics.Accuracy(
|
151 |
task="multiclass", num_classes=config.FONT_COUNT
|
152 |
)
|
train.py
CHANGED
@@ -84,6 +84,12 @@ parser.add_argument(
|
|
84 |
default=get_current_tag(),
|
85 |
help="Model name (default: current tag)",
|
86 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
args = parser.parse_args()
|
89 |
|
@@ -177,6 +183,7 @@ detector = FontDetector(
|
|
177 |
lambda_font=lambda_font,
|
178 |
lambda_direction=lambda_direction,
|
179 |
lambda_regression=lambda_regression,
|
|
|
180 |
lr=lr,
|
181 |
betas=(b1, b2),
|
182 |
num_warmup_iters=num_warmup_iter,
|
|
|
84 |
default=get_current_tag(),
|
85 |
help="Model name (default: current tag)",
|
86 |
)
|
87 |
+
parser.add_argument(
|
88 |
+
"-f",
|
89 |
+
"--font-classification-only",
|
90 |
+
action="store_true",
|
91 |
+
help="Font classification only (default: False)",
|
92 |
+
)
|
93 |
|
94 |
args = parser.parse_args()
|
95 |
|
|
|
183 |
lambda_font=lambda_font,
|
184 |
lambda_direction=lambda_direction,
|
185 |
lambda_regression=lambda_regression,
|
186 |
+
font_classification_only=args.font_classification_only,
|
187 |
lr=lr,
|
188 |
betas=(b1, b2),
|
189 |
num_warmup_iters=num_warmup_iter,
|