fix: torch.compile
Browse files
train.py
CHANGED
@@ -142,6 +142,9 @@ elif args.model == "resnet101":
|
|
142 |
else:
|
143 |
raise NotImplementedError()
|
144 |
|
|
|
|
|
|
|
145 |
detector = FontDetector(
|
146 |
model=model,
|
147 |
lambda_font=lambda_font,
|
@@ -154,8 +157,5 @@ detector = FontDetector(
|
|
154 |
num_epochs=num_epochs,
|
155 |
)
|
156 |
|
157 |
-
if torch.__version__ >= "2.0":
|
158 |
-
detector = torch.compile(detector)
|
159 |
-
|
160 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
161 |
trainer.test(detector, datamodule=data_module)
|
|
|
142 |
else:
|
143 |
raise NotImplementedError()
|
144 |
|
145 |
+
if torch.__version__ >= "2.0":
|
146 |
+
model = torch.compile(model)
|
147 |
+
|
148 |
detector = FontDetector(
|
149 |
model=model,
|
150 |
lambda_font=lambda_font,
|
|
|
157 |
num_epochs=num_epochs,
|
158 |
)
|
159 |
|
|
|
|
|
|
|
160 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
161 |
trainer.test(detector, datamodule=data_module)
|