chendl's picture
Add application file
0b7b08a
raw
history blame
867 Bytes
build_model = None
ZeroRedundancyOptimizer = None
GradScaler = None
laion_loader = None
pile_loader = None
autocast = None
zero_embedding_gradient = None
torch = None
lr_scheduler = None
get_cosine_schedule_with_warmup = None
ddp_model = build_model(...)
optimizer = ZeroRedundancyOptimizer(...)
lr_scheduler = get_cosine_schedule_with_warmup(...)
scaler = GradScaler()
for batch_laion, batch_pile in zip(laion_loader, pile_loader):
with autocast():
loss_laion = ddp_model(batch_laion)
scaler.scale(loss_laion).backward()
with autocast():
loss_pile = ddp_model(batch_pile)
scaler.scale(loss_pile).backward()
zero_embedding_gradient()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()