gyrojeff commited on
Commit
fd9442f
1 Parent(s): 8e068be

fix: refine num workers and cli

Browse files
Files changed (1) hide show
  1. train.py +5 -3
train.py CHANGED
@@ -13,13 +13,15 @@ torch.set_float32_matmul_precision("high")
13
 
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
 
16
 
17
  args = parser.parse_args()
18
 
19
  devices = args.devices
 
20
 
21
- final_batch_size = 128
22
- single_device_num_workers = 24
23
 
24
 
25
  lr = 0.0001
@@ -41,7 +43,7 @@ log_every_n_steps = 100
41
  num_device = len(devices)
42
 
43
  data_module = FontDataModule(
44
- batch_size=final_batch_size // num_device,
45
  num_workers=single_device_num_workers,
46
  pin_memory=True,
47
  train_shuffle=True,
 
13
 
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
16
+ parser.add_argument("-b", "--single-batch-size", type=int, default=64)
17
 
18
  args = parser.parse_args()
19
 
20
  devices = args.devices
21
+ single_batch_size = args.single_batch_size
22
 
23
+ total_num_workers = os.cpu_count()
24
+ single_device_num_workers = total_num_workers // len(devices)
25
 
26
 
27
  lr = 0.0001
 
43
  num_device = len(devices)
44
 
45
  data_module = FontDataModule(
46
+ batch_size=single_batch_size,
47
  num_workers=single_device_num_workers,
48
  pin_memory=True,
49
  train_shuffle=True,