Spaces:
Runtime error
Runtime error
""" Main training script """ | |
import argparse | |
import copy | |
import glob | |
import os | |
import random | |
import functools | |
import numpy as np | |
import torch | |
# torch.multiprocessing.set_sharing_strategy('file_system') | |
import wandb | |
from data2 import get_data | |
from distributed import init_distributed_device, world_info_from_env | |
from torch.distributed.fsdp import ( | |
FullyShardedDataParallel as FSDP, | |
MixedPrecision, | |
BackwardPrefetch, | |
ShardingStrategy, | |
FullStateDictConfig, | |
CPUOffload, | |
StateDictType, | |
) | |
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler | |
from torch.distributed.fsdp.wrap import ( | |
transformer_auto_wrap_policy, | |
enable_wrap, | |
wrap, | |
) | |
from train_utils import train_one_epoch | |
from transformers import ( | |
get_constant_schedule_with_warmup, | |
get_cosine_schedule_with_warmup, | |
get_linear_schedule_with_warmup, | |
) | |
from open_flamingo import create_model_and_transforms | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.cuda.amp import GradScaler | |
from torch.distributed.optim import ZeroRedundancyOptimizer | |
import warnings | |
warnings.filterwarnings("ignore") | |
import logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s %(message)s', | |
datefmt='%m/%d %I:%M:%S', | |
) | |
class FakeDataloader: | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return None | |
def random_seed(seed=42, rank=0): | |
torch.manual_seed(seed + rank) | |
np.random.seed(seed + rank) | |
random.seed(seed + rank) | |
def get_grouped_params(model, args): | |
params_with_wd, params_without_wd = [], [] | |
def apply_decay(x): | |
x = x.lower() | |
return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x | |
for n, p in model.named_parameters(): | |
# if p.requires_grad: | |
if apply_decay(n): | |
if torch.distributed.get_rank() == 0: | |
logging.info(f"with wd: {n}") | |
params_with_wd.append(p) | |
else: | |
if torch.distributed.get_rank() == 0: | |
logging.info(f"without wd: {n}") | |
params_without_wd.append(p) | |
return [ | |
{"params": params_with_wd, "weight_decay": args.weight_decay}, | |
{"params": params_without_wd, "weight_decay": 0.0}, | |
] | |
def lambda_policy_fn(module): | |
if ( | |
len(list(module.named_children())) == 0 | |
and getattr(module, "weight", None) is not None | |
and module.weight.requires_grad | |
): | |
return True | |
return False | |
def lambda_auto_wrap_policy( | |
module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn, | |
) -> bool: | |
""" | |
A convenient auto wrap policy to wrap submodules based on an arbitrary user | |
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as | |
a `wrapper_cls` unit. | |
Return if a module should be wrapped during auto wrapping. | |
The first three parameters are required by :func:`_recursive_wrap`. | |
Args: | |
module (nn.Module): Current module being considered. | |
recurse (bool): If ``False``, then this function must decide whether | |
``module`` should be wrapped as an FSDP instance or not. If | |
``True``, then the function is still recursing down the module | |
tree as a part of the DFS. | |
nonwrapped_numel (int): Parameter numel not yet wrapped. | |
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then | |
this module will be wrapped. | |
""" | |
if recurse: | |
return True # always recurse | |
return lambda_fn(module) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str) | |
parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str) | |
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str) | |
parser.add_argument( | |
"--tokenizer_path", | |
default="facebook/opt-1.3b", | |
type=str, | |
help="path to tokenizer", | |
) | |
parser.add_argument( | |
"--run_name", | |
type=str, | |
default="openflamingo3B", | |
help="used to name saving directory and wandb run", | |
) | |
parser.add_argument("--use_media_placement_augmentation", action="store_true") | |
parser.add_argument("--offline", action="store_true") | |
parser.add_argument("--num_steps", type=int, default=300000) | |
parser.add_argument( | |
"--logging_steps", type=int, default=10, help="log loss every n steps" | |
) | |
# Sum of gradient optimization batch size | |
parser.add_argument("--batch_size_mmc4", type=int, default=128) | |
parser.add_argument("--batch_size_laion", type=int, default=128) | |
parser.add_argument("--batch_size_pile", type=int, default=128) | |
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) | |
parser.add_argument( | |
"--resume_from_checkpoint", | |
type=str, | |
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", | |
default=None, | |
) | |
parser.add_argument( | |
"--delete_previous_checkpoint", | |
action="store_true", | |
help="delete previous checkpoint when saving new checkpoint", | |
) | |
parser.add_argument( | |
"--laion_shards", | |
type=str, | |
help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", | |
) | |
parser.add_argument( | |
"--mmc4_shards", | |
type=str, | |
help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", | |
) | |
parser.add_argument( | |
"--pile_shards", | |
type=str, | |
default=None, | |
help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", | |
) | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--learning_rate", default=1e-4, type=float) | |
parser.add_argument( | |
"--lr_scheduler", | |
default="constant", | |
type=str, | |
help="constant, linear, or cosine", | |
) | |
parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0) | |
parser.add_argument("--loss_multiplier_laion", type=float, default=1.0) | |
parser.add_argument("--loss_multiplier_pile", type=float, default=1.0) | |
parser.add_argument("--loss_multiplier_det", type=float, default=1.0) | |
parser.add_argument("--loss_multiplier_rel", type=float, default=1.0) | |
parser.add_argument("--loss_multiplier_attn", type=float, default=1.0) | |
parser.add_argument("--warmup_steps", default=5000, type=int) | |
# weight decay is only apply to YOLOX head if using FSDP | |
# https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159 | |
parser.add_argument("--weight_decay", default=0.05, type=float) | |
parser.add_argument( | |
"--precision", | |
choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], | |
default="fp32", | |
help="Floating point precision.", | |
) | |
# data args | |
parser.add_argument("--workers", type=int, default=1) | |
parser.add_argument("--dataset_resampled", action="store_true") | |
# distributed training args | |
parser.add_argument( | |
"--dist-url", | |
default="env://", | |
type=str, | |
help="url used to set up distributed training", | |
) | |
parser.add_argument( | |
"--dist-backend", default="nccl", type=str, help="distributed backend" | |
) | |
parser.add_argument( | |
"--horovod", | |
default=False, | |
action="store_true", | |
help="Use horovod for distributed training.", | |
) | |
parser.add_argument( | |
"--no-set-device-rank", | |
default=False, | |
action="store_true", | |
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", | |
) | |
# wandb args | |
parser.add_argument("--report_to_wandb", default=False, action="store_true") | |
parser.add_argument( | |
"--wandb_project", | |
type=str, | |
) | |
parser.add_argument( | |
"--wandb_entity", | |
type=str, | |
) | |
parser.add_argument( | |
"--save_checkpoints_to_wandb", | |
default=False, | |
action="store_true", | |
help="save checkpoints to wandb", | |
) | |
parser.add_argument( | |
"--checkpoint_activations", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--freeze_vision_encoder", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--mmc4_textsim_threshold", | |
default=30, | |
type=float, | |
help="threshold for filtering images in mmc4 based on image-text similarity", | |
) | |
parser.add_argument( | |
"--location_token_num", | |
default=1000, | |
type=int, | |
) | |
parser.add_argument( | |
"--vis_embed_size", | |
type=int, | |
required=False, | |
) | |
parser.add_argument( | |
"--save_interval", | |
default=1000, | |
type=int, | |
required=False, | |
) | |
parser.add_argument( | |
"--skip_delete_pattern", | |
default=1500, | |
type=int, | |
required=False, | |
) | |
parser.add_argument( | |
"--ddp", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--pile_freq", | |
default=1, | |
type=int, | |
required=False, | |
) | |
parser.add_argument( | |
"--restart", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--lora", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--lora_r", | |
default=16, | |
type=int, | |
required=False, | |
) | |
parser.add_argument( | |
"--single", | |
default=False, | |
action="store_true", | |
) | |
# Finetune | |
parser.add_argument( | |
"--instruct", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--fix-ffn", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--prob_ground", | |
default=1.0, | |
type=float, | |
required=False, | |
) | |
parser.add_argument( | |
"--optimizer", | |
default="adamw", | |
type=str, | |
required=False, | |
) | |
parser.add_argument( | |
"--add_visual_token", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--use_format_v2", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--use_sam", | |
default=None, | |
type=str, | |
required=False, | |
) | |
parser.add_argument( | |
"--max-length", | |
default=608, | |
type=int, | |
required=False, | |
) | |
parser.add_argument( | |
"--image-size", | |
default=256, | |
type=int, | |
required=False, | |
) | |
parser.add_argument( | |
"--reset_llm", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--add_box", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--add_pe", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--only_grounded_sample", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--expand", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--delete_contained", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--relation", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--attn_reg", | |
default="l1", | |
type=str, | |
required=False, | |
) | |
parser.add_argument( | |
"--enhance_data", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--no_visual", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--no_previsual", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--roi_align", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--roi_output_size", | |
default=4, | |
type=int, | |
required=False, | |
) | |
parser.add_argument( | |
"--apply_mask", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--longer_previsual", | |
default=False, | |
action="store_true", | |
) | |
args = parser.parse_args() | |
assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation" | |
if args.no_previsual: | |
assert args.no_visual, "no_previsual MUST come with no_visual" | |
assert not args.enhance_data, "dont enable enhance_data" | |
if args.offline: | |
os.environ["WANDB_MODE"] = "offline" | |
os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
args.local_rank, args.rank, args.world_size = world_info_from_env() | |
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}") | |
device_id = init_distributed_device(args) | |
random_seed(args.seed) | |
model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms( | |
args.vision_encoder_path, | |
args.vision_encoder_pretrained, | |
args.lm_path, | |
args.tokenizer_path if args.tokenizer_path else args.lm_path, | |
use_local_files=args.offline, | |
use_media_placement_augmentation=args.use_media_placement_augmentation, | |
checkpoint_activations=args.checkpoint_activations, | |
freeze_vision_encoder=args.freeze_vision_encoder, | |
location_token_num=args.location_token_num, | |
lora=args.lora, | |
lora_r=args.lora_r, | |
fix_ffn=args.fix_ffn, | |
add_visual_token=args.add_visual_token, | |
add_box=args.add_box, | |
add_pe=args.add_pe, | |
add_relation=args.relation, | |
use_format_v2=args.use_format_v2, | |
use_sam=args.use_sam, | |
enhance_data=args.enhance_data, | |
roi_align=args.roi_align, | |
roi_output_size=args.roi_output_size, | |
apply_mask=args.apply_mask, | |
) | |
if args.reset_llm: | |
llm_state_dict = model.lang_encoder.state_dict() | |
if args.rank == 0: | |
print(args) | |
print(image_processor) | |
random_seed(args.seed, args.rank) | |
if args.rank == 0 and args.report_to_wandb: | |
wandb.init( | |
project=args.wandb_project, | |
entity=args.wandb_entity, | |
name=args.run_name, | |
config=vars(args), | |
) | |
device_id = args.rank % torch.cuda.device_count() | |
if args.ddp: | |
print("use ddp mode") | |
model = model.to(device_id) | |
model = DDP(model) | |
else: | |
fpSixteen = MixedPrecision( | |
param_dtype=torch.float16, | |
# Gradient communication precision. | |
reduce_dtype=torch.float16, | |
# Buffer precision. | |
# buffer_dtype=torch.float16, | |
) | |
# from transformers.models.opt.modeling_opt import OPTDecoderLayer | |
from open_clip.transformer import ResidualAttentionBlock | |
from open_flamingo.src.flamingo_lm import FlamingoLayer | |
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention | |
from segment_anything.modeling.image_encoder import Block | |
transformer_layer_cls=[ | |
FlamingoLayer, | |
ResidualAttentionBlock, | |
Block, | |
] | |
if args.fix_ffn: | |
transformer_layer_cls.append(OPTAttention) | |
auto_wrap_policy = functools.partial( | |
transformer_auto_wrap_policy, | |
transformer_layer_cls=transformer_layer_cls, | |
) | |
if args.lora: | |
from torch.distributed.fsdp.wrap import _or_policy | |
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) | |
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy]) | |
ignored_modules = [model.vision_encoder] | |
# ignored_modules = None | |
else: | |
ignored_modules = [model.detection_head] | |
# ignored_modules = None | |
if args.add_pe: | |
ignored_modules += [model.pos_enc] | |
# if args.use_format_v2: | |
# ignored_modules += [model.lang_encoder.visual_guided_lm_head] | |
model = FSDP( | |
model, | |
auto_wrap_policy=auto_wrap_policy, | |
mixed_precision=fpSixteen, | |
device_id=torch.cuda.current_device(), | |
ignored_modules=ignored_modules, | |
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, | |
) | |
model = model.to(device_id) | |
pile_dataset = None | |
if args.instruct: | |
laion_dataset = get_data(args, image_processor, tokenizer, "instruct") | |
else: | |
laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text") | |
if args.pile_shards is not None: | |
pile_dataset = get_data(args, image_processor, tokenizer, "pile") | |
optim_groups = get_grouped_params(model, args) | |
# optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) | |
if args.ddp: | |
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) | |
# optimizer = ZeroRedundancyOptimizer( | |
# optim_groups, | |
# optimizer_class=torch.optim.AdamW, | |
# lr=args.learning_rate, | |
# parameters_as_bucket_view=True, | |
# ) | |
else: | |
if args.optimizer == "adamw": | |
print("use adamw") | |
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) | |
elif args.optimizer == "sgd": | |
print("use sgd...") | |
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) | |
else: | |
raise NotImplementedError | |
total_training_steps = args.num_steps | |
if args.rank == 0: | |
logging.info(f"Total training steps: {total_training_steps}") | |
if args.lr_scheduler == "linear": | |
lr_scheduler = get_linear_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=args.warmup_steps, | |
num_training_steps=total_training_steps, | |
) | |
elif args.lr_scheduler == "cosine": | |
lr_scheduler = get_cosine_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=args.warmup_steps, | |
num_training_steps=total_training_steps, | |
) | |
else: | |
lr_scheduler = get_constant_schedule_with_warmup( | |
optimizer, num_warmup_steps=args.warmup_steps | |
) | |
if args.ddp: | |
scaler = GradScaler() | |
else: | |
scaler = ShardedGradScaler() | |
total_laion_token = 0 | |
total_pile_token = 0 | |
total_laion_sample = 0 | |
total_step = 0 | |
# check if a checkpoint exists for this run | |
if os.path.exists(f"{args.run_name}"): | |
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") | |
if len(checkpoint_list) == 0: | |
if args.rank == 0: | |
logging.info(f"Found no checkpoints for run {args.run_name}.") | |
else: | |
args.resume_from_checkpoint = sorted( | |
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) | |
)[-1] | |
if args.rank == 0: | |
logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.") | |
args.restart = False | |
if args.rank == 0: | |
logging.info("do not restart because an existed checkpoint is found") | |
if args.resume_from_checkpoint is not None: | |
if args.rank == 0: | |
logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}") | |
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") | |
torch.distributed.barrier() | |
if args.ddp: | |
model.module.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
# sharded_osd = checkpoint['optimizer_state_dict'] | |
else: | |
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): | |
if args.reset_llm: | |
for key in checkpoint["model_state_dict"]: | |
if key.startswith("lang_encoder"): | |
if args.rank == 0: | |
logging.info(f"reset {key}") | |
llm_key = key.replace("lang_encoder.", "") | |
checkpoint["model_state_dict"][key] = llm_state_dict[llm_key] | |
model_state_dict = model.state_dict() | |
for key in checkpoint["model_state_dict"].keys(): | |
if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape: | |
if args.rank == 0: | |
logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}') | |
checkpoint["model_state_dict"][key] = model_state_dict[key].clone() | |
del model_state_dict | |
model.load_state_dict(checkpoint["model_state_dict"], False) | |
# sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups) | |
if not args.restart: | |
# optimizer.load_state_dict(sharded_osd) | |
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) | |
# scaler.load_state_dict(checkpoint["scaler_state_dict"]) | |
total_laion_token = checkpoint.get("total_laion_token", 0) | |
total_pile_token = checkpoint.get("total_pile_token", 0) | |
total_laion_sample = checkpoint.get("total_laion_sample", 0) | |
total_step = checkpoint.get("total_step", 0) | |
if args.rank == 0: | |
logging.info("load training statistics...") | |
else: | |
if args.rank == 0: | |
logging.info("restart training / finetuning. only load model weight...") | |
del checkpoint | |
if args.reset_llm: | |
del llm_state_dict | |
torch.cuda.empty_cache() | |
torch.distributed.barrier() | |
model.train() | |
if args.rank == 0: | |
if not os.path.exists(args.run_name): | |
os.makedirs(args.run_name) | |
writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog")) | |
else: | |
writer = None | |
laion_dataset.set_epoch(total_step) | |
laion_loader = laion_dataset.dataloader | |
if pile_dataset is not None: | |
pile_dataset.set_epoch(total_step) | |
pile_loader = pile_dataset.dataloader | |
else: | |
pile_loader = FakeDataloader() | |
train_one_epoch( | |
args=args, | |
model=model, | |
tokenizer=tokenizer, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
laion_loader=laion_loader, | |
pile_loader=pile_loader, | |
device_id=device_id, | |
writer=writer, | |
scaler=scaler, | |
optim_groups=optim_groups, | |
total_laion_token=total_laion_token, | |
total_pile_token=total_pile_token, | |
total_laion_sample=total_laion_sample, | |
total_step=total_step, | |
) | |
if __name__ == "__main__": | |
main() | |