Spaces:
Runtime error
Runtime error
File size: 1,969 Bytes
34fb220 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import torch.nn as nn
import torchvision
class FeatureExtractor(nn.Module):
def __init__(self, cnn, feature_layer=11):
super(FeatureExtractor, self).__init__()
self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer + 1)])
def normalize(self, tensors, mean, std):
if not torch.is_tensor(tensors):
raise TypeError('tensor is not a torch image.')
for tensor in tensors:
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensors
def forward(self, x):
# it image is gray scale then make it to 3 channel
if x.size()[1] == 1:
x = x.expand(-1, 3, -1, -1)
# [-1: 1] image to [0:1] image---------------------------------------------------(1)
x = (x + 1) * 0.5
# https://pytorch.org/docs/stable/torchvision/models.html
x.data = self.normalize(x.data, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return self.features(x)
# Feature extracting using vgg19
vgg19 = torchvision.models.vgg19(pretrained=True)
feature_extractor = FeatureExtractor(vgg19, feature_layer=35)
feature_extractor.eval()
class VGG19Loss(object):
def __init__(self):
global feature_extractor
self.initialized = False
self.feature_extractor = feature_extractor
self.MSE = nn.MSELoss()
def __call__(self, output, target, device):
if self.initialized == False:
self.feature_extractor = self.feature_extractor.to(device)
self.MSE = self.MSE.to(device)
self.initialized = True
# [-1: 1] image to [0:1] image---------------------------------------------------(2)
output = (output + 1) * 0.5
target = (target + 1) * 0.5
output = self.feature_extractor(output)
target = self.feature_extractor(target).data
return self.MSE(output, target) |