import jax _ = jax.device_count() # ugly hack to prevent tpu comms to lock/race or smth smh from typing import Tuple, Optional import os from argparse import ArgumentParser from flax_trainer import FlaxTrainerUNetPseudo3D from dataset import load_dataset def train( dataset_path: str, model_path: str, output_dir: str, dataset_cache_dir: Optional[str] = None, from_pt: bool = True, convert2d: bool = False, only_temporal: bool = True, sample_size: Tuple[int, int] = (64, 64), lr: float = 5e-5, batch_size: int = 1, num_frames: int = 24, epochs: int = 10, warmup: float = 0.1, decay: float = 0.0, weight_decay: float = 1e-2, log_every_step: int = 50, save_every_epoch: int = 1, sample_every_epoch: int = 1, seed: int = 0, dtype: str = 'bfloat16', param_dtype: str = 'float32', use_memory_efficient_attention: bool = True, verbose: bool = True, use_wandb: bool = False ) -> None: log = lambda x: print(x) if verbose else None log('\n----------------') log('Init trainer') trainer = FlaxTrainerUNetPseudo3D( model_path = model_path, from_pt = from_pt, convert2d = convert2d, sample_size = sample_size, seed = seed, dtype = dtype, param_dtype = param_dtype, use_memory_efficient_attention = use_memory_efficient_attention, verbose = verbose, only_temporal = only_temporal ) log('\n----------------') log('Init dataset') dataloader = load_dataset( dataset_path = dataset_path, model_path = model_path, cache_dir = dataset_cache_dir, batch_size = batch_size * trainer.num_devices, num_frames = num_frames, num_workers = min(trainer.num_devices * 2, os.cpu_count() - 1), as_numpy = True, shuffle = True ) log('\n----------------') log('Train') if use_wandb: trainer.enable_wandb() trainer.train( dataloader = dataloader, epochs = epochs, num_frames = num_frames, log_every_step = log_every_step, save_every_epoch = save_every_epoch, sample_every_epoch = sample_every_epoch, lr = lr, warmup = warmup, decay = decay, weight_decay = weight_decay, output_dir = output_dir ) log('\n----------------') log('Done') if __name__ == '__main__': parser = ArgumentParser() bool_type = lambda x: x.lower() in ['true', '1', 'yes'] parser.add_argument('-v', '--verbose', type = bool_type, default = True) parser.add_argument('-d', '--dataset_path', required = True) parser.add_argument('-m', '--model_path', required = True) parser.add_argument('-o', '--output_dir', required = True) parser.add_argument('-b', '--batch_size', type = int, default = 1) parser.add_argument('-f', '--num_frames', type = int, default = 24) parser.add_argument('-e', '--epochs', type = int, default = 2) parser.add_argument('--only_temporal', type = bool_type, default = True) parser.add_argument('--dataset_cache_dir', type = str, default = None) parser.add_argument('--from_pt', type = bool_type, default = True) parser.add_argument('--convert2d', type = bool_type, default = False) parser.add_argument('--lr', type = float, default = 1e-4) parser.add_argument('--warmup', type = float, default = 0.1) parser.add_argument('--decay', type = float, default = 0.0) parser.add_argument('--weight_decay', type = float, default = 1e-2) parser.add_argument('--sample_size', type = int, nargs = 2, default = [64, 64]) parser.add_argument('--log_every_step', type = int, default = 250) parser.add_argument('--save_every_epoch', type = int, default = 1) parser.add_argument('--sample_every_epoch', type = int, default = 1) parser.add_argument('--seed', type = int, default = 0) parser.add_argument('--use_memory_efficient_attention', type = bool_type, default = True) parser.add_argument('--dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'bfloat16') parser.add_argument('--param_dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'float32') parser.add_argument('--wandb', type = bool_type, default = False) args = parser.parse_args() args.sample_size = tuple(args.sample_size) if args.verbose: print(args) train( dataset_path = args.dataset_path, model_path = args.model_path, from_pt = args.from_pt, convert2d = args.convert2d, only_temporal = args.only_temporal, output_dir = args.output_dir, dataset_cache_dir = args.dataset_cache_dir, batch_size = args.batch_size, num_frames = args.num_frames, epochs = args.epochs, lr = args.lr, warmup = args.warmup, decay = args.decay, weight_decay = args.weight_decay, sample_size = args.sample_size, seed = args.seed, dtype = args.dtype, param_dtype = args.param_dtype, use_memory_efficient_attention = args.use_memory_efficient_attention, log_every_step = args.log_every_step, save_every_epoch = args.save_every_epoch, sample_every_epoch = args.sample_every_epoch, verbose = args.verbose, use_wandb = args.wandb )