abhirajeshbhai commited on
Commit
9205986
1 Parent(s): 46918b5

implement unet and deplot

Browse files
Files changed (4) hide show
  1. app.py +30 -0
  2. banana_colorizer_unet.pth +3 -0
  3. model.py +103 -0
  4. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from PIL import Image
7
+ from model import model, image_transforms
8
+
9
+
10
+
11
+ def col_select(value):
12
+ print(value)
13
+
14
+
15
+ st.title("Banan Image Colorizer")
16
+
17
+ upload_file = st.file_uploader("Upload Image")
18
+
19
+ if upload_file:
20
+ image = upload_file
21
+ image = Image.open(image)
22
+ image_gs = image_transforms(image)
23
+ image_gs_prev = image_gs.permute(1, 2, 0).detach().cpu().numpy()
24
+
25
+ image_color = model(image_gs.unsqueeze(0)).squeeze().permute(1, 2, 0).detach().cpu().numpy()
26
+
27
+
28
+ col1, col2 = st.columns(2)
29
+ col1.image(image_gs_prev)
30
+ col2.image(image_color, clamp=True, channels='RGB')
banana_colorizer_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97d8406c9bae3c4ccda6962d483d23e1e649e46a8bbe25e1e2bef95e5abc13b3
3
+ size 124265610
model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+
5
+ import torch.nn.functional as F
6
+
7
+
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+
12
+ image_transforms = torchvision.transforms.Compose([
13
+ torchvision.transforms.Resize((256, 256)),
14
+ torchvision.transforms.Grayscale(),
15
+ torchvision.transforms.ToTensor(),
16
+ torchvision.transforms.Normalize(mean=[0.0], std=[1.0])
17
+ ])
18
+
19
+ class ConvBlock(nn.Module):
20
+ def __init__(self, in_channel, out_channel):
21
+ super(ConvBlock, self).__init__()
22
+ self.main = nn.Sequential(
23
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
24
+ nn.BatchNorm2d(out_channel),
25
+ nn.ReLU(True),
26
+ nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
27
+ nn.BatchNorm2d(out_channel),
28
+ nn.ReLU(True)
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.main(x)
33
+
34
+ class UNETFruitColor(nn.Module):
35
+ def __init__(self):
36
+ super(UNETFruitColor, self).__init__()
37
+
38
+ self.convs = [64, 128, 256, 512]
39
+ self.convEncoder = nn.ModuleList()
40
+
41
+ in_feature = 1
42
+ for conv in self.convs:
43
+ self.convEncoder.append(ConvBlock(in_feature, conv))
44
+ in_feature = conv
45
+
46
+ self.bottleNeck = ConvBlock(self.convs[-1], self.convs[-1]*2)
47
+
48
+ in_feature = self.convs[-1]*2
49
+
50
+ self.convDecoder = nn.ModuleList()
51
+ self.decoderUpConvs = nn.ModuleList()
52
+
53
+ for conv in self.convs[::-1]:
54
+ self.convDecoder.append(ConvBlock(in_feature, conv))
55
+ self.decoderUpConvs.append(nn.ConvTranspose2d(in_feature, conv, kernel_size=2, stride=2, padding=0))
56
+ in_feature = conv
57
+
58
+
59
+ # final conv and deconv
60
+ self.finalUpConv = nn.Conv2d(in_feature, 3, (1, 1))
61
+ self.sigmoid = nn.Sigmoid()
62
+
63
+ def forward(self,x):
64
+ skip_conns = []
65
+ for conv in self.convEncoder:
66
+ # conv
67
+ x = conv(x)
68
+ # append for skip conns
69
+ skip_conns.append(x)
70
+ # max pool
71
+ x = F.max_pool2d(x, (2,2), stride=2)
72
+
73
+ x = self.bottleNeck(x)
74
+
75
+ skip_conns = skip_conns[::-1]
76
+
77
+ for idx in range(len(self.convDecoder)):
78
+ # do upsample here
79
+ upconv = self.decoderUpConvs[idx]
80
+ deconv = self.convDecoder[idx]
81
+ skp = skip_conns[idx]
82
+
83
+ # do up conv
84
+ x = upconv(x)
85
+
86
+ # crop and cat
87
+ x_cat = torchvision.transforms.Resize((x.shape[2], x.shape[3]))(skp)
88
+ x = torch.cat([x_cat, x], dim=1)
89
+
90
+ # do deconv
91
+ x = deconv(x)
92
+
93
+ # final
94
+
95
+ x = self.finalUpConv(x)
96
+ # x = self.sigmoid(x)
97
+ return x
98
+
99
+
100
+
101
+ model = UNETFruitColor()
102
+ model.load_state_dict(torch.load("banana_colorizer_unet.pth", map_location=device),strict=True)
103
+ model.eval()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ streamlit