Spaces:
Runtime error
Runtime error
import glob | |
import json | |
import logging | |
import os | |
import re | |
import subprocess | |
import sys | |
import random | |
from datetime import datetime | |
import numpy as np | |
import torch | |
from torch import optim | |
from torch.cuda.amp import GradScaler | |
try: | |
import torch.utils.tensorboard as tensorboard | |
except ImportError: | |
tensorboard = None | |
try: | |
import horovod.torch as hvd | |
except ImportError: | |
hvd = None | |
from .open_clip import create_model_and_transforms, trace_model, get_tokenizer | |
from .data import get_data, PreferenceDataset, RegionDataset, RankingDataset, ImageRewardDataset, HPDDataset | |
from .distributed import is_master, init_distributed_device, broadcast_object, barrier | |
from .logger import setup_logging | |
from .params import parse_args | |
from .scheduler import cosine_lr, const_lr, const_lr_cooldown | |
from .train import evaluate_ranking, train_iters, evaluate_preference, evaluate_regional, unwrap_model | |
from .file_utils import pt_load, save_ckpt, start_sync_process, remote_sync | |
LATEST_CHECKPOINT_NAME = "latest.pt" | |
def random_seed(seed=42, rank=0): | |
torch.manual_seed(seed + rank) | |
np.random.seed(seed + rank) | |
random.seed(seed + rank) | |
def natural_key(string_): | |
"""See http://www.codinghorror.com/blog/archives/001018.html""" | |
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] | |
def get_latest_checkpoint(path: str, remote : bool): | |
# as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders | |
if remote: | |
result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
print(result) | |
if result.returncode == 1: | |
return None | |
checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]] | |
else: | |
checkpoints = glob.glob(path + '**/*.pt', recursive=True) | |
if checkpoints: | |
checkpoints = sorted(checkpoints, key=natural_key) | |
return checkpoints[-1] | |
return None | |
def do_eval(data, model, args, out_dict=None): | |
if out_dict is None: | |
out_dict = {} | |
for d in data['val']: | |
if isinstance(d.dataloader.dataset, PreferenceDataset): | |
out_dict['pref_acc'] = evaluate_preference(model, d, args) | |
if isinstance(d.dataloader.dataset, RegionDataset): | |
out_dict['iou'] = evaluate_regional(model, d, args) | |
if isinstance(d.dataloader.dataset, RankingDataset): | |
out_dict['ranking_acc'] = evaluate_ranking(model, d, args) | |
if isinstance(d.dataloader.dataset, ImageRewardDataset): | |
out_dict['ImageReward_acc'] = evaluate_ranking(model, d, args) | |
return out_dict | |
def main(rank, args): | |
if rank is not None: | |
assert int(os.environ['WORLD_SIZE']) <= 8, "currently only support single node training" | |
os.environ['LOCAL_RANK'] = str(rank) | |
os.environ['RANK'] = str(rank) | |
if torch.cuda.is_available(): | |
# This enables tf32 on Ampere GPUs which is only 8% slower than | |
# float16 and almost as accurate as float32 | |
# This was a default in pytorch until 1.12 | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.deterministic = False | |
# fully initialize distributed device environment | |
device = init_distributed_device(args) | |
# get the name of the experiments | |
if args.name is None: | |
# sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? | |
model_name_safe = args.model.replace('/', '-') | |
date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") | |
if args.distributed: | |
# sync date_str from master to all ranks | |
date_str = broadcast_object(args, date_str) | |
args.name = '-'.join([ | |
date_str, | |
f"model_{model_name_safe}", | |
f"lr_{args.lr}", | |
f"b_{args.batch_size}", | |
f"j_{args.workers}", | |
f"p_{args.precision}", | |
]) | |
resume_latest = args.resume == 'latest' | |
log_base_path = os.path.join(args.logs, args.name) | |
args.log_path = None | |
if is_master(args, local=args.log_local): | |
os.makedirs(log_base_path, exist_ok=True) | |
log_filename = f'out-{args.rank}' if args.log_local else 'out.log' | |
args.log_path = os.path.join(log_base_path, log_filename) | |
if os.path.exists(args.log_path) and not resume_latest: | |
print( | |
"Error. Experiment already exists. Use --name {} to specify a new experiment." | |
) | |
return -1 | |
# Setup text logger | |
args.log_level = logging.DEBUG if args.debug else logging.INFO | |
setup_logging(args.log_path, args.log_level) | |
# Setup tensorboard, checkpoint logging | |
args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to | |
args.checkpoint_path = os.path.join(log_base_path, "checkpoints") | |
if is_master(args): | |
args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else '' | |
for dirname in [args.tensorboard_path, args.checkpoint_path]: | |
if dirname: | |
os.makedirs(dirname, exist_ok=True) | |
else: | |
args.tensorboard_path = '' | |
if resume_latest: | |
resume_from = None | |
checkpoint_path = args.checkpoint_path | |
# If using remote_sync, need to check the remote instead of the local checkpoints folder. | |
if args.remote_sync is not None: | |
checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints") | |
if args.save_most_recent: | |
print('Error. Cannot use save-most-recent with remote_sync and resume latest.') | |
return -1 | |
if args.remote_sync_protocol != 's3': | |
print('Error. Sync protocol not supported when using resume latest.') | |
return -1 | |
if is_master(args): | |
# Checking for existing checkpoint via master rank only. It is possible for | |
# different rank processes to see different files if a shared file-system is under | |
# stress, however it's very difficult to fully work around such situations. | |
if args.save_most_recent: | |
# if --save-most-recent flag is set, look for latest at a fixed filename | |
resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME) | |
if not os.path.exists(resume_from): | |
# If no latest checkpoint has been saved yet, don't try to resume | |
resume_from = None | |
else: | |
# otherwise, list checkpoint dir contents and pick the newest checkpoint | |
resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None) | |
if resume_from: | |
logging.info(f'Found latest resume checkpoint at {resume_from}.') | |
else: | |
logging.info(f'No latest resume checkpoint found in {checkpoint_path}.') | |
if args.distributed: | |
# sync found checkpoint path to all ranks | |
resume_from = broadcast_object(args, resume_from) | |
args.resume = resume_from | |
# start the sync proces if remote-sync is not None | |
remote_sync_process = None | |
if is_master(args) and args.remote_sync is not None: | |
# first make sure it works | |
result = remote_sync( | |
os.path.join(args.logs, args.name), | |
os.path.join(args.remote_sync, args.name), | |
args.remote_sync_protocol | |
) | |
if result: | |
logging.info('remote sync successful.') | |
else: | |
logging.info('Error: remote sync failed. Exiting.') | |
return -1 | |
# if all looks good, start a process to do this every args.remote_sync_frequency seconds | |
remote_sync_process = start_sync_process( | |
args.remote_sync_frequency, | |
os.path.join(args.logs, args.name), | |
os.path.join(args.remote_sync, args.name), | |
args.remote_sync_protocol | |
) | |
remote_sync_process.start() | |
if args.precision == 'fp16': | |
logging.warning( | |
'It is recommended to use AMP mixed-precision instead of FP16. ' | |
'FP16 support needs further verification and tuning, especially for train.') | |
if args.horovod: | |
logging.info( | |
f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' | |
f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') | |
elif args.distributed: | |
logging.info( | |
f'Running in distributed mode with multiple processes. Device: {args.device}.' | |
f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') | |
else: | |
logging.info(f'Running with a single process. Device {args.device}.') | |
dist_model = None | |
args.distill = args.distill_model is not None and args.distill_pretrained is not None | |
if args.distill: | |
#FIXME: support distillation with grad accum. | |
assert args.accum_freq == 1 | |
#FIXME: support distillation with coca. | |
assert 'coca' not in args.model.lower() | |
if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1: | |
# arg is nargs, single (square) image size list -> int | |
args.force_image_size = args.force_image_size[0] | |
random_seed(args.seed, 0) | |
model, preprocess_train, preprocess_val = create_model_and_transforms( | |
args.model, | |
args.pretrained, | |
precision=args.precision, | |
device=device, | |
jit=args.torchscript, | |
force_quick_gelu=args.force_quick_gelu, | |
force_custom_text=args.force_custom_text, | |
force_patch_dropout=args.force_patch_dropout, | |
force_image_size=args.force_image_size, | |
pretrained_image=args.pretrained_image, | |
image_mean=args.image_mean, | |
image_std=args.image_std, | |
light_augmentation=args.light_augmentation, | |
aug_cfg=args.aug_cfg, | |
output_dict=True, | |
with_score_predictor='rating' in args.dataset_type or args.no_text_condition, | |
with_region_predictor='regional' in args.dataset_type | |
) | |
if args.distill: | |
# FIXME: currenlty assumes the model your distilling from has the same tokenizer & transforms. | |
dist_model, _, _ = create_model_and_transforms( | |
args.distill_model, | |
args.distill_pretrained, | |
device=device, | |
precision=args.precision, | |
output_dict=True, | |
) | |
random_seed(args.seed, args.rank) | |
if args.trace: | |
model = trace_model(model, batch_size=args.batch_size, device=device) | |
if args.lock_image: | |
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991 | |
model.lock_image_tower( | |
unlocked_groups=args.lock_image_unlocked_groups, | |
freeze_bn_stats=args.lock_image_freeze_bn_stats) | |
if args.lock_text: | |
model.lock_text_tower( | |
unlocked_layers=args.lock_text_unlocked_layers, | |
freeze_layer_norm=args.lock_text_freeze_layer_norm) | |
if args.grad_checkpointing: | |
model.set_grad_checkpointing() | |
if is_master(args): | |
logging.info("Model:") | |
logging.info(f"{str(model)}") | |
logging.info("Params:") | |
params_file = os.path.join(args.logs, args.name, "params.txt") | |
with open(params_file, "w") as f: | |
for name in sorted(vars(args)): | |
val = getattr(args, name) | |
logging.info(f" {name}: {val}") | |
f.write(f"{name}: {val}\n") | |
if args.distributed and not args.horovod: | |
if args.use_bn_sync: | |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
ddp_args = {} | |
if args.ddp_static_graph: | |
# this doesn't exist in older PyTorch, arg only added if enabled | |
ddp_args['static_graph'] = True | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True,**ddp_args) | |
if args.distill: | |
dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) | |
# create optimizer and scaler | |
optimizer = None | |
scaler = None | |
if args.train_data or args.dataset_type == "synthetic": | |
assert not args.trace, 'Cannot train with traced model' | |
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n | |
include = lambda n, p: not exclude(n, p) | |
named_parameters = list(model.named_parameters()) | |
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] | |
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] | |
optimizer = optim.AdamW( | |
[ | |
{"params": gain_or_bias_params, "weight_decay": 0.}, | |
{"params": rest_params, "weight_decay": args.wd}, | |
], | |
lr=args.lr, | |
betas=(args.beta1, args.beta2), | |
eps=args.eps, | |
) | |
if args.horovod: | |
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) | |
hvd.broadcast_parameters(model.state_dict(), root_rank=0) | |
hvd.broadcast_optimizer_state(optimizer, root_rank=0) | |
scaler = GradScaler() if args.precision == "amp" else None | |
# optionally resume from a checkpoint | |
start_iterations = 0 | |
if args.resume is not None: | |
checkpoint = pt_load(args.resume, map_location='cpu') | |
if 'iterations' in checkpoint: | |
# resuming a train checkpoint w/ epoch and optimizer state | |
start_iterations = checkpoint["iterations"] | |
sd = checkpoint["state_dict"] | |
if not args.distributed and next(iter(sd.items()))[0].startswith('module'): | |
sd = {k[len('module.'):]: v for k, v in sd.items()} | |
model.load_state_dict(sd) | |
if optimizer is not None: | |
optimizer.load_state_dict(checkpoint["optimizer"]) | |
if scaler is not None and 'scaler' in checkpoint: | |
scaler.load_state_dict(checkpoint['scaler']) | |
logging.info(f"=> resuming checkpoint '{args.resume}' (iterations {start_iterations})") | |
else: | |
# loading a bare (model only) checkpoint for fine-tune or evaluation | |
model.load_state_dict(checkpoint) | |
logging.info(f"=> loaded checkpoint '{args.resume}' (iterations {start_iterations})") | |
# initialize datasets | |
data = get_data(args, (preprocess_train, preprocess_val), epoch=0, tokenizer=get_tokenizer(args.model)) | |
assert len(data), 'At least one train or eval dataset must be specified.' | |
# create scheduler if train | |
scheduler = None | |
if 'train' in data and optimizer is not None : | |
total_steps = (args.iterations // args.world_size) * args.world_size | |
if args.lr_scheduler == "cosine": | |
scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) | |
elif args.lr_scheduler == "const": | |
scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps) | |
elif args.lr_scheduler == "const-cooldown": | |
assert args.epochs_cooldown is not None | |
cooldown_steps = (args.iters_cooldown // args.world_size) * args.world_size | |
scheduler = const_lr_cooldown( | |
optimizer, args.lr, args.warmup, total_steps, | |
cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end) | |
else: | |
logging.error( | |
f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') | |
exit(1) | |
# determine if this worker should save logs and checkpoints. only do so if it is rank == 0 | |
args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) | |
writer = None | |
if args.save_logs and args.tensorboard: | |
assert tensorboard is not None, "Please install tensorboard." | |
writer = tensorboard.SummaryWriter(args.tensorboard_path) | |
out_dict = {} | |
if 'train' not in data: | |
out_dict = do_eval(data, model, args, out_dict=out_dict) | |
return out_dict | |
iterations = args.iterations - start_iterations | |
if is_master(args): | |
logging.info(f'Start training for {iterations} iterations.' | |
f'with sample ratio {args.train_data_sample_ratio}' | |
) | |
# train first args.start_eval_iters to stablize model | |
train_iters(model, data, iterations, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer) | |
barrier(args) | |
# final eval after training | |
if 'val' in data: | |
out_dict = do_eval(data, model, args, out_dict=out_dict) | |
if is_master(args): | |
logging.info( | |
f"finished iterations [ {iterations} / {iterations} ] " | |
f"rank acc {out_dict['ranking_acc']} " | |
) | |
if args.save_path is not None: | |
save_ckpt(args, model, scaler, optimizer) | |
barrier(args) | |
# run a final sync. | |
if remote_sync_process is not None: | |
logging.info('Final remote sync.') | |
remote_sync_process.terminate() | |
result = remote_sync( | |
os.path.join(args.logs, args.name), | |
os.path.join(args.remote_sync, args.name), | |
args.remote_sync_protocol | |
) | |
if result: | |
logging.info('Final remote sync successful.') | |
else: | |
logging.info('Final remote sync failed.') | |
if is_master(args): | |
with open("result.json", "w") as f: | |
json.dump(out_dict, f) | |
return out_dict | |
if __name__ == "__main__": | |
args = parse_args(sys.argv[1:]) | |
main(None, args) | |