Spaces:
Running
on
Zero
Running
on
Zero
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, | |
LoggerHook, ParamSchedulerHook) | |
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR | |
from torch.optim import AdamW | |
from transformers import AutoTokenizer | |
from xtuner.dataset import ConcatDataset | |
from xtuner.dataset.samplers import LengthGroupedSampler | |
from xtuner.engine.hooks import DatasetInfoHook | |
from xtuner.engine.runner import TrainLoop | |
from xtuner.utils import PROMPT_TEMPLATE | |
from xtuner.dataset.map_fns import template_map_fn_factory | |
from third_parts.mmdet.models.losses import DiceLoss, CrossEntropyLoss | |
from peft import LoraConfig | |
from projects.llava_sam2.models.internvl import InternVL_Slowfast | |
from projects.llava_sam2.models import VideoLLaVASAMModel, SAM2TrainRunner, VideoLLaVASAMModel_zero3 | |
from projects.llava_sam2.datasets import VideoReVOSDataset, VideoMeVISDataset, VideoRefYoutubeVOSDataset, video_lisa_collate_fn, VideoSAM2Dataset | |
from projects.llava_sam2.datasets import VideoChatUniViDataset | |
from projects.llava_sam2.datasets import RefCOCOgGCGDataset, OpenPsgGCGDataset, FlickrGCGDataset, GranDfGCGDataset, OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset | |
from projects.llava_sam2.datasets import LLaVADataset | |
from projects.llava_sam2.datasets import ReferSegmDataset | |
from projects.llava_sam2.models.preprocess.image_resize import DirectResize | |
####################################################################### | |
# PART 1 Settings # | |
####################################################################### | |
# Model | |
path = './pretrained/InternVL2_5-4B' | |
pretrained_pth = None | |
# Data | |
prompt_template = PROMPT_TEMPLATE.phi3_chat | |
max_length = 8192 | |
# Scheduler & Optimizer | |
batch_size = 2 # per_device | |
accumulative_counts = 4 | |
dataloader_num_workers = 4 | |
max_epochs = 1 | |
optim_type = AdamW | |
# official 1024 -> 4e-5 | |
# lr = 1e-6 | |
lr = 4e-5 | |
betas = (0.9, 0.999) | |
weight_decay = 0.05 | |
max_norm = 1 # grad clip | |
warmup_ratio = 0.05 | |
# Save | |
save_steps = 1000 | |
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) | |
special_tokens = ['[SEG]', '<p>', '</p>', '<vp>', '</vp>'] | |
tokenizer = dict( | |
type=AutoTokenizer.from_pretrained, | |
pretrained_model_name_or_path=path, | |
trust_remote_code=True, | |
padding_side='right') | |
extra_image_processor = dict( | |
type=DirectResize, | |
target_length=1024, | |
) | |
####################################################################### | |
# PART 2 Model & Tokenizer & Image Processor # | |
####################################################################### | |
model = dict( | |
type=VideoLLaVASAMModel_zero3, | |
special_tokens=special_tokens, | |
frozen_sam2_decoder=False, | |
mllm=dict( | |
type=InternVL_Slowfast, | |
model_path=path, | |
freeze_llm=True, | |
freeze_visual_encoder=True, | |
llm_lora=dict( | |
type=LoraConfig, | |
r=128, | |
lora_alpha=256, | |
lora_dropout=0.05, | |
bias='none', | |
task_type='CAUSAL_LM'), | |
special_tokens=special_tokens, | |
), | |
tokenizer=tokenizer, | |
grounding_encoder=dict( | |
type=SAM2TrainRunner, | |
), | |
loss_mask=dict( | |
type=CrossEntropyLoss, | |
use_sigmoid=True, | |
reduction='mean', | |
loss_weight=2.0), | |
loss_dice=dict( | |
type=DiceLoss, | |
use_sigmoid=True, | |
activate=True, | |
reduction='mean', | |
naive_dice=True, | |
eps=1.0, | |
loss_weight=0.5), | |
pretrained_pth=pretrained_pth, | |
loss_sample_points=True, | |
# loss_sample_points=False, | |
bs=batch_size, | |
) | |
####################################################################### | |
# PART 3 Dataset & Dataloader # | |
####################################################################### | |
VIDEO_DATAS = './data/video_datas/' | |
IMG_DATAS = './data/image_datas/' | |
############### video res | |
data_root_revos = './data/video_datas/revos/' | |
video_revos_image_folder = data_root_revos | |
video_revos_expression_file = data_root_revos + 'meta_expressions_train_.json' | |
video_revos_mask_file = data_root_revos + 'mask_dict.json' | |
data_root_mevis = './data/video_datas/mevis/train/' | |
video_mevis_image_folder = data_root_mevis + 'JPEGImages' | |
video_mevis_expression_file = data_root_mevis + 'meta_expressions.json' | |
video_mevis_mask_file = data_root_mevis + 'mask_dict.json' | |
data_root_refytvos = './data/video_datas/rvos/' | |
video_refytvos_image_folder = data_root_refytvos + 'train/JPEGImages/' | |
video_refytvos_expression_file = data_root_refytvos + 'meta_expressions/train/meta_expressions.json' | |
video_refytvos_mask_file = data_root_refytvos + 'mask_dict.pkl' | |
video_revos_dataset = dict( | |
type=VideoReVOSDataset, | |
image_folder=video_revos_image_folder, | |
expression_file=video_revos_expression_file, | |
mask_file=video_revos_mask_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=10, | |
special_tokens=special_tokens, | |
extra_image_processor=extra_image_processor, | |
sampled_frames=5, | |
) | |
video_mevis_dataset = dict( | |
type=VideoMeVISDataset, | |
image_folder=video_mevis_image_folder, | |
expression_file=video_mevis_expression_file, | |
mask_file=video_mevis_mask_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=4, | |
special_tokens=special_tokens, | |
extra_image_processor=extra_image_processor, | |
sampled_frames=5, | |
) | |
video_refytvos_dataset = dict( | |
type=VideoRefYoutubeVOSDataset, | |
image_folder=video_refytvos_image_folder, | |
expression_file=video_refytvos_expression_file, | |
mask_file=video_refytvos_mask_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=4, | |
special_tokens=special_tokens, | |
extra_image_processor=extra_image_processor, | |
sampled_frames=5, | |
) | |
################### Video chat | |
data_root_video_chatunivi = VIDEO_DATAS + 'video_vlm/video_chat/' | |
video_chatunivi_image_folder = data_root_video_chatunivi + 'Activity_Videos/' | |
video_chatunivi_json_file = data_root_video_chatunivi+ 'video_chat.json' | |
video_qa_dataset = dict( | |
type=VideoChatUniViDataset, | |
image_folder=video_chatunivi_image_folder, | |
json_file=video_chatunivi_json_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=1, | |
special_tokens=special_tokens, | |
extra_image_processor=extra_image_processor, | |
sampled_frames=5, | |
) | |
################## image chat | |
llava_vqa_dataset = dict( | |
type=LLaVADataset, | |
tokenizer=tokenizer, | |
data_path='data/llava_data/LLaVA-Instruct-150K/llava_v1_5_mix665k.json', | |
prompt_template=prompt_template, | |
special_tokens=special_tokens, | |
image_folder='data/llava_data/llava_images/', | |
) | |
################## image res | |
refcoco_segm_dataset=dict( | |
type=ReferSegmDataset, | |
tokenizer=tokenizer, | |
special_tokens=special_tokens, | |
extra_image_processor=extra_image_processor, | |
data_root='data/ref_seg/refcoco', | |
data_prefix=dict(img_path='coco2014/train2014/'), | |
ann_file='instances.json', | |
split_file='refs(unc).p', | |
prompt_template=prompt_template, | |
num_classes_per_sample=5, | |
max_length=max_length, | |
) | |
refcoco_plus_segm_dataset=dict( | |
type=ReferSegmDataset, | |
tokenizer=tokenizer, | |
special_tokens=special_tokens, | |
extra_image_processor=extra_image_processor, | |
data_root='data/ref_seg/refcoco+', | |
data_prefix=dict(img_path='coco2014/train2014/'), | |
ann_file='instances.json', | |
split_file='refs(unc).p', | |
prompt_template=prompt_template, | |
num_classes_per_sample=5, | |
max_length=max_length, | |
) | |
refcocog_segm_dataset=dict( | |
type=ReferSegmDataset, | |
tokenizer=tokenizer, | |
special_tokens=special_tokens, | |
extra_image_processor=extra_image_processor, | |
data_root='data/ref_seg/refcocog', | |
data_prefix=dict(img_path='coco2014/train2014/'), | |
ann_file='instances.json', | |
split_file='refs(umd).p', | |
prompt_template=prompt_template, | |
num_classes_per_sample=5, | |
max_length=max_length, | |
) | |
# image gcg datas | |
glamm_data_root = './data/glamm_data/' | |
refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' | |
refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' | |
grandf_image_path = glamm_data_root + 'images/grandf/train/' | |
grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' | |
flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' | |
flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' | |
psg_image_path = glamm_data_root + 'images/coco2017/' | |
psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' | |
glamm_refcocog_dataset = dict( | |
type=RefCOCOgGCGDataset, | |
image_folder=refcocog_image_path, | |
data_path=refcocog_ann_file, | |
tokenizer=tokenizer, | |
max_length=max_length, | |
special_tokens=special_tokens, | |
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), | |
extra_image_processor=extra_image_processor, | |
lazy=True, | |
repeats=1, | |
) | |
glamm_grandf_dataset = dict( | |
type=GranDfGCGDataset, | |
data_path=grandf_ann_file, | |
image_folder=grandf_image_path, | |
tokenizer=tokenizer, | |
max_length=max_length, | |
special_tokens=special_tokens, | |
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), | |
extra_image_processor=extra_image_processor, | |
lazy=True, | |
repeats=10, | |
) | |
glamm_psg_dataset = dict( | |
type=OpenPsgGCGDataset, | |
data_path=psg_ann_file, | |
image_folder=psg_image_path, | |
tokenizer=tokenizer, | |
max_length=max_length, | |
special_tokens=special_tokens, | |
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), | |
extra_image_processor=extra_image_processor, | |
lazy=True, | |
repeats=1, | |
) | |
glamm_flickr_dataset = dict( | |
type=FlickrGCGDataset, | |
data_path=flickr_ann_file, | |
image_folder=flickr_image_path, | |
tokenizer=tokenizer, | |
max_length=max_length, | |
special_tokens=special_tokens, | |
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), | |
extra_image_processor=extra_image_processor, | |
lazy=True, | |
repeats=1, | |
) | |
# sam2 data | |
data_sam2_folder = VIDEO_DATAS + 'segmentation_datasets/sam_v_full/' | |
data_sam2_expression_file = './whole_pesudo_cap_v3/sam_v_final_v3.json' | |
video_sam2_dataset = dict( | |
type=VideoSAM2Dataset, | |
sam2_folder=data_sam2_folder, | |
expression_file=data_sam2_expression_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=4, | |
special_tokens=special_tokens, | |
extra_image_processor=extra_image_processor, | |
sampled_frames=5, | |
select_number=5, | |
) | |
# osprey | |
data_osprey_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_conversation.json' | |
data_osprey_image_folders = [ | |
IMG_DATAS+ 'coco/train2014/', | |
IMG_DATAS + 'coco/val2014/', | |
IMG_DATAS + 'coco/train2017/', | |
IMG_DATAS + 'coco/val2017/', | |
] | |
image_osprey_dataset = dict( | |
type=OspreyDataset, | |
image_folder=data_osprey_image_folders, | |
data_path=data_osprey_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=1, | |
special_tokens=special_tokens, | |
) | |
data_osprey_detail_description_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_detail_description.json' | |
image_osprey_description_dataset = dict( | |
type=OspreyDescriptionDataset, | |
image_folder=data_osprey_image_folders, | |
data_path=data_osprey_detail_description_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=1, | |
special_tokens=special_tokens, | |
) | |
data_osprey_short_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_short_form.json' | |
image_osprey_short_dataset = dict( | |
type=OspreyShortDescriptionDataset, | |
image_folder=data_osprey_image_folders, | |
data_path=data_osprey_short_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=1, | |
special_tokens=special_tokens, | |
) | |
data_osprey_part_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_part_level.json' | |
image_osprey_part_dataset = dict( | |
type=OspreyDataset, | |
image_folder=data_osprey_image_folders, | |
data_path=data_osprey_part_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=1, | |
special_tokens=special_tokens, | |
) | |
data_osprey_positive_neg_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_lvis_positive_negative.json' | |
image_osprey_positive_neg_dataset = dict( | |
type=OspreyDataset, | |
image_folder=data_osprey_image_folders, | |
data_path=data_osprey_positive_neg_file, | |
tokenizer=tokenizer, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
max_length=max_length, | |
lazy=True, | |
repeats=1, | |
special_tokens=special_tokens, | |
) | |
train_dataset = dict( | |
type=ConcatDataset, datasets=[ | |
# sem seg | |
# semantic_seg_ade20k_dataset, | |
# ref seg | |
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset, | |
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset, | |
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset, | |
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset, | |
# image qa | |
llava_vqa_dataset, | |
# video res | |
video_mevis_dataset, video_revos_dataset, video_refytvos_dataset, | |
# video chat | |
video_qa_dataset, | |
# sam2 pesudo | |
video_sam2_dataset, | |
# gcg data | |
glamm_psg_dataset, | |
glamm_grandf_dataset, | |
glamm_flickr_dataset, | |
glamm_refcocog_dataset, | |
# visual prompt | |
image_osprey_dataset, image_osprey_description_dataset, | |
image_osprey_part_dataset, image_osprey_short_dataset, | |
image_osprey_positive_neg_dataset, | |
] | |
) | |
train_dataloader = dict( | |
batch_size=batch_size, | |
num_workers=dataloader_num_workers, | |
dataset=train_dataset, | |
sampler=dict( | |
type=LengthGroupedSampler, | |
length_property='modality_length', | |
per_device_batch_size=batch_size * accumulative_counts), | |
collate_fn=dict(type=video_lisa_collate_fn) | |
) | |
####################################################################### | |
# PART 4 Scheduler & Optimizer # | |
####################################################################### | |
# optimizer | |
optim_wrapper = dict( | |
type=AmpOptimWrapper, | |
optimizer=dict( | |
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), | |
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), | |
accumulative_counts=accumulative_counts, | |
loss_scale='dynamic', | |
dtype='bfloat16' | |
) | |
# learning policy | |
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 | |
param_scheduler = [ | |
dict( | |
type=LinearLR, | |
start_factor=1e-5, | |
by_epoch=True, | |
begin=0, | |
end=warmup_ratio * max_epochs, | |
convert_to_iter_based=True), | |
dict( | |
type=CosineAnnealingLR, | |
eta_min=0.0, | |
by_epoch=True, | |
begin=warmup_ratio * max_epochs, | |
end=max_epochs, | |
convert_to_iter_based=True) | |
] | |
# train, val, test setting | |
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) | |
####################################################################### | |
# PART 5 Runtime # | |
####################################################################### | |
# Log the dialogue periodically during the training process, optional | |
custom_hooks = [ | |
# dict(type=DatasetInfoHook, tokenizer=tokenizer), | |
] | |
# configure default hooks | |
default_hooks = dict( | |
# record the time of every iteration. | |
timer=dict(type=IterTimerHook), | |
# print log every 10 iterations. | |
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), | |
# enable the parameter scheduler. | |
param_scheduler=dict(type=ParamSchedulerHook), | |
# save checkpoint per `save_steps`. | |
checkpoint=dict( | |
type=CheckpointHook, | |
save_optimizer=False, | |
by_epoch=False, | |
interval=save_steps, | |
max_keep_ckpts=save_total_limit), | |
# set sampler seed in distributed evrionment. | |
sampler_seed=dict(type=DistSamplerSeedHook), | |
) | |
# configure environment | |
env_cfg = dict( | |
# whether to enable cudnn benchmark | |
cudnn_benchmark=False, | |
# set multi process parameters | |
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | |
# set distributed parameters | |
dist_cfg=dict(backend='nccl'), | |
) | |
# set visualizer | |
visualizer = None | |
# set log level | |
log_level = 'INFO' | |
# load from which checkpoint | |
load_from = None | |
# whether to resume training from the loaded checkpoint | |
resume = False | |
# Defaults to use random seed and disable `deterministic` | |
randomness = dict(seed=None, deterministic=False) | |
# set log processor | |
log_processor = dict(by_epoch=False) | |