""" Main training script """ import argparse import copy import glob import os import random import functools import numpy as np import torch # torch.multiprocessing.set_sharing_strategy('file_system') import wandb from data2 import get_data from distributed import init_distributed_device, world_info_from_env from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, BackwardPrefetch, ShardingStrategy, FullStateDictConfig, CPUOffload, StateDictType, ) from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import ( transformer_auto_wrap_policy, enable_wrap, wrap, ) from train_utils import train_one_epoch from transformers import ( get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup, ) from open_flamingo import create_model_and_transforms from torch.utils.tensorboard import SummaryWriter from torch.nn.parallel import DistributedDataParallel as DDP from torch.cuda.amp import GradScaler from torch.distributed.optim import ZeroRedundancyOptimizer import warnings warnings.filterwarnings("ignore") import logging logging.basicConfig( level=logging.INFO, format='%(asctime)s %(message)s', datefmt='%m/%d %I:%M:%S', ) class FakeDataloader: def __iter__(self): return self def __next__(self): return None def random_seed(seed=42, rank=0): torch.manual_seed(seed + rank) np.random.seed(seed + rank) random.seed(seed + rank) def get_grouped_params(model, args): params_with_wd, params_without_wd = [], [] def apply_decay(x): x = x.lower() return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x for n, p in model.named_parameters(): # if p.requires_grad: if apply_decay(n): if torch.distributed.get_rank() == 0: logging.info(f"with wd: {n}") params_with_wd.append(p) else: if torch.distributed.get_rank() == 0: logging.info(f"without wd: {n}") params_without_wd.append(p) return [ {"params": params_with_wd, "weight_decay": args.weight_decay}, {"params": params_without_wd, "weight_decay": 0.0}, ] def lambda_policy_fn(module): if ( len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and module.weight.requires_grad ): return True return False def lambda_auto_wrap_policy( module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn, ) -> bool: """ A convenient auto wrap policy to wrap submodules based on an arbitrary user function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as a `wrapper_cls` unit. Return if a module should be wrapped during auto wrapping. The first three parameters are required by :func:`_recursive_wrap`. Args: module (nn.Module): Current module being considered. recurse (bool): If ``False``, then this function must decide whether ``module`` should be wrapped as an FSDP instance or not. If ``True``, then the function is still recursing down the module tree as a part of the DFS. nonwrapped_numel (int): Parameter numel not yet wrapped. lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then this module will be wrapped. """ if recurse: return True # always recurse return lambda_fn(module) def main(): parser = argparse.ArgumentParser() parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str) parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str) parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str) parser.add_argument( "--tokenizer_path", default="facebook/opt-1.3b", type=str, help="path to tokenizer", ) parser.add_argument( "--run_name", type=str, default="openflamingo3B", help="used to name saving directory and wandb run", ) parser.add_argument("--use_media_placement_augmentation", action="store_true") parser.add_argument("--offline", action="store_true") parser.add_argument("--num_steps", type=int, default=300000) parser.add_argument( "--logging_steps", type=int, default=10, help="log loss every n steps" ) # Sum of gradient optimization batch size parser.add_argument("--batch_size_mmc4", type=int, default=128) parser.add_argument("--batch_size_laion", type=int, default=128) parser.add_argument("--batch_size_pile", type=int, default=128) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument( "--resume_from_checkpoint", type=str, help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", default=None, ) parser.add_argument( "--delete_previous_checkpoint", action="store_true", help="delete previous checkpoint when saving new checkpoint", ) parser.add_argument( "--laion_shards", type=str, help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", ) parser.add_argument( "--mmc4_shards", type=str, help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", ) parser.add_argument( "--pile_shards", type=str, default=None, help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", ) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--learning_rate", default=1e-4, type=float) parser.add_argument( "--lr_scheduler", default="constant", type=str, help="constant, linear, or cosine", ) parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0) parser.add_argument("--loss_multiplier_laion", type=float, default=1.0) parser.add_argument("--loss_multiplier_pile", type=float, default=1.0) parser.add_argument("--loss_multiplier_det", type=float, default=1.0) parser.add_argument("--loss_multiplier_rel", type=float, default=1.0) parser.add_argument("--loss_multiplier_attn", type=float, default=1.0) parser.add_argument("--warmup_steps", default=5000, type=int) # weight decay is only apply to YOLOX head if using FSDP # https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159 parser.add_argument("--weight_decay", default=0.05, type=float) parser.add_argument( "--precision", choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], default="fp32", help="Floating point precision.", ) # data args parser.add_argument("--workers", type=int, default=1) parser.add_argument("--dataset_resampled", action="store_true") # distributed training args parser.add_argument( "--dist-url", default="env://", type=str, help="url used to set up distributed training", ) parser.add_argument( "--dist-backend", default="nccl", type=str, help="distributed backend" ) parser.add_argument( "--horovod", default=False, action="store_true", help="Use horovod for distributed training.", ) parser.add_argument( "--no-set-device-rank", default=False, action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", ) # wandb args parser.add_argument("--report_to_wandb", default=False, action="store_true") parser.add_argument( "--wandb_project", type=str, ) parser.add_argument( "--wandb_entity", type=str, ) parser.add_argument( "--save_checkpoints_to_wandb", default=False, action="store_true", help="save checkpoints to wandb", ) parser.add_argument( "--checkpoint_activations", default=False, action="store_true", ) parser.add_argument( "--freeze_vision_encoder", default=False, action="store_true", ) parser.add_argument( "--mmc4_textsim_threshold", default=30, type=float, help="threshold for filtering images in mmc4 based on image-text similarity", ) parser.add_argument( "--location_token_num", default=1000, type=int, ) parser.add_argument( "--vis_embed_size", type=int, required=False, ) parser.add_argument( "--save_interval", default=1000, type=int, required=False, ) parser.add_argument( "--skip_delete_pattern", default=1500, type=int, required=False, ) parser.add_argument( "--ddp", default=False, action="store_true", ) parser.add_argument( "--pile_freq", default=1, type=int, required=False, ) parser.add_argument( "--restart", default=False, action="store_true", ) parser.add_argument( "--lora", default=False, action="store_true", ) parser.add_argument( "--lora_r", default=16, type=int, required=False, ) parser.add_argument( "--single", default=False, action="store_true", ) # Finetune parser.add_argument( "--instruct", default=False, action="store_true", ) parser.add_argument( "--fix-ffn", default=False, action="store_true", ) parser.add_argument( "--prob_ground", default=1.0, type=float, required=False, ) parser.add_argument( "--optimizer", default="adamw", type=str, required=False, ) parser.add_argument( "--add_visual_token", default=False, action="store_true", ) parser.add_argument( "--use_format_v2", default=False, action="store_true", ) parser.add_argument( "--use_sam", default=None, type=str, required=False, ) parser.add_argument( "--max-length", default=608, type=int, required=False, ) parser.add_argument( "--image-size", default=256, type=int, required=False, ) parser.add_argument( "--reset_llm", default=False, action="store_true", ) parser.add_argument( "--add_box", default=False, action="store_true", ) parser.add_argument( "--add_pe", default=False, action="store_true", ) parser.add_argument( "--only_grounded_sample", default=False, action="store_true", ) parser.add_argument( "--expand", default=False, action="store_true", ) parser.add_argument( "--delete_contained", default=False, action="store_true", ) parser.add_argument( "--relation", default=False, action="store_true", ) parser.add_argument( "--attn_reg", default="l1", type=str, required=False, ) parser.add_argument( "--enhance_data", default=False, action="store_true", ) parser.add_argument( "--no_visual", default=False, action="store_true", ) parser.add_argument( "--no_previsual", default=False, action="store_true", ) parser.add_argument( "--roi_align", default=False, action="store_true", ) parser.add_argument( "--roi_output_size", default=4, type=int, required=False, ) parser.add_argument( "--apply_mask", default=False, action="store_true", ) parser.add_argument( "--longer_previsual", default=False, action="store_true", ) args = parser.parse_args() assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation" if args.no_previsual: assert args.no_visual, "no_previsual MUST come with no_visual" assert not args.enhance_data, "dont enable enhance_data" if args.offline: os.environ["WANDB_MODE"] = "offline" os.environ["TRANSFORMERS_OFFLINE"] = "1" args.local_rank, args.rank, args.world_size = world_info_from_env() print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}") device_id = init_distributed_device(args) random_seed(args.seed) model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms( args.vision_encoder_path, args.vision_encoder_pretrained, args.lm_path, args.tokenizer_path if args.tokenizer_path else args.lm_path, use_local_files=args.offline, use_media_placement_augmentation=args.use_media_placement_augmentation, checkpoint_activations=args.checkpoint_activations, freeze_vision_encoder=args.freeze_vision_encoder, location_token_num=args.location_token_num, lora=args.lora, lora_r=args.lora_r, fix_ffn=args.fix_ffn, add_visual_token=args.add_visual_token, add_box=args.add_box, add_pe=args.add_pe, add_relation=args.relation, use_format_v2=args.use_format_v2, use_sam=args.use_sam, enhance_data=args.enhance_data, roi_align=args.roi_align, roi_output_size=args.roi_output_size, apply_mask=args.apply_mask, ) if args.reset_llm: llm_state_dict = model.lang_encoder.state_dict() if args.rank == 0: print(args) print(image_processor) random_seed(args.seed, args.rank) if args.rank == 0 and args.report_to_wandb: wandb.init( project=args.wandb_project, entity=args.wandb_entity, name=args.run_name, config=vars(args), ) device_id = args.rank % torch.cuda.device_count() if args.ddp: print("use ddp mode") model = model.to(device_id) model = DDP(model) else: fpSixteen = MixedPrecision( param_dtype=torch.float16, # Gradient communication precision. reduce_dtype=torch.float16, # Buffer precision. # buffer_dtype=torch.float16, ) # from transformers.models.opt.modeling_opt import OPTDecoderLayer from open_clip.transformer import ResidualAttentionBlock from open_flamingo.src.flamingo_lm import FlamingoLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention from segment_anything.modeling.image_encoder import Block transformer_layer_cls=[ FlamingoLayer, ResidualAttentionBlock, Block, ] if args.fix_ffn: transformer_layer_cls.append(OPTAttention) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls, ) if args.lora: from torch.distributed.fsdp.wrap import _or_policy lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy]) ignored_modules = [model.vision_encoder] # ignored_modules = None else: ignored_modules = [model.detection_head] # ignored_modules = None if args.add_pe: ignored_modules += [model.pos_enc] # if args.use_format_v2: # ignored_modules += [model.lang_encoder.visual_guided_lm_head] model = FSDP( model, auto_wrap_policy=auto_wrap_policy, mixed_precision=fpSixteen, device_id=torch.cuda.current_device(), ignored_modules=ignored_modules, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, ) model = model.to(device_id) pile_dataset = None if args.instruct: laion_dataset = get_data(args, image_processor, tokenizer, "instruct") else: laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text") if args.pile_shards is not None: pile_dataset = get_data(args, image_processor, tokenizer, "pile") optim_groups = get_grouped_params(model, args) # optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) if args.ddp: optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) # optimizer = ZeroRedundancyOptimizer( # optim_groups, # optimizer_class=torch.optim.AdamW, # lr=args.learning_rate, # parameters_as_bucket_view=True, # ) else: if args.optimizer == "adamw": print("use adamw") optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) elif args.optimizer == "sgd": print("use sgd...") optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) else: raise NotImplementedError total_training_steps = args.num_steps if args.rank == 0: logging.info(f"Total training steps: {total_training_steps}") if args.lr_scheduler == "linear": lr_scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_training_steps, ) elif args.lr_scheduler == "cosine": lr_scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_training_steps, ) else: lr_scheduler = get_constant_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps ) if args.ddp: scaler = GradScaler() else: scaler = ShardedGradScaler() total_laion_token = 0 total_pile_token = 0 total_laion_sample = 0 total_step = 0 # check if a checkpoint exists for this run if os.path.exists(f"{args.run_name}"): checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") if len(checkpoint_list) == 0: if args.rank == 0: logging.info(f"Found no checkpoints for run {args.run_name}.") else: args.resume_from_checkpoint = sorted( checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) )[-1] if args.rank == 0: logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.") args.restart = False if args.rank == 0: logging.info("do not restart because an existed checkpoint is found") if args.resume_from_checkpoint is not None: if args.rank == 0: logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}") checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") torch.distributed.barrier() if args.ddp: model.module.load_state_dict(checkpoint["model_state_dict"], strict=False) # sharded_osd = checkpoint['optimizer_state_dict'] else: with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): if args.reset_llm: for key in checkpoint["model_state_dict"]: if key.startswith("lang_encoder"): if args.rank == 0: logging.info(f"reset {key}") llm_key = key.replace("lang_encoder.", "") checkpoint["model_state_dict"][key] = llm_state_dict[llm_key] model_state_dict = model.state_dict() for key in checkpoint["model_state_dict"].keys(): if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape: if args.rank == 0: logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}') checkpoint["model_state_dict"][key] = model_state_dict[key].clone() del model_state_dict model.load_state_dict(checkpoint["model_state_dict"], False) # sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups) if not args.restart: # optimizer.load_state_dict(sharded_osd) lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) # scaler.load_state_dict(checkpoint["scaler_state_dict"]) total_laion_token = checkpoint.get("total_laion_token", 0) total_pile_token = checkpoint.get("total_pile_token", 0) total_laion_sample = checkpoint.get("total_laion_sample", 0) total_step = checkpoint.get("total_step", 0) if args.rank == 0: logging.info("load training statistics...") else: if args.rank == 0: logging.info("restart training / finetuning. only load model weight...") del checkpoint if args.reset_llm: del llm_state_dict torch.cuda.empty_cache() torch.distributed.barrier() model.train() if args.rank == 0: if not os.path.exists(args.run_name): os.makedirs(args.run_name) writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog")) else: writer = None laion_dataset.set_epoch(total_step) laion_loader = laion_dataset.dataloader if pile_dataset is not None: pile_dataset.set_epoch(total_step) pile_loader = pile_dataset.dataloader else: pile_loader = FakeDataloader() train_one_epoch( args=args, model=model, tokenizer=tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler, laion_loader=laion_loader, pile_loader=pile_loader, device_id=device_id, writer=writer, scaler=scaler, optim_groups=optim_groups, total_laion_token=total_laion_token, total_pile_token=total_pile_token, total_laion_sample=total_laion_sample, total_step=total_step, ) if __name__ == "__main__": main()