Spaces:
Runtime error
Runtime error
File size: 867 Bytes
0b7b08a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
build_model = None
ZeroRedundancyOptimizer = None
GradScaler = None
laion_loader = None
pile_loader = None
autocast = None
zero_embedding_gradient = None
torch = None
lr_scheduler = None
get_cosine_schedule_with_warmup = None
ddp_model = build_model(...)
optimizer = ZeroRedundancyOptimizer(...)
lr_scheduler = get_cosine_schedule_with_warmup(...)
scaler = GradScaler()
for batch_laion, batch_pile in zip(laion_loader, pile_loader):
with autocast():
loss_laion = ddp_model(batch_laion)
scaler.scale(loss_laion).backward()
with autocast():
loss_pile = ddp_model(batch_pile)
scaler.scale(loss_pile).backward()
zero_embedding_gradient()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
|