penpen commited on
Commit
61d6774
1 Parent(s): 93fdd33

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ import kornia
6
+
7
+ from PIL import Image
8
+ import numpy as np
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+
12
+
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+
15
+
16
+ # Define model
17
+
18
+ class BlurUpSample(nn.Module):
19
+ def __init__(self, c):
20
+ super(BlurUpSample, self).__init__()
21
+ self.blurpool = kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
22
+ self.upsample = nn.Upsample(scale_factor=(2, 2), mode='bilinear', align_corners=False)
23
+
24
+ def forward(self, x):
25
+ x = self.blurpool(x)
26
+ x = self.upsample(x)
27
+
28
+ return x
29
+
30
+ class DownLayer(nn.Module):
31
+ def __init__(self, c_in, c_out):
32
+ super(DownLayer, self).__init__()
33
+ self.maxblurpool = kornia.filters.MaxBlurPool2D(kernel_size=3)
34
+ self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1)
35
+ self.bn1 = nn.BatchNorm2d(c_out)
36
+ self.leakyrelu = nn.LeakyReLU(inplace=True)
37
+ self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1)
38
+ self.bn2 = nn.BatchNorm2d(c_out)
39
+
40
+ def forward(self, x):
41
+ x = self.maxblurpool(x)
42
+ x = self.conv1(x)
43
+ x = self.bn1(x)
44
+ x = self.leakyrelu(x)
45
+ x = self.conv2(x)
46
+ x = self.bn2(x)
47
+ x = self.leakyrelu(x)
48
+ return x
49
+
50
+
51
+ class UpLayer(nn.Module):
52
+ def __init__(self, c_in, c_out):
53
+ super(UpLayer, self).__init__()
54
+ self.upsample = BlurUpSample(c_in)
55
+ self.conv1 = nn.Conv2d(c_in+ c_out, c_out, kernel_size=3, stride=1, padding=1)
56
+ self.bn1 = nn.BatchNorm2d(c_out)
57
+ self.leakyrelu = nn.LeakyReLU(inplace=True)
58
+ self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1)
59
+ self.bn2 = nn.BatchNorm2d(c_out)
60
+
61
+ def forward(self, x, skip_x):
62
+ x = self.upsample(x)
63
+
64
+ dh = skip_x.size(2) - x.size(2)
65
+ dw = skip_x.size(3) - x.size(3)
66
+
67
+ x = F.pad(x, (dw // 2, dw - dw // 2, dh // 2, dh - dh // 2))
68
+
69
+ x = torch.cat([x, skip_x], dim=1)
70
+
71
+ x = self.conv1(x)
72
+ x = self.bn1(x)
73
+ x = self.leakyrelu(x)
74
+ x = self.conv2(x)
75
+ x = self.bn2(x)
76
+ x = self.leakyrelu(x)
77
+ return x
78
+
79
+ class Generator(nn.Module):
80
+ def __init__(self):
81
+ super(Generator, self).__init__()
82
+ self.conv1 = nn.Conv2d(5, 64, kernel_size=3, stride=1, padding=1)
83
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
84
+ self.batchnorm1 = nn.BatchNorm2d(64)
85
+ self.leakyrelu = nn.LeakyReLU(inplace=True)
86
+ self.downlayer1 = DownLayer(64, 128)
87
+ self.downlayer2 = DownLayer(128, 256)
88
+ self.downlayer3 = DownLayer(256, 512)
89
+ self.downlayer4 = DownLayer(512, 1024)
90
+ self.uplayer1 = UpLayer(1024, 512)
91
+ self.uplayer2 = UpLayer(512, 256)
92
+ self.uplayer3 = UpLayer(256, 128)
93
+ self.uplayer4 = UpLayer(128, 64)
94
+ self.conv3 = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
95
+
96
+ def forward(self, x):
97
+ #print(f'Input Shape: {x.shape}')
98
+ x1 = self.conv1(x)
99
+ x1 = self.batchnorm1(x1)
100
+ x1 = self.leakyrelu(x1)
101
+ x1 = self.conv2(x1)
102
+ x1 = self.batchnorm1(x1)
103
+ x1 = self.leakyrelu(x1)
104
+
105
+ #print(f'Processed Input Shape: {x.shape}')
106
+
107
+ x2 = self.downlayer1(x1)
108
+ x3 = self.downlayer2(x2)
109
+ x4 = self.downlayer3(x3)
110
+ x5 = self.downlayer4(x4)
111
+
112
+ #print(f'Done Downlayering... Shape: {x5.shape}')
113
+
114
+ x = self.uplayer1(x5, x4)
115
+ x = self.uplayer2(x, x3)
116
+ x = self.uplayer3(x, x2)
117
+ x = self.uplayer4(x, x1)
118
+ x = self.conv3(x)
119
+
120
+ #print(f'Output Shape: {x.shape}')
121
+ return x
122
+
123
+
124
+ transform_resize = A.Compose([
125
+ A.Resize(512, 512),
126
+ ToTensorV2(),
127
+ ])
128
+
129
+
130
+
131
+
132
+ # Load model
133
+ generator_model = Generator()
134
+ generator_model.load_state_dict(torch.load('age-transformation/large-aging-model.h5',map_location=torch.device(device)))
135
+ generator_model.to(device)
136
+ #generator_model.eval()
137
+ print("")
138
+
139
+
140
+
141
+ def age_filter(image, input_age, output_age):
142
+
143
+ resized_image = image.resize((512,512))
144
+
145
+ input_image = transform_resize(image=np.array(image))['image']/255
146
+
147
+
148
+ #input_image=(dataset[0]['normalized_input_image'])
149
+ age_map1 = torch.full((1, 512, 512), input_age / 100)
150
+ age_map2 = torch.full((1, 512, 512), output_age / 100)
151
+
152
+ input_tensor = torch.cat((input_image, age_map1,age_map2), dim=0)
153
+
154
+ with torch.no_grad():
155
+ model_output = generator_model(input_tensor.unsqueeze(0).to(device))
156
+
157
+ np_test = np.array(image)
158
+
159
+ new_image = (model_output.squeeze(0).cpu().permute(1,2,0).numpy()*255+np.array(resized_image)).astype('uint8')
160
+
161
+ sample_image = np.array(Image.fromarray(new_image).resize((np_test.shape[1],np_test.shape[0]))).astype('uint8')
162
+ return sample_image
163
+
164
+ import gradio as gr
165
+ from torchvision.transforms.functional import crop
166
+
167
+
168
+
169
+ def crop_and_process_image(input_img,input_age,output_ag):
170
+ # Crop the image using the provided crop tool coordinates
171
+ processed_image = Image.fromarray(input_img) # Modify this line to preprocess the cropped image
172
+
173
+ # Run the processed image through your model
174
+ output = age_filter(processed_image, input_age, output_ag)
175
+ # Return the output
176
+ return output
177
+
178
+ # Define the input image component with the crop tool
179
+ input_image = gr.Image(label="Input Image", interactive=True)
180
+
181
+ # Define the output image component
182
+ output_image = gr.Image(label="Output Image", type="pil")
183
+
184
+ input_image.style(height=512, width=512)
185
+ output_image.style(height=512, width=512)
186
+
187
+ input_age = gr.Slider(label="Input Age")
188
+ output_age = gr.Slider(label="Output Age")
189
+
190
+
191
+ # Define the function to be called when the button is pressed
192
+ def process_image(input_img,input_age,output_age):
193
+ # Convert the input image to a PyTorch tensor
194
+
195
+ # Call the crop_and_process_image function
196
+ output = crop_and_process_image(input_img,input_age,output_age)
197
+
198
+ # Convert the output tensor to a NumPy array and return it
199
+ output = Image.fromarray(output)
200
+ output.show()
201
+ return output
202
+
203
+
204
+
205
+ # Create the Gradio interface
206
+ gr.Interface(fn=process_image, inputs=[input_image,input_age,output_age], outputs=output_image, title="Image Crop and Process").launch(debug=True)