import time from contextlib import suppress import numpy as np import torch from tqdm import tqdm import datetime import os import gc from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, BackwardPrefetch, ShardingStrategy, FullStateDictConfig, StateDictType, ) from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import ( transformer_auto_wrap_policy, enable_wrap, wrap, ) from torch.utils.tensorboard import SummaryWriter import logging logging.basicConfig( level=logging.INFO, format='%(asctime)s %(message)s', datefmt='%m/%d %I:%M:%S', ) def get_cast_dtype(precision: str): cast_dtype = None if precision == "bf16": cast_dtype = torch.bfloat16 elif precision == "fp16": cast_dtype = torch.float16 return cast_dtype def get_autocast(precision): if precision == "amp_fp16": return lambda: torch.cuda.amp.autocast(dtype=torch.float16) elif precision == "amp_bfloat16" or precision == "amp_bf16": # amp_bfloat16 is more stable than amp float16 for clip training return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) else: return suppress def get_sync(model, flag): if flag: return suppress else: return lambda: model.no_sync() def train_one_epoch( args, model, laion_loader, pile_loader, tokenizer, optimizer, lr_scheduler, device_id, writer: SummaryWriter, optim_groups, scaler, total_laion_token: int, total_pile_token: int, total_laion_sample: int, total_step: int, ): world_size = torch.distributed.get_world_size() autocast = get_autocast(args.precision) cast_dtype = get_cast_dtype(args.precision) media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] if args.add_box: box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1] if args.use_format_v2: prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] if args.rank == 0: logging.info(f"train from: {total_step} step") model.train() # loop through dataloader last_logging_step = total_step last_save_step = total_step for num_steps, (batch_laion, batch_pile) in tqdm( enumerate(zip(laion_loader, pile_loader)), disable=args.rank != 0 or "SLURM_PROCID" in os.environ, total=args.num_steps * args.gradient_accumulation_steps, initial=total_step * args.gradient_accumulation_steps, ): #### LAION FORWARD PASS #### images = ( batch_laion[0] .to(device_id, dtype=cast_dtype, non_blocking=True) .unsqueeze(1) .unsqueeze(1) ) image_nums = batch_laion[1] image_start_index_list = batch_laion[2] # TODO: OPT model: input_ids is not started with while input_ids2 is? input_ids = batch_laion[3].to(device_id, non_blocking=True).long() attention_mask = batch_laion[4].to(device_id, dtype=cast_dtype, non_blocking=True) added_bbox_list = [x.to(device_id) for x in batch_laion[5]] # list object total_laion_token += int(attention_mask.sum().long()) * world_size total_laion_sample += sum(image_nums) * world_size labels = input_ids.clone() if args.add_box: labels[input_ids == visual_token_id] = -100 labels[input_ids == box_token_id] = -100 labels[input_ids == endofattr_token_id] = -100 if args.use_format_v2: labels[input_ids == previsual_token_id] = -100 labels[input_ids == prebox_token_id] = -100 labels[torch.roll(input_ids == prebox_token_id, 1)] = -100 labels[torch.roll(input_ids == box_token_id, 1)] = -100 labels[:, 0] = -100 labels[input_ids == tokenizer.pad_token_id] = -100 labels[input_ids == media_token_id] = -100 labels[input_ids == endofmedia_token_id] = -100 labels.to(device_id) current_laion_num = input_ids.shape[0] #### PILE FORWARD PASS #### if batch_pile is not None and batch_pile[0] is not None and batch_pile[1] is not None: input_ids2 = batch_pile[0].to(device_id, non_blocking=True).long() attention_mask2 = batch_pile[1].to(device_id, dtype=cast_dtype, non_blocking=True) input_length = input_ids.shape[-1] input_ids2 = torch.cat([input_ids2, torch.ones((input_ids2.shape[0], input_length - input_ids2.shape[1]), device=input_ids2.device, dtype=input_ids2.dtype) * tokenizer.pad_token_id], dim=-1) attention_mask2 = torch.cat([attention_mask2, torch.zeros((attention_mask2.shape[0], input_length - attention_mask2.shape[1]), device=attention_mask2.device, dtype=attention_mask2.dtype)], dim=-1) labels2 = input_ids2.clone() labels2[labels2 == tokenizer.pad_token_id] = -100 labels2[:, 0] = -100 labels2.to(device_id) if (num_steps != 0 and num_steps % args.pile_freq == 0) or args.pile_freq == 1: image_nums = image_nums + [0] * len(input_ids2) image_start_index_list = image_start_index_list + [[]] * len(input_ids2) input_ids = torch.cat([input_ids, input_ids2], dim=0) attention_mask = torch.cat([attention_mask, attention_mask2], dim=0) labels = torch.cat([labels, labels2], dim=0) total_pile_token += int(attention_mask2.sum().long()) * world_size else: del input_ids2 del attention_mask2 del labels2 if args.instruct: answer_token_id = tokenizer(" Answer").input_ids[0] answer_token_loc = (input_ids == answer_token_id).nonzero() for batch_idx, idx in answer_token_loc: labels[batch_idx][:idx+2] = -100 if args.relation and not args.instruct: relations = batch_laion[6] else: relations = None if len(added_bbox_list) == 0: added_bbox_list = None update_flag = (num_steps != 0 and num_steps % args.gradient_accumulation_steps == 0) or args.gradient_accumulation_steps == 1 # do_sync = get_sync(model, update_flag) with autocast(): # modify: # /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py # /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py # CrossEntropyLoss(reduction="none") outputs = model( vision_x=images, lang_x=input_ids, attention_mask=attention_mask, labels=labels, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=added_bbox_list, add_box=args.add_box, relations=relations, ) loss_total = outputs.loss.reshape(labels.shape[0], -1) loss_sample = loss_total.sum(-1) / (loss_total != 0).sum(-1) loss_sample_for_laion = loss_sample[:current_laion_num] nan_mask = torch.isnan(loss_sample_for_laion) if nan_mask.sum() > 0: logging.warning(f"caption NaN: {nan_mask}") if nan_mask.sum() == len(loss_sample_for_laion) or not model.valid: logging.info("WARNING: skip this caption loss due to some error") loss_laion = torch.tensor(0.0).cuda() else: loss_laion = loss_sample_for_laion[~nan_mask].mean() loss_caption = loss_laion divided_loss_laion = loss_laion / args.gradient_accumulation_steps if current_laion_num != loss_sample.shape[0]: loss_pile = loss_sample[current_laion_num:].mean() else: loss_pile = torch.tensor(0.0).cuda() divided_loss_pile = loss_pile / args.gradient_accumulation_steps if "detection_losses" in outputs: loss_det = outputs["detection_losses"]["loss"] loss_iou = outputs["detection_losses"]["loss_iou"] loss_obj = outputs["detection_losses"]["loss_obj"] loss_cls = outputs["detection_losses"]["loss_cls"] else: loss_det = torch.tensor(0.0).cuda() loss_iou = torch.tensor(0.0).cuda() loss_obj = torch.tensor(0.0).cuda() loss_cls = torch.tensor(0.0).cuda() if "loss_dict" in outputs: visual_loss_iou = outputs["loss_dict"][0]["loss_iou"] previsual_loss_iou = outputs["loss_dict"][1]["loss_iou"] visual_loss_obj = outputs["loss_dict"][0]["loss_obj"] previsual_loss_obj = outputs["loss_dict"][1]["loss_obj"] else: visual_loss_iou = torch.tensor(0.0).cuda() previsual_loss_iou = torch.tensor(0.0).cuda() visual_loss_obj = torch.tensor(0.0).cuda() previsual_loss_obj = torch.tensor(0.0).cuda() divided_loss_det = loss_det / args.gradient_accumulation_steps loss_rel = outputs.get("rel_loss", torch.tensor(0.0).cuda()) divided_loss_rel = loss_rel / args.gradient_accumulation_steps loss = ( divided_loss_laion * args.loss_multiplier_laion + divided_loss_pile * args.loss_multiplier_pile + divided_loss_det * args.loss_multiplier_det + divided_loss_rel * args.loss_multiplier_rel ) scaler.scale(loss).backward() # for logging only loss = ( loss_laion * args.loss_multiplier_laion + loss_pile * args.loss_multiplier_pile + loss_det * args.loss_multiplier_det + loss_rel * args.loss_multiplier_rel ).detach() # step optimizer and log if update_flag: #### MASK GRADIENTS FOR EMBEDDINGS #### # Note (anas): Do not apply weight decay to embeddings as it will break this function. # ! not an important point # if args.ddp: # def mask_embedding(m): # if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad: # zero_mask = torch.zeros_like(m.weight.grad) # zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) # zero_mask[endofmedia_token_id] = torch.ones_like(zero_mask[endofmedia_token_id]) # m.weight.grad = m.weight.grad * zero_mask # model.apply(mask_embedding) total_step += 1 scaler.unscale_(optimizer) if args.ddp: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) else: model.clip_grad_norm_(1.0) scaler.step(optimizer) scaler.update() lr_scheduler.step() optimizer.zero_grad() # https://github.com/facebookresearch/fairscale/issues/627 model.zero_grad(set_to_none=True) if args.rank == 0 and total_step % args.logging_steps == 0 and total_step != last_logging_step: last_logging_step = total_step global_step = total_step lr = optimizer.param_groups[0]["lr"] writer.add_scalar("lr", lr, global_step) writer.add_scalar("scale", scaler.get_scale(), global_step) writer.add_scalar("loss_groundcaption", loss_laion.item(), global_step) writer.add_scalar("loss_laion", loss_caption.item(), global_step) writer.add_scalar("loss_pile", loss_pile.item(), global_step) writer.add_scalar("loss", loss.item(), global_step) writer.add_scalar("loss_det", loss_det.item(), global_step) writer.add_scalar("loss_iou", loss_iou.item(), global_step) writer.add_scalar("loss_obj", loss_obj.item(), global_step) writer.add_scalar("loss_cls", loss_cls.item(), global_step) if loss_rel.item() != 0: writer.add_scalar("loss_rel", loss_rel.item(), global_step) if args.use_format_v2: writer.add_scalar("loss_iou_visual", visual_loss_iou.item(), global_step) writer.add_scalar("loss_obj_visual", visual_loss_obj.item(), global_step) writer.add_scalar("loss_iou_previsual", previsual_loss_iou.item(), global_step) writer.add_scalar("loss_obj_previsual", previsual_loss_obj.item(), global_step) global_sample_num = total_laion_sample writer.add_scalar("loss_groundcaption_vs_sample_num", loss_laion.item(), global_sample_num) writer.add_scalar("loss_laion_vs_sample_num", loss_caption.item(), global_sample_num) writer.add_scalar("loss_pile_vs_sample_num", loss_pile.item(), global_sample_num) writer.add_scalar("loss_vs_sample_num", loss.item(), global_sample_num) writer.add_scalar("loss_det_vs_sample_num", loss_det.item(), global_sample_num) writer.add_scalar("loss_iou_vs_sample_num", loss_iou.item(), global_sample_num) writer.add_scalar("loss_obj_vs_sample_num", loss_obj.item(), global_sample_num) if loss_rel.item() != 0: writer.add_scalar("loss_rel_vs_sample_num", loss_rel.item(), global_sample_num) writer.add_scalar("lr_vs_sample_num", optimizer.param_groups[0]["lr"], global_sample_num) writer.add_scalar("loss_groundcaption_vs_token", loss_laion.item(), total_laion_token) writer.add_scalar("loss_laion_vs_token", loss_caption.item(), total_laion_token) writer.add_scalar("loss_pile_vs_token", loss_pile.item(), total_pile_token) writer.add_scalar("loss_det_vs_token", loss_det.item(), total_laion_token) writer.add_scalar("loss_iou_vs_token", loss_iou.item(), total_laion_token) writer.add_scalar("loss_obj_vs_token", loss_obj.item(), total_laion_token) writer.add_scalar("loss_cls_vs_token", loss_cls.item(), total_laion_token) if loss_rel.item() != 0: writer.add_scalar("loss_rel_vs_token", loss_rel.item(), total_laion_token) total_token = total_laion_token + total_pile_token writer.add_scalar("sample_num", global_sample_num, global_step) writer.add_scalar("total_laion_token", total_laion_token, global_step) writer.add_scalar("total_pile_token", total_pile_token, global_step) writer.add_scalar("total_token", total_token, global_step) logging.info( f"[{global_step}][{total_laion_sample}][{total_token}]. total: {loss.item():.3f} // laion: {loss_caption.item():.3f} // pile: {loss_pile.item():.3f} // iou: {loss_iou.item():.4f} // obj: {loss_obj.item():.4f} // previsual_obj: {previsual_loss_obj.item():.4f} // visual_obj: {visual_loss_obj.item():.4f} // previsual_iou: {previsual_loss_iou.item():.4f} // visual_iou: {visual_loss_iou.item():.4f} // lr: {lr:.2e} // scale: {scaler.get_scale()}" ) if total_step % args.save_interval == 0 and total_step != last_save_step: last_save_step = total_step torch.distributed.barrier() if args.ddp: cpu_state = model.state_dict() # if args.rank == 0: # optimizer_state = optimizer.state_dict() else: save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, save_policy ): cpu_state = model.state_dict() torch.distributed.barrier() # https://pytorch.org/docs/1.12/fsdp.html # need to pass optim_groups as optim_input # optimizer_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_groups) if args.rank == 0: checkpoint_dict = { "model_state_dict": cpu_state, # "optimizer_state_dict": optimizer_state, "lr_scheduler_state_dict": lr_scheduler.state_dict(), "scaler_state_dict": scaler.state_dict(), "total_pile_token": total_pile_token, "total_laion_token": total_laion_token, "total_laion_sample": total_laion_sample, "total_step": total_step, } logging.info(f"Saving checkpoint to {args.run_name}/checkpoint_{total_step}.pt") torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{total_step}.pt") del checkpoint_dict if args.delete_previous_checkpoint and total_step-args.save_interval > 0 and (total_step-args.save_interval) % args.skip_delete_pattern != 0: try: os.remove(f"{args.run_name}/checkpoint_{total_step-args.save_interval}.pt") except: pass torch.distributed.barrier() class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count