K00B404 commited on
Commit
cf71845
1 Parent(s): 3297549

Upload 2 files

Browse files
Files changed (2) hide show
  1. big_1024_model.py +39 -0
  2. small_256_model.py +29 -0
big_1024_model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ class UNet(nn.Module):
5
+ def __init__(self):
6
+ super(UNet, self).__init__()
7
+
8
+ # Encoder
9
+ self.encoder = nn.Sequential(
10
+ nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), # 256 -> 128
11
+ nn.ReLU(inplace=True),
12
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128 -> 64
13
+ nn.ReLU(inplace=True),
14
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 64 -> 32
15
+ nn.ReLU(inplace=True),
16
+ nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # 32 -> 16
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1), # 16 -> 8
19
+ nn.ReLU(inplace=True)
20
+ )
21
+
22
+ # Decoder
23
+ self.decoder = nn.Sequential(
24
+ nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1), # 8 -> 16
25
+ nn.ReLU(inplace=True),
26
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 16 -> 32
27
+ nn.ReLU(inplace=True),
28
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 32 -> 64
29
+ nn.ReLU(inplace=True),
30
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 64 -> 128
31
+ nn.ReLU(inplace=True),
32
+ nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), # 128 -> 256
33
+ nn.Tanh() # Output range [-1, 1]
34
+ )
35
+
36
+ def forward(self, x):
37
+ enc = self.encoder(x)
38
+ dec = self.decoder(enc)
39
+ return dec
small_256_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ class UNet(nn.Module):
5
+ def __init__(self):
6
+ super(UNet, self).__init__()
7
+ # Encoder
8
+ self.encoder = nn.Sequential(
9
+ nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
10
+ nn.ReLU(inplace=True),
11
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
14
+ nn.ReLU(inplace=True),
15
+ )
16
+ # Decoder
17
+ self.decoder = nn.Sequential(
18
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
19
+ nn.ReLU(inplace=True),
20
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
21
+ nn.ReLU(inplace=True),
22
+ nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
23
+ nn.Tanh()
24
+ )
25
+
26
+ def forward(self, x):
27
+ enc = self.encoder(x)
28
+ dec = self.decoder(enc)
29
+ return dec