File size: 1,969 Bytes
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torchvision

class FeatureExtractor(nn.Module):
    def __init__(self, cnn, feature_layer=11):
        super(FeatureExtractor, self).__init__()
        self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer + 1)])

    def normalize(self, tensors, mean, std):
        if not torch.is_tensor(tensors):
            raise TypeError('tensor is not a torch image.')
        for tensor in tensors:
            for t, m, s in zip(tensor, mean, std):
                t.sub_(m).div_(s)
        return tensors

    def forward(self, x):
        # it image is gray scale then make it to 3 channel
        if x.size()[1] == 1:
            x = x.expand(-1, 3, -1, -1)
            
        # [-1: 1] image to  [0:1] image---------------------------------------------------(1)
        x = (x + 1) * 0.5
        
        # https://pytorch.org/docs/stable/torchvision/models.html
        x.data = self.normalize(x.data, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        return self.features(x)

# Feature extracting using vgg19
vgg19 = torchvision.models.vgg19(pretrained=True)
feature_extractor = FeatureExtractor(vgg19, feature_layer=35)
feature_extractor.eval()

class VGG19Loss(object):
    def __init__(self):
        global feature_extractor
        self.initialized = False
        self.feature_extractor = feature_extractor
        self.MSE = nn.MSELoss()

    def __call__(self, output, target, device):
        if self.initialized == False:
            self.feature_extractor = self.feature_extractor.to(device)
            self.MSE = self.MSE.to(device)
            self.initialized = True

        # [-1: 1] image to  [0:1] image---------------------------------------------------(2)
        output = (output + 1) * 0.5
        target = (target + 1) * 0.5

        output = self.feature_extractor(output)
        target = self.feature_extractor(target).data
        return self.MSE(output, target)