crlandsc's picture
added percussion checkpoint and config
79ba331
seed: 12345
train: true
ignore_warnings: true
print_config: false
work_dir: ${hydra:runtime.cwd}
logs_dir: ${work_dir}${oc.env:DIR_LOGS}
data_dir: ${work_dir}${oc.env:DIR_DATA}
ckpt_dir: ${logs_dir}/runs/${now:%Y-%m-%d-%H-%M-%S}
module: main.module_base
batch_size: 16
accumulate_grad_batches: 2
num_workers: 16
sampling_rate: 44100
length: 32768
channels: 2
log_every_n_steps: 500
model:
_target_: ${module}.Model
lr: 0.0001
lr_beta1: 0.95
lr_beta2: 0.999
lr_eps: 1.0e-06
lr_weight_decay: 0.001
ema_beta: 0.995
ema_power: 0.7
model:
_target_: main.DiffusionModel
net_t:
_target_: ${module}.UNetT
in_channels: 2
channels:
- 32
- 32
- 64
- 64
- 128
- 128
- 256
- 256
factors:
- 1
- 2
- 2
- 2
- 2
- 2
- 2
- 2
items:
- 2
- 2
- 2
- 2
- 2
- 2
- 4
- 4
attentions:
- 0
- 0
- 0
- 0
- 0
- 1
- 1
- 1
attention_heads: 8
attention_features: 64
datamodule:
_target_: main.module_base.Datamodule
dataset:
_target_: audio_data_pytorch.WAVDataset
path: ./data/percussion
recursive: true
sample_rate: ${sampling_rate}
transforms:
_target_: audio_data_pytorch.AllTransform
crop_size: ${length}
stereo: true
source_rate: ${sampling_rate}
target_rate: ${sampling_rate}
loudness: -20
val_split: 0.05
batch_size: ${batch_size}
num_workers: ${num_workers}
pin_memory: true
callbacks:
rich_progress_bar:
_target_: pytorch_lightning.callbacks.RichProgressBar
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: valid_loss
save_top_k: 1
save_last: true
mode: min
verbose: false
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
filename: '{epoch:02d}-{valid_loss:.3f}'
model_summary:
_target_: pytorch_lightning.callbacks.RichModelSummary
max_depth: 2
audio_samples_logger:
_target_: main.module_base.SampleLogger
num_items: 4
channels: ${channels}
sampling_rate: ${sampling_rate}
length: ${length}
sampling_steps:
- 50
use_ema_model: true
loggers:
wandb:
_target_: pytorch_lightning.loggers.wandb.WandbLogger
project: ${oc.env:WANDB_PROJECT}
entity: ${oc.env:WANDB_ENTITY}
name: percussion_v0
job_type: train
group: ''
save_dir: ${logs_dir}
trainer:
_target_: pytorch_lightning.Trainer
gpus: 1
precision: 16
accelerator: gpu
min_epochs: 0
max_epochs: -1
enable_model_summary: false
log_every_n_steps: 1
check_val_every_n_epoch: null
val_check_interval: ${log_every_n_steps}
accumulate_grad_batches: ${accumulate_grad_batches}
ckpt: ./logs/ckpts/2023-06-17-21-46-46/last.ckpt