gyrojeff commited on
Commit
1750035
1 Parent(s): b06784f

fix: torch.compile

Browse files
Files changed (1) hide show
  1. train.py +3 -3
train.py CHANGED
@@ -142,6 +142,9 @@ elif args.model == "resnet101":
142
  else:
143
  raise NotImplementedError()
144
 
 
 
 
145
  detector = FontDetector(
146
  model=model,
147
  lambda_font=lambda_font,
@@ -154,8 +157,5 @@ detector = FontDetector(
154
  num_epochs=num_epochs,
155
  )
156
 
157
- if torch.__version__ >= "2.0":
158
- detector = torch.compile(detector)
159
-
160
  trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
161
  trainer.test(detector, datamodule=data_module)
 
142
  else:
143
  raise NotImplementedError()
144
 
145
+ if torch.__version__ >= "2.0":
146
+ model = torch.compile(model)
147
+
148
  detector = FontDetector(
149
  model=model,
150
  lambda_font=lambda_font,
 
157
  num_epochs=num_epochs,
158
  )
159
 
 
 
 
160
  trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
161
  trainer.test(detector, datamodule=data_module)