Spaces:
tubui
/
Runtime error

File size: 4,859 Bytes
dfec228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from torch import nn
from torch.autograd import Variable
import torch
import torch.nn.functional as F
from .munit import ResBlocks, Conv2dBlock
import math


class Unet(nn.Module):
    def __init__(self, resolution=256, secret_len=100, return_residual=False) -> None:
        super().__init__()
        self.secret_len = secret_len
        self.return_residual = return_residual
        self.secret_dense = nn.Linear(secret_len, 16*16*3)
        log_resolution = int(math.log(resolution, 2))
        assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
        self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4)))

        self.enc = Encoder(2, 4, 6, 64, 'bn' , 'relu', 'reflect')
        self.dec = Decoder(2, 4, self.enc.output_dim, 3, 'bn', 'relu', 'reflect')
    
    def forward(self, image, secret):
        # import pdb; pdb.set_trace()
        fingerprint = F.relu(self.secret_dense(secret))
        fingerprint = fingerprint.view((-1, 3, 16, 16))
        fingerprint_enlarged = self.secret_upsample(fingerprint)
        inputs = torch.cat([fingerprint_enlarged, image], dim=1)
        emb = self.enc(inputs)
        # import pdb; pdb.set_trace()
        out = self.dec(emb)
        return out

class Encoder(nn.Module):
    def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
        super().__init__()
        self.model = []
        self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
        # downsampling blocks
        for i in range(n_downsample):
            self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
            dim *= 2
        # residual blocks
        self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
        # self.model = nn.(*self.model)
        self.model = nn.ModuleList(self.model)
        self.output_dim = dim

    def forward(self, x):
        out = []
        for block in self.model:
            x = block(x)
            out.append(x)
            # print(x.shape)
        return out


class Decoder(nn.Module):
    def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
        super(Decoder, self).__init__()

        self.model = []
        # AdaIN residual blocks
        self.model += [DecoderBlock('resblock', n_res, dim, res_norm, activ, pad_type=pad_type)]
        # upsampling blocks
        for i in range(n_upsample):
            self.model += [DecoderBlock('upsample', dim, dim//2,'bn', activ, pad_type)
                           ]
            dim //= 2
        # use reflection padding in the last conv layer
        self.output_layer = Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
        # self.model = nn.Sequential(*self.model)
        self.model = nn.ModuleList(self.model)

    def forward(self, x):
        x1 = x.pop()
        for block in self.model:
            x2 = x.pop()
            # print(x1.shape, x2.shape)
            x1 = block(x1, x2)
        x1 = self.output_layer(x1)
        return x1


class Merge(nn.Module):
    def __init__(self, dim, activation='relu'):
        super().__init__()
        self.conv = nn.Conv2d(2*dim, dim, 3, 1, 1)
        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)
    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)  # 2xdim
        x = self.conv(x)  # B,dim,H,W
        x = self.activation(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, block_type, in_dim, out_dim, norm, activ='relu', pad_type='reflect'):
        super().__init__()
        assert block_type in ['resblock', 'upsample']
        if block_type == 'resblock':
            self.core_layer = ResBlocks(in_dim, out_dim, norm, activ, pad_type=pad_type)
        else:
            assert out_dim == in_dim//2
            self.core_layer = nn.Sequential(nn.Upsample(scale_factor=2),
                           Conv2dBlock(in_dim, out_dim, 5, 1, 2, norm=norm, activation=activ, pad_type=pad_type))
        self.merge = Merge(out_dim, activ)
        
    def forward(self, x1, x2):
        x1 = self.core_layer(x1)
        return self.merge(x1, x2)