gyrojeff commited on
Commit
5c43f60
1 Parent(s): daa52ce

feat: add test loop and fix training accur

Browse files
Files changed (2) hide show
  1. detector/model.py +28 -6
  2. train.py +1 -0
detector/model.py CHANGED
@@ -91,6 +91,12 @@ class FontDetector(ptl.LightningModule):
91
  self.direction_accur_val = torchmetrics.Accuracy(
92
  task="multiclass", num_classes=2
93
  )
 
 
 
 
 
 
94
  self.lr = lr
95
  self.betas = betas
96
  self.num_warmup_iters = num_warmup_iters
@@ -106,11 +112,7 @@ class FontDetector(ptl.LightningModule):
106
  y_hat = self.forward(X)
107
  loss = self.loss(y_hat, y)
108
  self.log("train_loss", loss, prog_bar=True, sync_dist=True)
109
- return {"loss": loss, "pred": y_hat, "target": y}
110
-
111
- def training_step_end(self, outputs):
112
- y_hat = outputs["pred"]
113
- y = outputs["target"]
114
  self.log(
115
  "train_font_accur",
116
  self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
@@ -123,6 +125,7 @@ class FontDetector(ptl.LightningModule):
123
  ),
124
  sync_dist=True,
125
  )
 
126
 
127
  def on_train_epoch_end(self) -> None:
128
  self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
@@ -143,7 +146,7 @@ class FontDetector(ptl.LightningModule):
143
  self.direction_accur_val.update(
144
  y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
145
  )
146
- return {"loss": loss, "pred": y_hat, "target": y}
147
 
148
  def on_validation_epoch_end(self):
149
  self.log("val_font_accur", self.font_accur_val.compute(), sync_dist=True)
@@ -153,6 +156,25 @@ class FontDetector(ptl.LightningModule):
153
  self.font_accur_val.reset()
154
  self.direction_accur_val.reset()
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def configure_optimizers(self):
157
  optimizer = torch.optim.Adam(
158
  self.model.parameters(), lr=self.lr, betas=self.betas
 
91
  self.direction_accur_val = torchmetrics.Accuracy(
92
  task="multiclass", num_classes=2
93
  )
94
+ self.font_accur_test = torchmetrics.Accuracy(
95
+ task="multiclass", num_classes=config.FONT_COUNT
96
+ )
97
+ self.direction_accur_test = torchmetrics.Accuracy(
98
+ task="multiclass", num_classes=2
99
+ )
100
  self.lr = lr
101
  self.betas = betas
102
  self.num_warmup_iters = num_warmup_iters
 
112
  y_hat = self.forward(X)
113
  loss = self.loss(y_hat, y)
114
  self.log("train_loss", loss, prog_bar=True, sync_dist=True)
115
+ # accur
 
 
 
 
116
  self.log(
117
  "train_font_accur",
118
  self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
 
125
  ),
126
  sync_dist=True,
127
  )
128
+ return {"loss": loss}
129
 
130
  def on_train_epoch_end(self) -> None:
131
  self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
 
146
  self.direction_accur_val.update(
147
  y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
148
  )
149
+ return {"loss": loss}
150
 
151
  def on_validation_epoch_end(self):
152
  self.log("val_font_accur", self.font_accur_val.compute(), sync_dist=True)
 
156
  self.font_accur_val.reset()
157
  self.direction_accur_val.reset()
158
 
159
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
160
+ X, y = batch
161
+ y_hat = self.forward(X)
162
+ loss = self.loss(y_hat, y)
163
+ self.log("test_loss", loss, prog_bar=True, sync_dist=True)
164
+ self.font_accur_test.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
165
+ self.direction_accur_test.update(
166
+ y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
167
+ )
168
+ return {"loss": loss}
169
+
170
+ def on_test_epoch_end(self) -> None:
171
+ self.log("test_font_accur", self.font_accur_test.compute(), sync_dist=True)
172
+ self.log(
173
+ "test_direction_accur", self.direction_accur_test.compute(), sync_dist=True
174
+ )
175
+ self.font_accur_test.reset()
176
+ self.direction_accur_test.reset()
177
+
178
  def configure_optimizers(self):
179
  optimizer = torch.optim.Adam(
180
  self.model.parameters(), lr=self.lr, betas=self.betas
train.py CHANGED
@@ -73,3 +73,4 @@ detector = FontDetector(
73
  )
74
 
75
  trainer.fit(detector, datamodule=data_module)
 
 
73
  )
74
 
75
  trainer.fit(detector, datamodule=data_module)
76
+ trainer.test(detector, datamodule=data_module)