fix: torch.compile cannot run on windows
Browse files
train.py
CHANGED
@@ -152,7 +152,7 @@ logger_unconditioned = TensorBoardLogger(
|
|
152 |
save_dir=os.getcwd(), name="tensorboard", version=model_name
|
153 |
)
|
154 |
|
155 |
-
strategy =
|
156 |
|
157 |
trainer = ptl.Trainer(
|
158 |
max_epochs=num_epochs,
|
@@ -189,7 +189,7 @@ elif args.model == "deepfont":
|
|
189 |
else:
|
190 |
raise NotImplementedError()
|
191 |
|
192 |
-
if torch.__version__ >= "2.0":
|
193 |
model = torch.compile(model)
|
194 |
|
195 |
detector = FontDetector(
|
|
|
152 |
save_dir=os.getcwd(), name="tensorboard", version=model_name
|
153 |
)
|
154 |
|
155 |
+
strategy = "auto" if num_device == 1 else "ddp"
|
156 |
|
157 |
trainer = ptl.Trainer(
|
158 |
max_epochs=num_epochs,
|
|
|
189 |
else:
|
190 |
raise NotImplementedError()
|
191 |
|
192 |
+
if torch.__version__ >= "2.0" and os.name == "posix":
|
193 |
model = torch.compile(model)
|
194 |
|
195 |
detector = FontDetector(
|