|
import os |
|
import torch |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
from nets.cyclegan import compute_gradient_penalty |
|
from utils.utils import get_lr, show_result |
|
|
|
|
|
def fit_one_epoch(G_model_A2B_train, G_model_B2A_train, D_model_A_train, D_model_B_train, G_model_A2B, G_model_B2A, D_model_A, D_model_B, VGG_feature_model, ResNeSt_model, loss_history, |
|
G_optimizer, D_optimizer_A, D_optimizer_B, BCE_loss, L1_loss, Face_loss, epoch, epoch_step, gen, Epoch, cuda, fp16, scaler, save_period, save_dir, photo_save_step, local_rank=0): |
|
G_total_loss = 0 |
|
D_total_loss_A = 0 |
|
D_total_loss_B = 0 |
|
|
|
if local_rank == 0: |
|
print('Start Train') |
|
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) |
|
for iteration, batch in enumerate(gen): |
|
if iteration >= epoch_step: |
|
break |
|
|
|
images_A, images_B = batch[0], batch[1] |
|
batch_size = images_A.size()[0] |
|
y_real = torch.ones(batch_size) |
|
y_fake = torch.zeros(batch_size) |
|
|
|
with torch.no_grad(): |
|
if cuda: |
|
images_A, images_B, y_real, y_fake = images_A.cuda(local_rank), images_B.cuda(local_rank), y_real.cuda(local_rank), y_fake.cuda(local_rank) |
|
|
|
if not fp16: |
|
|
|
|
|
|
|
G_optimizer.zero_grad() |
|
|
|
Same_B = G_model_A2B_train(images_B) |
|
loss_identity_B = L1_loss(Same_B, images_B) |
|
|
|
Same_A = G_model_B2A_train(images_A) |
|
loss_identity_A = L1_loss(Same_A, images_A) |
|
|
|
fake_B = G_model_A2B_train(images_A) |
|
pred_real = D_model_B_train(images_B) |
|
pred_fake = D_model_B_train(fake_B) |
|
pred_rf = pred_real - pred_fake.mean() |
|
pred_fr = pred_fake - pred_real.mean() |
|
D_train_loss_rf = BCE_loss(pred_rf, y_fake) |
|
D_train_loss_fr = BCE_loss(pred_fr, y_real) |
|
loss_GAN_A2B = (D_train_loss_rf + D_train_loss_fr) / 2 |
|
|
|
fake_A = G_model_B2A_train(images_B) |
|
pred_real = D_model_A_train(images_A) |
|
pred_fake = D_model_A_train(fake_A) |
|
pred_rf = pred_real - pred_fake.mean() |
|
pred_fr = pred_fake - pred_real.mean() |
|
D_train_loss_rf = BCE_loss(pred_rf, y_fake) |
|
D_train_loss_fr = BCE_loss(pred_fr, y_real) |
|
loss_GAN_B2A = (D_train_loss_rf + D_train_loss_fr) / 2 |
|
|
|
recovered_A = G_model_B2A_train(fake_B) |
|
loss_cycle_ABA = L1_loss(recovered_A, images_A) |
|
|
|
loss_per_ABA = L1_loss(VGG_feature_model(recovered_A), VGG_feature_model(images_A)) |
|
|
|
recovered_A_face = F.interpolate(recovered_A, size=(112, 112), mode='bicubic', align_corners=True) |
|
images_A_face = F.interpolate(images_A, size=(112, 112), mode='bicubic', align_corners=True) |
|
loss_face_ABA = torch.mean(1. - Face_loss(ResNeSt_model(recovered_A_face), ResNeSt_model(images_A_face))) |
|
|
|
recovered_B = G_model_A2B_train(fake_A) |
|
loss_cycle_BAB = L1_loss(recovered_B, images_B) |
|
|
|
loss_per_BAB = L1_loss(VGG_feature_model(recovered_B), VGG_feature_model(images_B)) |
|
|
|
recovered_B_face = F.interpolate(recovered_B, size=(112, 112), mode='bicubic', align_corners=True) |
|
images_B_face = F.interpolate(images_B, size=(112, 112), mode='bicubic', align_corners=True) |
|
loss_face_BAB = torch.mean(1. - Face_loss(ResNeSt_model(recovered_B_face), ResNeSt_model(images_B_face))) |
|
|
|
G_loss = loss_identity_A * 5.0 + loss_identity_B * 5.0 + loss_GAN_A2B + loss_GAN_B2A + loss_per_ABA * 2.5 \ |
|
+ loss_per_BAB *2.5 + loss_cycle_ABA * 10.0 + loss_cycle_BAB * 10.0 + loss_face_ABA * 5 + loss_face_BAB * 5 |
|
G_loss.backward() |
|
G_optimizer.step() |
|
|
|
|
|
|
|
|
|
D_optimizer_A.zero_grad() |
|
pred_real = D_model_A_train(images_A) |
|
pred_fake = D_model_A_train(fake_A.detach()) |
|
pred_rf = pred_real - pred_fake.mean() |
|
pred_fr = pred_fake - pred_real.mean() |
|
D_train_loss_rf = BCE_loss(pred_rf, y_real) |
|
D_train_loss_fr = BCE_loss(pred_fr, y_fake) |
|
gradient_penalty = compute_gradient_penalty(D_model_A_train, images_A, fake_A.detach()) |
|
|
|
D_loss_A = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 |
|
D_loss_A.backward() |
|
D_optimizer_A.step() |
|
|
|
|
|
|
|
|
|
D_optimizer_B.zero_grad() |
|
|
|
pred_real = D_model_B_train(images_B) |
|
pred_fake = D_model_B_train(fake_B.detach()) |
|
pred_rf = pred_real - pred_fake.mean() |
|
pred_fr = pred_fake - pred_real.mean() |
|
D_train_loss_rf = BCE_loss(pred_rf, y_real) |
|
D_train_loss_fr = BCE_loss(pred_fr, y_fake) |
|
gradient_penalty = compute_gradient_penalty(D_model_B_train, images_B, fake_B.detach()) |
|
|
|
D_loss_B = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 |
|
D_loss_B.backward() |
|
D_optimizer_B.step() |
|
|
|
else: |
|
from torch.cuda.amp import autocast |
|
|
|
|
|
|
|
|
|
with autocast(): |
|
G_optimizer.zero_grad() |
|
Same_B = G_model_A2B_train(images_B) |
|
loss_identity_B = L1_loss(Same_B, images_B) |
|
|
|
Same_A = G_model_B2A_train(images_A) |
|
loss_identity_A = L1_loss(Same_A, images_A) |
|
|
|
fake_B = G_model_A2B_train(images_A) |
|
pred_real = D_model_B_train(images_B) |
|
pred_fake = D_model_B_train(fake_B) |
|
pred_rf = pred_real - pred_fake.mean() |
|
pred_fr = pred_fake - pred_real.mean() |
|
D_train_loss_rf = BCE_loss(pred_rf, y_fake) |
|
D_train_loss_fr = BCE_loss(pred_fr, y_real) |
|
loss_GAN_A2B = (D_train_loss_rf + D_train_loss_fr) / 2 |
|
|
|
fake_A = G_model_B2A_train(images_B) |
|
pred_real = D_model_A_train(images_A) |
|
pred_fake = D_model_A_train(fake_A) |
|
pred_rf = pred_real - pred_fake.mean() |
|
pred_fr = pred_fake - pred_real.mean() |
|
D_train_loss_rf = BCE_loss(pred_rf, y_fake) |
|
D_train_loss_fr = BCE_loss(pred_fr, y_real) |
|
loss_GAN_B2A = (D_train_loss_rf + D_train_loss_fr) / 2 |
|
|
|
recovered_A = G_model_B2A_train(fake_B) |
|
loss_cycle_ABA = L1_loss(recovered_A, images_A) |
|
recovered_A_face = F.interpolate(recovered_A, size=(112, 112), mode='bicubic', align_corners=True) |
|
images_A_face = F.interpolate(images_A, size=(112, 112), mode='bicubic', align_corners=True) |
|
loss_face_ABA = torch.mean(1. - Face_loss(ResNeSt_model(recovered_A_face), ResNeSt_model(images_A_face))) |
|
|
|
recovered_B = G_model_A2B_train(fake_A) |
|
loss_cycle_BAB = L1_loss(recovered_B, images_B) |
|
recovered_B_face = F.interpolate(recovered_B, size=(112, 112), mode='bicubic', align_corners=True) |
|
images_B_face = F.interpolate(images_B, size=(112, 112), mode='bicubic', align_corners=True) |
|
loss_face_BAB = torch.mean(1. - Face_loss(ResNeSt_model(recovered_B_face), ResNeSt_model(images_B_face))) |
|
|
|
G_loss = loss_identity_A * 5.0 + loss_identity_B * 5.0 + loss_GAN_A2B + loss_GAN_B2A \ |
|
+ loss_cycle_ABA * 10.0 + loss_cycle_BAB * 10.0 + loss_face_ABA * 5 + loss_face_BAB * 5 |
|
|
|
|
|
|
|
scaler.scale(G_loss).backward() |
|
scaler.step(G_optimizer) |
|
scaler.update() |
|
|
|
|
|
|
|
|
|
with autocast(): |
|
D_optimizer_A.zero_grad() |
|
pred_real = D_model_A_train(images_A) |
|
pred_fake = D_model_A_train(fake_A.detach()) |
|
pred_rf = pred_real - pred_fake.mean() |
|
pred_fr = pred_fake - pred_real.mean() |
|
D_train_loss_rf = BCE_loss(pred_rf, y_real) |
|
D_train_loss_fr = BCE_loss(pred_fr, y_fake) |
|
gradient_penalty = compute_gradient_penalty(D_model_A_train, images_A, fake_A.detach()) |
|
|
|
D_loss_A = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 |
|
|
|
|
|
|
|
scaler.scale(D_loss_A).backward() |
|
scaler.step(D_optimizer_A) |
|
scaler.update() |
|
|
|
|
|
|
|
|
|
with autocast(): |
|
D_optimizer_B.zero_grad() |
|
|
|
pred_real = D_model_B_train(images_B) |
|
pred_fake = D_model_B_train(fake_B.detach()) |
|
pred_rf = pred_real - pred_fake.mean() |
|
pred_fr = pred_fake - pred_real.mean() |
|
D_train_loss_rf = BCE_loss(pred_rf, y_real) |
|
D_train_loss_fr = BCE_loss(pred_fr, y_fake) |
|
gradient_penalty = compute_gradient_penalty(D_model_B_train, images_B, fake_B.detach()) |
|
|
|
D_loss_B = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 |
|
|
|
|
|
|
|
scaler.scale(D_loss_B).backward() |
|
scaler.step(D_optimizer_B) |
|
scaler.update() |
|
|
|
G_total_loss += G_loss.item() |
|
D_total_loss_A += D_loss_A.item() |
|
D_total_loss_B += D_loss_B.item() |
|
|
|
if local_rank == 0: |
|
pbar.set_postfix(**{'G_loss' : G_total_loss / (iteration + 1), |
|
'D_loss_A' : D_total_loss_A / (iteration + 1), |
|
'D_loss_B' : D_total_loss_B / (iteration + 1), |
|
'lr' : get_lr(G_optimizer)}) |
|
pbar.update(1) |
|
|
|
if iteration % photo_save_step == 0: |
|
show_result(epoch + 1, G_model_A2B, G_model_B2A, images_A, images_B) |
|
|
|
G_total_loss = G_total_loss / epoch_step |
|
D_total_loss_A = D_total_loss_A / epoch_step |
|
D_total_loss_B = D_total_loss_B / epoch_step |
|
|
|
if local_rank == 0: |
|
pbar.close() |
|
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) |
|
print('G Loss: %.4f || D Loss A: %.4f || D Loss B: %.4f ' % (G_total_loss, D_total_loss_A, D_total_loss_B)) |
|
loss_history.append_loss(epoch + 1, G_total_loss = G_total_loss, D_total_loss_A = D_total_loss_A, D_total_loss_B = D_total_loss_B) |
|
|
|
|
|
|
|
|
|
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: |
|
torch.save(G_model_A2B.state_dict(), os.path.join(save_dir, 'G_model_A2B_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B))) |
|
torch.save(G_model_B2A.state_dict(), os.path.join(save_dir, 'G_model_B2A_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B))) |
|
torch.save(D_model_A.state_dict(), os.path.join(save_dir, 'D_model_A_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B))) |
|
torch.save(D_model_B.state_dict(), os.path.join(save_dir, 'D_model_B_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B))) |
|
|
|
torch.save(G_model_A2B.state_dict(), os.path.join(save_dir, "G_model_A2B_last_epoch_weights.pth")) |
|
torch.save(G_model_B2A.state_dict(), os.path.join(save_dir, "G_model_B2A_last_epoch_weights.pth")) |
|
torch.save(D_model_A.state_dict(), os.path.join(save_dir, "D_model_A_last_epoch_weights.pth")) |
|
torch.save(D_model_B.state_dict(), os.path.join(save_dir, "D_model_B_last_epoch_weights.pth")) |