File size: 2,773 Bytes
79ba331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
seed: 12345
train: true
ignore_warnings: true
print_config: false
work_dir: ${hydra:runtime.cwd}
logs_dir: ${work_dir}${oc.env:DIR_LOGS}
data_dir: ${work_dir}${oc.env:DIR_DATA}
ckpt_dir: ${logs_dir}/runs/${now:%Y-%m-%d-%H-%M-%S}
module: main.module_base
batch_size: 16
accumulate_grad_batches: 2
num_workers: 16
sampling_rate: 44100
length: 32768
channels: 2
log_every_n_steps: 500
model:
  _target_: ${module}.Model
  lr: 0.0001
  lr_beta1: 0.95
  lr_beta2: 0.999
  lr_eps: 1.0e-06
  lr_weight_decay: 0.001
  ema_beta: 0.995
  ema_power: 0.7
  model:
    _target_: main.DiffusionModel
    net_t:
      _target_: ${module}.UNetT
    in_channels: 2
    channels:
    - 32
    - 32
    - 64
    - 64
    - 128
    - 128
    - 256
    - 256
    factors:
    - 1
    - 2
    - 2
    - 2
    - 2
    - 2
    - 2
    - 2
    items:
    - 2
    - 2
    - 2
    - 2
    - 2
    - 2
    - 4
    - 4
    attentions:
    - 0
    - 0
    - 0
    - 0
    - 0
    - 1
    - 1
    - 1
    attention_heads: 8
    attention_features: 64
datamodule:
  _target_: main.module_base.Datamodule
  dataset:
    _target_: audio_data_pytorch.WAVDataset
    path: ./data/percussion
    recursive: true
    sample_rate: ${sampling_rate}
    transforms:
      _target_: audio_data_pytorch.AllTransform
      crop_size: ${length}
      stereo: true
      source_rate: ${sampling_rate}
      target_rate: ${sampling_rate}
      loudness: -20
  val_split: 0.05
  batch_size: ${batch_size}
  num_workers: ${num_workers}
  pin_memory: true
callbacks:
  rich_progress_bar:
    _target_: pytorch_lightning.callbacks.RichProgressBar
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: valid_loss
    save_top_k: 1
    save_last: true
    mode: min
    verbose: false
    dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
    filename: '{epoch:02d}-{valid_loss:.3f}'
  model_summary:
    _target_: pytorch_lightning.callbacks.RichModelSummary
    max_depth: 2
  audio_samples_logger:
    _target_: main.module_base.SampleLogger
    num_items: 4
    channels: ${channels}
    sampling_rate: ${sampling_rate}
    length: ${length}
    sampling_steps:
    - 50
    use_ema_model: true
loggers:
  wandb:
    _target_: pytorch_lightning.loggers.wandb.WandbLogger
    project: ${oc.env:WANDB_PROJECT}
    entity: ${oc.env:WANDB_ENTITY}
    name: percussion_v0
    job_type: train
    group: ''
    save_dir: ${logs_dir}
trainer:
  _target_: pytorch_lightning.Trainer
  gpus: 1
  precision: 16
  accelerator: gpu
  min_epochs: 0
  max_epochs: -1
  enable_model_summary: false
  log_every_n_steps: 1
  check_val_every_n_epoch: null
  val_check_interval: ${log_every_n_steps}
  accumulate_grad_batches: ${accumulate_grad_batches}
ckpt: ./logs/ckpts/2023-06-17-21-46-46/last.ckpt