gyrojeff commited on
Commit
98da805
1 Parent(s): edd29d3

fix: torch.compile cannot run on windows

Browse files
Files changed (1) hide show
  1. train.py +2 -2
train.py CHANGED
@@ -152,7 +152,7 @@ logger_unconditioned = TensorBoardLogger(
152
  save_dir=os.getcwd(), name="tensorboard", version=model_name
153
  )
154
 
155
- strategy = None if num_device == 1 else "ddp"
156
 
157
  trainer = ptl.Trainer(
158
  max_epochs=num_epochs,
@@ -189,7 +189,7 @@ elif args.model == "deepfont":
189
  else:
190
  raise NotImplementedError()
191
 
192
- if torch.__version__ >= "2.0":
193
  model = torch.compile(model)
194
 
195
  detector = FontDetector(
 
152
  save_dir=os.getcwd(), name="tensorboard", version=model_name
153
  )
154
 
155
+ strategy = "auto" if num_device == 1 else "ddp"
156
 
157
  trainer = ptl.Trainer(
158
  max_epochs=num_epochs,
 
189
  else:
190
  raise NotImplementedError()
191
 
192
+ if torch.__version__ >= "2.0" and os.name == "posix":
193
  model = torch.compile(model)
194
 
195
  detector = FontDetector(