gyrojeff commited on
Commit
3631068
1 Parent(s): 964201e

feat: add model and ptl training loop

Browse files
Files changed (1) hide show
  1. detector/model.py +170 -0
detector/model.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchmetrics
2
+ from . import config
3
+
4
+ from typing import Tuple, Dict, List, Any
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ import torch.nn as nn
10
+ import pytorch_lightning as ptl
11
+
12
+
13
+ class ResNet18Regressor(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.model = torchvision.models.resnet18(pretrained=False)
17
+ self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
18
+
19
+ def forward(self, X):
20
+ X = self.model(X)
21
+ # [0, 1]
22
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
23
+ return X
24
+
25
+
26
+ class FontDetectorLoss(nn.Module):
27
+ def __init__(self, lambda_font, lambda_direction, lambda_regression):
28
+ super().__init__()
29
+ self.category_loss = nn.CrossEntropyLoss()
30
+ self.regression_loss = nn.MSELoss()
31
+ self.lambda_font = lambda_font
32
+ self.lambda_direction = lambda_direction
33
+ self.lambda_regression = lambda_regression
34
+
35
+ def forward(self, y_hat, y):
36
+ font_cat = self.category_loss(y_hat[..., : config.FONT_COUNT], y[..., 0].long())
37
+ direction_cat = self.category_loss(
38
+ y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1].long()
39
+ )
40
+ regression = self.regression_loss(
41
+ y_hat[..., config.FONT_COUNT + 2 :], y[..., 2:]
42
+ )
43
+ return (
44
+ self.lambda_font * font_cat
45
+ + self.lambda_direction * direction_cat
46
+ + self.lambda_regression * regression
47
+ )
48
+
49
+
50
+ class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
51
+ def __init__(self, optimizer, warmup, max_iters):
52
+ self.warmup = warmup
53
+ self.max_num_iters = max_iters
54
+ super().__init__(optimizer)
55
+
56
+ def get_lr(self):
57
+ lr_factor = self.get_lr_factor(epoch=self.last_epoch)
58
+ return [base_lr * lr_factor for base_lr in self.base_lrs]
59
+
60
+ def get_lr_factor(self, epoch):
61
+ lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
62
+ if epoch <= self.warmup:
63
+ lr_factor *= epoch * 1.0 / self.warmup
64
+ return lr_factor
65
+
66
+
67
+ class FontDetector(ptl.LightningModule):
68
+ def __init__(
69
+ self,
70
+ model: nn.Module,
71
+ lambda_font: float,
72
+ lambda_direction: float,
73
+ lambda_regression: float,
74
+ lr: float,
75
+ betas: Tuple[float, float],
76
+ num_warmup_iters: int,
77
+ num_iters: int,
78
+ ):
79
+ super().__init__()
80
+ self.model = model
81
+ self.loss = FontDetectorLoss(lambda_font, lambda_direction, lambda_regression)
82
+ self.font_accur_train = torchmetrics.Accuracy(
83
+ task="multiclass", num_classes=config.FONT_COUNT
84
+ )
85
+ self.direction_accur_train = torchmetrics.Accuracy(
86
+ task="multiclass", num_classes=2
87
+ )
88
+ self.font_accur_val = torchmetrics.Accuracy(
89
+ task="multiclass", num_classes=config.FONT_COUNT
90
+ )
91
+ self.direction_accur_val = torchmetrics.Accuracy(
92
+ task="multiclass", num_classes=2
93
+ )
94
+ self.lr = lr
95
+ self.betas = betas
96
+ self.num_warmup_iters = num_warmup_iters
97
+ self.num_iters = num_iters
98
+
99
+ def forward(self, x):
100
+ return self.model(x)
101
+
102
+ def training_step(
103
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
104
+ ) -> Dict[str, Any]:
105
+ X, y = batch
106
+ y_hat = self.forward(X)
107
+ loss = self.loss(y_hat, y)
108
+ self.log("train_loss", loss, prog_bar=True)
109
+ return {"loss": loss, "pred": y_hat, "target": y}
110
+
111
+ def training_step_end(self, outputs):
112
+ y_hat = outputs["pred"]
113
+ y = outputs["target"]
114
+ self.log(
115
+ "train_font_accur",
116
+ self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
117
+ )
118
+ self.log(
119
+ "train_direction_accur",
120
+ self.direction_accur_train(
121
+ y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
122
+ ),
123
+ )
124
+
125
+ def training_epoch_end(self, outputs) -> None:
126
+ self.font_accur_train.reset()
127
+ self.direction_accur_train.reset()
128
+
129
+ def validation_step(
130
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
131
+ ) -> Dict[str, Any]:
132
+ X, y = batch
133
+ y_hat = self.forward(X)
134
+ loss = self.loss(y_hat, y)
135
+ self.log("val_loss", loss, prog_bar=True)
136
+ self.font_accur_val.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
137
+ self.direction_accur_val.update(
138
+ y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
139
+ )
140
+ return {"loss": loss, "pred": y_hat, "target": y}
141
+
142
+ def validation_epoch_end(self, outputs):
143
+ self.log("val_font_accur", self.font_accur_val.compute())
144
+ self.log("val_direction_accur", self.direction_accur_val.compute())
145
+ self.font_accur_val.reset()
146
+ self.direction_accur_val.reset()
147
+
148
+ def configure_optimizers(self):
149
+ optimizer = torch.optim.Adam(
150
+ self.model.parameters(), lr=self.lr, betas=self.betas
151
+ )
152
+ self.scheduler = CosineWarmupScheduler(
153
+ optimizer, self.num_warmup_iters, self.num_iters
154
+ )
155
+ return optimizer
156
+
157
+ def optimizer_step(
158
+ self,
159
+ epoch: int,
160
+ batch_idx: int,
161
+ optimizer,
162
+ optimizer_idx: int = 0,
163
+ *args,
164
+ **kwargs
165
+ ):
166
+ super().optimizer_step(
167
+ epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs
168
+ )
169
+ self.log("lr", self.scheduler.get_last_lr()[0])
170
+ self.scheduler.step()