Spaces:
Runtime error
Runtime error
File size: 5,950 Bytes
035e10c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import torch.nn as nn
from torch.nn import init
from torch.optim import lr_scheduler
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
# lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
lr_l = 0.3 ** max(0, (epoch + opt.epoch_count - opt.n_epochs) // 5)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_weights(net, init_type='normal', init_gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=()):
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
class SignWithSigmoidGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
result = (x > 0).float()
sigmoid_result = torch.sigmoid(x)
ctx.save_for_backward(sigmoid_result)
return result
@staticmethod
def backward(ctx, grad_result):
(sigmoid_result,) = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)
else:
grad_input = None
return grad_input
class Painter(nn.Module):
def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3):
super().__init__()
self.enc_img = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(3, 32, 3, 1),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.ReflectionPad2d(1),
nn.Conv2d(32, 64, 3, 2),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ReflectionPad2d(1),
nn.Conv2d(64, 128, 3, 2),
nn.BatchNorm2d(128),
nn.ReLU(True))
self.enc_canvas = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(3, 32, 3, 1),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.ReflectionPad2d(1),
nn.Conv2d(32, 64, 3, 2),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ReflectionPad2d(1),
nn.Conv2d(64, 128, 3, 2),
nn.BatchNorm2d(128),
nn.ReLU(True))
self.conv = nn.Conv2d(128 * 2, hidden_dim, 1)
self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers)
self.linear_param = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, param_per_stroke))
self.linear_decider = nn.Linear(hidden_dim, 1)
self.query_pos = nn.Parameter(torch.rand(total_strokes, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
def forward(self, img, canvas):
b, _, H, W = img.shape
img_feat = self.enc_img(img)
canvas_feat = self.enc_canvas(canvas)
h, w = img_feat.shape[-2:]
feat = torch.cat([img_feat, canvas_feat], dim=1)
feat_conv = self.conv(feat)
pos_embed = torch.cat([
self.col_embed[:w].unsqueeze(0).contiguous().repeat(h, 1, 1),
self.row_embed[:h].unsqueeze(1).contiguous().repeat(1, w, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
hidden_state = self.transformer(pos_embed + feat_conv.flatten(2).permute(2, 0, 1).contiguous(),
self.query_pos.unsqueeze(1).contiguous().repeat(1, b, 1))
hidden_state = hidden_state.permute(1, 0, 2).contiguous()
param = self.linear_param(hidden_state)
s = hidden_state.shape[1]
grid = param[:, :, :2].view(b * s, 1, 1, 2).contiguous()
img_temp = img.unsqueeze(1).contiguous().repeat(1, s, 1, 1, 1).view(b * s, 3, H, W).contiguous()
color = nn.functional.grid_sample(img_temp, 2 * grid - 1, align_corners=False).view(b, s, 3).contiguous()
decision = self.linear_decider(hidden_state)
return torch.cat([param, color, color, torch.rand(b, s, 1, device=img.device)], dim=-1), decision
|