Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
import torch.distributed as dist | |
from torch.utils.data import Dataset, DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
import numpy as np | |
import argparse | |
import os | |
import json | |
from utils.distributed import init_distributed_mode | |
from language.t5 import T5Embedder | |
CAPTION_KEY = { | |
'blip': 0, | |
'llava': 1, | |
'llava_first': 2, | |
} | |
################################################################################# | |
# Training Helper Functions # | |
################################################################################# | |
class CustomDataset(Dataset): | |
def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False): | |
img_path_list = [] | |
for lst_name in sorted(os.listdir(lst_dir))[start: end+1]: | |
if not lst_name.endswith('.jsonl'): | |
continue | |
file_path = os.path.join(lst_dir, lst_name) | |
with open(file_path, 'r') as file: | |
for line_idx, line in enumerate(file): | |
data = json.loads(line) | |
# caption = data[caption_key] | |
caption = data['text'][CAPTION_KEY[caption_key]] | |
code_dir = file_path.split('/')[-1].split('.')[0] | |
if trunc_caption: | |
caption = caption.split('.')[0] | |
img_path_list.append((caption, code_dir, line_idx)) | |
self.img_path_list = img_path_list | |
def __len__(self): | |
return len(self.img_path_list) | |
def __getitem__(self, index): | |
caption, code_dir, code_name = self.img_path_list[index] | |
return caption, code_dir, code_name | |
################################################################################# | |
# Training Loop # | |
################################################################################# | |
def main(args): | |
""" | |
Trains a new DiT model. | |
""" | |
assert torch.cuda.is_available(), "Training currently requires at least one GPU." | |
# Setup DDP: | |
# dist.init_process_group("nccl") | |
init_distributed_mode(args) | |
rank = dist.get_rank() | |
device = rank % torch.cuda.device_count() | |
seed = args.global_seed * dist.get_world_size() + rank | |
torch.manual_seed(seed) | |
torch.cuda.set_device(device) | |
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") | |
# Setup a feature folder: | |
if rank == 0: | |
os.makedirs(args.t5_path, exist_ok=True) | |
# Setup data: | |
print(f"Dataset is preparing...") | |
dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption) | |
sampler = DistributedSampler( | |
dataset, | |
num_replicas=dist.get_world_size(), | |
rank=rank, | |
shuffle=False, | |
seed=args.global_seed | |
) | |
loader = DataLoader( | |
dataset, | |
batch_size=1, # important! | |
shuffle=False, | |
sampler=sampler, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False | |
) | |
print(f"Dataset contains {len(dataset):,} images") | |
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] | |
assert os.path.exists(args.t5_model_path) | |
t5_xxl = T5Embedder( | |
device=device, | |
local_cache=True, | |
cache_dir=args.t5_model_path, | |
dir_or_name=args.t5_model_type, | |
torch_dtype=precision | |
) | |
for caption, code_dir, code_name in loader: | |
caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption) | |
valid_caption_embs = caption_embs[:, :emb_masks.sum()] | |
x = valid_caption_embs.to(torch.float32).detach().cpu().numpy() | |
os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True) | |
np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x) | |
print(code_name.item()) | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data-path", type=str, required=True) | |
parser.add_argument("--t5-path", type=str, required=True) | |
parser.add_argument("--data-start", type=int, required=True) | |
parser.add_argument("--data-end", type=int, required=True) | |
parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys())) | |
parser.add_argument("--trunc-caption", action='store_true', default=False) | |
parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt') | |
parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl') | |
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) | |
parser.add_argument("--global-seed", type=int, default=0) | |
parser.add_argument("--num-workers", type=int, default=24) | |
args = parser.parse_args() | |
main(args) | |