Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,041 Bytes
2422035 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import numpy as np
import argparse
import os
import json
from utils.distributed import init_distributed_mode
from language.t5 import T5Embedder
CAPTION_KEY = {
'blip': 0,
'llava': 1,
'llava_first': 2,
}
#################################################################################
# Training Helper Functions #
#################################################################################
class CustomDataset(Dataset):
def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False):
img_path_list = []
for lst_name in sorted(os.listdir(lst_dir))[start: end+1]:
if not lst_name.endswith('.jsonl'):
continue
file_path = os.path.join(lst_dir, lst_name)
with open(file_path, 'r') as file:
for line_idx, line in enumerate(file):
data = json.loads(line)
# caption = data[caption_key]
caption = data['text'][CAPTION_KEY[caption_key]]
code_dir = file_path.split('/')[-1].split('.')[0]
if trunc_caption:
caption = caption.split('.')[0]
img_path_list.append((caption, code_dir, line_idx))
self.img_path_list = img_path_list
def __len__(self):
return len(self.img_path_list)
def __getitem__(self, index):
caption, code_dir, code_name = self.img_path_list[index]
return caption, code_dir, code_name
#################################################################################
# Training Loop #
#################################################################################
def main(args):
"""
Trains a new DiT model.
"""
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
# Setup DDP:
# dist.init_process_group("nccl")
init_distributed_mode(args)
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
# Setup a feature folder:
if rank == 0:
os.makedirs(args.t5_path, exist_ok=True)
# Setup data:
print(f"Dataset is preparing...")
dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption)
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=rank,
shuffle=False,
seed=args.global_seed
)
loader = DataLoader(
dataset,
batch_size=1, # important!
shuffle=False,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False
)
print(f"Dataset contains {len(dataset):,} images")
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
assert os.path.exists(args.t5_model_path)
t5_xxl = T5Embedder(
device=device,
local_cache=True,
cache_dir=args.t5_model_path,
dir_or_name=args.t5_model_type,
torch_dtype=precision
)
for caption, code_dir, code_name in loader:
caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption)
valid_caption_embs = caption_embs[:, :emb_masks.sum()]
x = valid_caption_embs.to(torch.float32).detach().cpu().numpy()
os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True)
np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x)
print(code_name.item())
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, required=True)
parser.add_argument("--t5-path", type=str, required=True)
parser.add_argument("--data-start", type=int, required=True)
parser.add_argument("--data-end", type=int, required=True)
parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys()))
parser.add_argument("--trunc-caption", action='store_true', default=False)
parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt')
parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
parser.add_argument("--global-seed", type=int, default=0)
parser.add_argument("--num-workers", type=int, default=24)
args = parser.parse_args()
main(args)
|