chendl's picture
Add application file
0b7b08a
raw
history blame
23.2 kB
""" Main training script """
import argparse
import copy
import glob
import os
import random
import functools
import numpy as np
import torch
# torch.multiprocessing.set_sharing_strategy('file_system')
import wandb
from data2 import get_data
from distributed import init_distributed_device, world_info_from_env
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
CPUOffload,
StateDictType,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from train_utils import train_one_epoch
from transformers import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from open_flamingo import create_model_and_transforms
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler
from torch.distributed.optim import ZeroRedundancyOptimizer
import warnings
warnings.filterwarnings("ignore")
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(message)s',
datefmt='%m/%d %I:%M:%S',
)
class FakeDataloader:
def __iter__(self):
return self
def __next__(self):
return None
def random_seed(seed=42, rank=0):
torch.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
def get_grouped_params(model, args):
params_with_wd, params_without_wd = [], []
def apply_decay(x):
x = x.lower()
return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x
for n, p in model.named_parameters():
# if p.requires_grad:
if apply_decay(n):
if torch.distributed.get_rank() == 0:
logging.info(f"with wd: {n}")
params_with_wd.append(p)
else:
if torch.distributed.get_rank() == 0:
logging.info(f"without wd: {n}")
params_without_wd.append(p)
return [
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
]
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
def lambda_auto_wrap_policy(
module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn,
) -> bool:
"""
A convenient auto wrap policy to wrap submodules based on an arbitrary user
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
a `wrapper_cls` unit.
Return if a module should be wrapped during auto wrapping.
The first three parameters are required by :func:`_recursive_wrap`.
Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
this module will be wrapped.
"""
if recurse:
return True # always recurse
return lambda_fn(module)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str)
parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str)
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
parser.add_argument(
"--tokenizer_path",
default="facebook/opt-1.3b",
type=str,
help="path to tokenizer",
)
parser.add_argument(
"--run_name",
type=str,
default="openflamingo3B",
help="used to name saving directory and wandb run",
)
parser.add_argument("--use_media_placement_augmentation", action="store_true")
parser.add_argument("--offline", action="store_true")
parser.add_argument("--num_steps", type=int, default=300000)
parser.add_argument(
"--logging_steps", type=int, default=10, help="log loss every n steps"
)
# Sum of gradient optimization batch size
parser.add_argument("--batch_size_mmc4", type=int, default=128)
parser.add_argument("--batch_size_laion", type=int, default=128)
parser.add_argument("--batch_size_pile", type=int, default=128)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
default=None,
)
parser.add_argument(
"--delete_previous_checkpoint",
action="store_true",
help="delete previous checkpoint when saving new checkpoint",
)
parser.add_argument(
"--laion_shards",
type=str,
help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
)
parser.add_argument(
"--mmc4_shards",
type=str,
help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
)
parser.add_argument(
"--pile_shards",
type=str,
default=None,
help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning_rate", default=1e-4, type=float)
parser.add_argument(
"--lr_scheduler",
default="constant",
type=str,
help="constant, linear, or cosine",
)
parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
parser.add_argument("--loss_multiplier_laion", type=float, default=1.0)
parser.add_argument("--loss_multiplier_pile", type=float, default=1.0)
parser.add_argument("--loss_multiplier_det", type=float, default=1.0)
parser.add_argument("--loss_multiplier_rel", type=float, default=1.0)
parser.add_argument("--loss_multiplier_attn", type=float, default=1.0)
parser.add_argument("--warmup_steps", default=5000, type=int)
# weight decay is only apply to YOLOX head if using FSDP
# https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159
parser.add_argument("--weight_decay", default=0.05, type=float)
parser.add_argument(
"--precision",
choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
default="fp32",
help="Floating point precision.",
)
# data args
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--dataset_resampled", action="store_true")
# distributed training args
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--horovod",
default=False,
action="store_true",
help="Use horovod for distributed training.",
)
parser.add_argument(
"--no-set-device-rank",
default=False,
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)
# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
parser.add_argument(
"--wandb_project",
type=str,
)
parser.add_argument(
"--wandb_entity",
type=str,
)
parser.add_argument(
"--save_checkpoints_to_wandb",
default=False,
action="store_true",
help="save checkpoints to wandb",
)
parser.add_argument(
"--checkpoint_activations",
default=False,
action="store_true",
)
parser.add_argument(
"--freeze_vision_encoder",
default=False,
action="store_true",
)
parser.add_argument(
"--mmc4_textsim_threshold",
default=30,
type=float,
help="threshold for filtering images in mmc4 based on image-text similarity",
)
parser.add_argument(
"--location_token_num",
default=1000,
type=int,
)
parser.add_argument(
"--vis_embed_size",
type=int,
required=False,
)
parser.add_argument(
"--save_interval",
default=1000,
type=int,
required=False,
)
parser.add_argument(
"--skip_delete_pattern",
default=1500,
type=int,
required=False,
)
parser.add_argument(
"--ddp",
default=False,
action="store_true",
)
parser.add_argument(
"--pile_freq",
default=1,
type=int,
required=False,
)
parser.add_argument(
"--restart",
default=False,
action="store_true",
)
parser.add_argument(
"--lora",
default=False,
action="store_true",
)
parser.add_argument(
"--lora_r",
default=16,
type=int,
required=False,
)
parser.add_argument(
"--single",
default=False,
action="store_true",
)
# Finetune
parser.add_argument(
"--instruct",
default=False,
action="store_true",
)
parser.add_argument(
"--fix-ffn",
default=False,
action="store_true",
)
parser.add_argument(
"--prob_ground",
default=1.0,
type=float,
required=False,
)
parser.add_argument(
"--optimizer",
default="adamw",
type=str,
required=False,
)
parser.add_argument(
"--add_visual_token",
default=False,
action="store_true",
)
parser.add_argument(
"--use_format_v2",
default=False,
action="store_true",
)
parser.add_argument(
"--use_sam",
default=None,
type=str,
required=False,
)
parser.add_argument(
"--max-length",
default=608,
type=int,
required=False,
)
parser.add_argument(
"--image-size",
default=256,
type=int,
required=False,
)
parser.add_argument(
"--reset_llm",
default=False,
action="store_true",
)
parser.add_argument(
"--add_box",
default=False,
action="store_true",
)
parser.add_argument(
"--add_pe",
default=False,
action="store_true",
)
parser.add_argument(
"--only_grounded_sample",
default=False,
action="store_true",
)
parser.add_argument(
"--expand",
default=False,
action="store_true",
)
parser.add_argument(
"--delete_contained",
default=False,
action="store_true",
)
parser.add_argument(
"--relation",
default=False,
action="store_true",
)
parser.add_argument(
"--attn_reg",
default="l1",
type=str,
required=False,
)
parser.add_argument(
"--enhance_data",
default=False,
action="store_true",
)
parser.add_argument(
"--no_visual",
default=False,
action="store_true",
)
parser.add_argument(
"--no_previsual",
default=False,
action="store_true",
)
parser.add_argument(
"--roi_align",
default=False,
action="store_true",
)
parser.add_argument(
"--roi_output_size",
default=4,
type=int,
required=False,
)
parser.add_argument(
"--apply_mask",
default=False,
action="store_true",
)
parser.add_argument(
"--longer_previsual",
default=False,
action="store_true",
)
args = parser.parse_args()
assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation"
if args.no_previsual:
assert args.no_visual, "no_previsual MUST come with no_visual"
assert not args.enhance_data, "dont enable enhance_data"
if args.offline:
os.environ["WANDB_MODE"] = "offline"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
args.local_rank, args.rank, args.world_size = world_info_from_env()
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
device_id = init_distributed_device(args)
random_seed(args.seed)
model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms(
args.vision_encoder_path,
args.vision_encoder_pretrained,
args.lm_path,
args.tokenizer_path if args.tokenizer_path else args.lm_path,
use_local_files=args.offline,
use_media_placement_augmentation=args.use_media_placement_augmentation,
checkpoint_activations=args.checkpoint_activations,
freeze_vision_encoder=args.freeze_vision_encoder,
location_token_num=args.location_token_num,
lora=args.lora,
lora_r=args.lora_r,
fix_ffn=args.fix_ffn,
add_visual_token=args.add_visual_token,
add_box=args.add_box,
add_pe=args.add_pe,
add_relation=args.relation,
use_format_v2=args.use_format_v2,
use_sam=args.use_sam,
enhance_data=args.enhance_data,
roi_align=args.roi_align,
roi_output_size=args.roi_output_size,
apply_mask=args.apply_mask,
)
if args.reset_llm:
llm_state_dict = model.lang_encoder.state_dict()
if args.rank == 0:
print(args)
print(image_processor)
random_seed(args.seed, args.rank)
if args.rank == 0 and args.report_to_wandb:
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.run_name,
config=vars(args),
)
device_id = args.rank % torch.cuda.device_count()
if args.ddp:
print("use ddp mode")
model = model.to(device_id)
model = DDP(model)
else:
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
# buffer_dtype=torch.float16,
)
# from transformers.models.opt.modeling_opt import OPTDecoderLayer
from open_clip.transformer import ResidualAttentionBlock
from open_flamingo.src.flamingo_lm import FlamingoLayer
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention
from segment_anything.modeling.image_encoder import Block
transformer_layer_cls=[
FlamingoLayer,
ResidualAttentionBlock,
Block,
]
if args.fix_ffn:
transformer_layer_cls.append(OPTAttention)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=transformer_layer_cls,
)
if args.lora:
from torch.distributed.fsdp.wrap import _or_policy
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy])
ignored_modules = [model.vision_encoder]
# ignored_modules = None
else:
ignored_modules = [model.detection_head]
# ignored_modules = None
if args.add_pe:
ignored_modules += [model.pos_enc]
# if args.use_format_v2:
# ignored_modules += [model.lang_encoder.visual_guided_lm_head]
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=fpSixteen,
device_id=torch.cuda.current_device(),
ignored_modules=ignored_modules,
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
)
model = model.to(device_id)
pile_dataset = None
if args.instruct:
laion_dataset = get_data(args, image_processor, tokenizer, "instruct")
else:
laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text")
if args.pile_shards is not None:
pile_dataset = get_data(args, image_processor, tokenizer, "pile")
optim_groups = get_grouped_params(model, args)
# optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
if args.ddp:
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
# optimizer = ZeroRedundancyOptimizer(
# optim_groups,
# optimizer_class=torch.optim.AdamW,
# lr=args.learning_rate,
# parameters_as_bucket_view=True,
# )
else:
if args.optimizer == "adamw":
print("use adamw")
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
elif args.optimizer == "sgd":
print("use sgd...")
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
else:
raise NotImplementedError
total_training_steps = args.num_steps
if args.rank == 0:
logging.info(f"Total training steps: {total_training_steps}")
if args.lr_scheduler == "linear":
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
elif args.lr_scheduler == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
else:
lr_scheduler = get_constant_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps
)
if args.ddp:
scaler = GradScaler()
else:
scaler = ShardedGradScaler()
total_laion_token = 0
total_pile_token = 0
total_laion_sample = 0
total_step = 0
# check if a checkpoint exists for this run
if os.path.exists(f"{args.run_name}"):
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) == 0:
if args.rank == 0:
logging.info(f"Found no checkpoints for run {args.run_name}.")
else:
args.resume_from_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
if args.rank == 0:
logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.")
args.restart = False
if args.rank == 0:
logging.info("do not restart because an existed checkpoint is found")
if args.resume_from_checkpoint is not None:
if args.rank == 0:
logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}")
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
torch.distributed.barrier()
if args.ddp:
model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
# sharded_osd = checkpoint['optimizer_state_dict']
else:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
if args.reset_llm:
for key in checkpoint["model_state_dict"]:
if key.startswith("lang_encoder"):
if args.rank == 0:
logging.info(f"reset {key}")
llm_key = key.replace("lang_encoder.", "")
checkpoint["model_state_dict"][key] = llm_state_dict[llm_key]
model_state_dict = model.state_dict()
for key in checkpoint["model_state_dict"].keys():
if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape:
if args.rank == 0:
logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}')
checkpoint["model_state_dict"][key] = model_state_dict[key].clone()
del model_state_dict
model.load_state_dict(checkpoint["model_state_dict"], False)
# sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups)
if not args.restart:
# optimizer.load_state_dict(sharded_osd)
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
# scaler.load_state_dict(checkpoint["scaler_state_dict"])
total_laion_token = checkpoint.get("total_laion_token", 0)
total_pile_token = checkpoint.get("total_pile_token", 0)
total_laion_sample = checkpoint.get("total_laion_sample", 0)
total_step = checkpoint.get("total_step", 0)
if args.rank == 0:
logging.info("load training statistics...")
else:
if args.rank == 0:
logging.info("restart training / finetuning. only load model weight...")
del checkpoint
if args.reset_llm:
del llm_state_dict
torch.cuda.empty_cache()
torch.distributed.barrier()
model.train()
if args.rank == 0:
if not os.path.exists(args.run_name):
os.makedirs(args.run_name)
writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog"))
else:
writer = None
laion_dataset.set_epoch(total_step)
laion_loader = laion_dataset.dataloader
if pile_dataset is not None:
pile_dataset.set_epoch(total_step)
pile_loader = pile_dataset.dataloader
else:
pile_loader = FakeDataloader()
train_one_epoch(
args=args,
model=model,
tokenizer=tokenizer,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
laion_loader=laion_loader,
pile_loader=pile_loader,
device_id=device_id,
writer=writer,
scaler=scaler,
optim_groups=optim_groups,
total_laion_token=total_laion_token,
total_pile_token=total_pile_token,
total_laion_sample=total_laion_sample,
total_step=total_step,
)
if __name__ == "__main__":
main()