ubuntu commited on
Commit
32603e9
1 Parent(s): ad86786

Initial Commit

Browse files
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import shutil
3
+ import gradio as gr
4
+ from pygm_rrwm import pygm_rrwm
5
+
6
+
7
+ PYGM_IMG_DEFAULT_PATH = "src/pygm_default.png"
8
+ PYGM_SOLUTION_1_PATH = "src/pygm_image_1.png"
9
+ PYGM_SOLUTION_2_PATH = "src/pygm_image_2.png"
10
+
11
+
12
+ def _handle_pygm_solve(
13
+ img_1_path: str,
14
+ img_2_path: str,
15
+ kpts1_path: str,
16
+ kpts2_path: str,
17
+ ):
18
+ if img_1_path is None:
19
+ raise gr.Error("Please upload file completely!")
20
+ if img_2_path is None:
21
+ raise gr.Error("Please upload file completely!")
22
+ if kpts1_path is None:
23
+ raise gr.Error("Please upload file completely!")
24
+ if kpts1_path is None:
25
+ raise gr.Error("Please upload file completely!")
26
+
27
+ start_time = time.time()
28
+ pygm_rrwm(
29
+ img1_path=img_1_path,
30
+ img2_path=img_2_path,
31
+ kpts1_path=kpts1_path,
32
+ kpts2_path=kpts2_path,
33
+ output_path="src",
34
+ filename="pygm_image"
35
+ )
36
+ solved_time = time.time() - start_time
37
+
38
+ message = "Successfully solve the TSP problem, using time ({:.3f}s).".format(solved_time)
39
+
40
+ return message, PYGM_SOLUTION_1_PATH, PYGM_SOLUTION_2_PATH
41
+
42
+
43
+ def handle_pygm_solve(
44
+ img_1_path: str,
45
+ img_2_path: str,
46
+ kpts1_path: str,
47
+ kpts2_path: str,
48
+ ):
49
+ try:
50
+ message = _handle_pygm_solve(
51
+ img_1_path=img_1_path,
52
+ img_2_path=img_2_path,
53
+ kpts1_path=kpts1_path,
54
+ kpts2_path=kpts2_path,
55
+ )
56
+ return message
57
+ except Exception as e:
58
+ message = str(e)
59
+ return message, PYGM_SOLUTION_1_PATH, PYGM_SOLUTION_2_PATH
60
+
61
+
62
+ def handle_pygm_clear():
63
+ shutil.copy(
64
+ src=PYGM_IMG_DEFAULT_PATH,
65
+ dst=PYGM_SOLUTION_1_PATH
66
+ )
67
+ shutil.copy(
68
+ src=PYGM_IMG_DEFAULT_PATH,
69
+ dst=PYGM_SOLUTION_2_PATH
70
+ )
71
+
72
+ message = "successfully clear the files!"
73
+ return message, PYGM_SOLUTION_1_PATH, PYGM_SOLUTION_2_PATH
74
+
75
+
76
+ def convert_image_path_to_bytes(image_path):
77
+ with open(image_path, "rb") as f:
78
+ image_bytes = f.read()
79
+ return image_bytes
80
+
81
+
82
+ with gr.Blocks() as pygm_page:
83
+
84
+ gr.Markdown(
85
+ '''
86
+ This space displays the solution to the Graph Matching problem.
87
+ ## How to use this Space?
88
+ - Upload a '.pygm' file from pygmlib .
89
+ - The images of the TSP problem and solution will be shown after you click the solve button.
90
+ - Click the 'clear' button to clear all the files.
91
+ '''
92
+ )
93
+
94
+ with gr.Row(variant="panel"):
95
+ with gr.Column(scale=2):
96
+ with gr.Row():
97
+ pygm_img_1 = gr.File(
98
+ label="Upload .png File",
99
+ file_types=[".png"],
100
+ min_width=40,
101
+ )
102
+ pygm_img_2 = gr.File(
103
+ label="Upload .png File",
104
+ file_types=[".png"],
105
+ min_width=40,
106
+ )
107
+ with gr.Row():
108
+ pygm_kpts_1 = gr.File(
109
+ label="Upload .mat File",
110
+ file_types=[".mat"],
111
+ min_width=40,
112
+ )
113
+ pygm_kpts_2 = gr.File(
114
+ label="Upload .mat File",
115
+ file_types=[".mat"],
116
+ min_width=40,
117
+ )
118
+ info = gr.Textbox(
119
+ value="",
120
+ label="Log",
121
+ scale=4,
122
+ )
123
+ with gr.Column(scale=2):
124
+ pygm_solution_1 = gr.Image(
125
+ value=PYGM_SOLUTION_1_PATH,
126
+ type="filepath",
127
+ label="Original Images"
128
+ )
129
+ pygm_solution_2 = gr.Image(
130
+ value=PYGM_SOLUTION_2_PATH,
131
+ type="filepath",
132
+ label="Graph Matching Results"
133
+ )
134
+ with gr.Row():
135
+ with gr.Column(scale=1, min_width=100):
136
+ solve_button = gr.Button(
137
+ value="Solve",
138
+ variant="primary",
139
+ scale=1
140
+ )
141
+ with gr.Column(scale=1, min_width=100):
142
+ clear_button = gr.Button(
143
+ "Clear",
144
+ variant="secondary",
145
+ scale=1
146
+ )
147
+ with gr.Column(scale=8):
148
+ pass
149
+
150
+ solve_button.click(
151
+ handle_pygm_solve,
152
+ [pygm_img_1, pygm_img_2, pygm_kpts_1, pygm_kpts_2],
153
+ outputs=[info, pygm_solution_1, pygm_solution_2]
154
+ )
155
+
156
+ clear_button.click(
157
+ handle_pygm_clear,
158
+ inputs=None,
159
+ outputs=[info, pygm_solution_1, pygm_solution_2]
160
+ )
161
+
162
+
163
+ if __name__ == "__main__":
164
+ pygm_page.launch(debug = True)
pygm_rrwm.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch # pytorch backend
3
+ import torchvision # CV models
4
+ import pygmtools as pygm
5
+ import matplotlib.pyplot as plt # for plotting
6
+ from matplotlib.patches import ConnectionPatch # for plotting matching result
7
+ import scipy.io as sio # for loading .mat file
8
+ import scipy.spatial as spa # for Delaunay triangulation
9
+ from sklearn.decomposition import PCA as PCAdimReduc
10
+ import itertools
11
+ import numpy as np
12
+ from PIL import Image
13
+ pygm.set_backend('pytorch') # set default backend for pygmtools
14
+
15
+
16
+ ##################################################################
17
+ # Utils Func #
18
+ ##################################################################
19
+
20
+ def plot_image_with_graph(img, kpt, A=None):
21
+ plt.imshow(img)
22
+ plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
23
+ if A is not None:
24
+ for idx in torch.nonzero(A, as_tuple=False):
25
+ plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')
26
+
27
+
28
+ def delaunay_triangulation(kpt):
29
+ d = spa.Delaunay(kpt.numpy().transpose())
30
+ A = torch.zeros(len(kpt[0]), len(kpt[0]))
31
+ for simplex in d.simplices:
32
+ for pair in itertools.permutations(simplex, 2):
33
+ A[pair] = 1
34
+ return A
35
+
36
+
37
+ def plot_image_with_graphs(img1, img2, kpts1, kpts2, A1=None, A2=None,
38
+ title_1: str="Image 1", title_2: str="Image 2", filename="examples.png"):
39
+ plt.figure(figsize=(8, 4))
40
+ plt.subplot(1, 2, 1)
41
+ plt.title(title_1)
42
+ plot_image_with_graph(img1, kpts1, A1)
43
+ plt.subplot(1, 2, 2)
44
+ plt.title(title_2)
45
+ plot_image_with_graph(img2, kpts2, A2)
46
+ plt.savefig(filename)
47
+
48
+
49
+ def load_images(
50
+ img1_path: str,
51
+ img2_path: str,
52
+ kpts1_path: str,
53
+ kpts2_path: str,
54
+ obj_resize: tuple=(256, 256)
55
+ ):
56
+ img1 = Image.open(img1_path)
57
+ img2 = Image.open(img2_path)
58
+ kpts1 = torch.tensor(sio.loadmat(kpts1_path)['pts_coord'])
59
+ kpts2 = torch.tensor(sio.loadmat(kpts2_path)['pts_coord'])
60
+ kpts1[0] = kpts1[0] * obj_resize[0] / img1.size[0]
61
+ kpts1[1] = kpts1[1] * obj_resize[1] / img1.size[1]
62
+ kpts2[0] = kpts2[0] * obj_resize[0] / img2.size[0]
63
+ kpts2[1] = kpts2[1] * obj_resize[1] / img2.size[1]
64
+ img1 = img1.resize(obj_resize, resample=Image.Resampling.BILINEAR)
65
+ img2 = img2.resize(obj_resize, resample=Image.Resampling.BILINEAR)
66
+ return img1, img2, kpts1, kpts2
67
+
68
+
69
+ ##################################################################
70
+ # Process #
71
+ ##################################################################
72
+
73
+ def pygm_rrwm(
74
+ img1_path: str,
75
+ img2_path: str,
76
+ kpts1_path: str,
77
+ kpts2_path: str,
78
+ obj_resize: tuple=(256, 256),
79
+ output_path: str="examples",
80
+ filename: str="example"
81
+ ):
82
+ if not os.path.exists(output_path):
83
+ os.mkdir(output_path)
84
+ output_filename = os.path.join(output_path, filename) + "_{}.png"
85
+
86
+ # Load the images
87
+ img1, img2, kpts1, kpts2 = load_images(img1_path, img2_path, kpts1_path, kpts2_path, obj_resize)
88
+ plot_image_with_graphs(img1, img2, kpts1, kpts2, filename=output_filename.format(1))
89
+
90
+ # Build the graphs
91
+ A1 = delaunay_triangulation(kpts1)
92
+ A2 = delaunay_triangulation(kpts2)
93
+ A1 = ((kpts1.unsqueeze(1) - kpts1.unsqueeze(2)) ** 2).sum(dim=0) * A1
94
+ A1 = (A1 / A1.max()).to(dtype=torch.float32)
95
+ A2 = ((kpts2.unsqueeze(1) - kpts2.unsqueeze(2)) ** 2).sum(dim=0) * A2
96
+ A2 = (A2 / A2.max()).to(dtype=torch.float32)
97
+ # plot_image_with_graphs(img1, img2, kpts1, kpts2, A1, A2,
98
+ # "Image 1 with Graphs", "Image 2 with Graphs", output_filename.format(2))
99
+
100
+ # Extract node features
101
+ vgg16_cnn = torchvision.models.vgg16_bn(True)
102
+ torch_img1 = torch.from_numpy(np.array(img1, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
103
+ torch_img2 = torch.from_numpy(np.array(img2, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
104
+ with torch.set_grad_enabled(False):
105
+ feat1 = vgg16_cnn.features(torch_img1)
106
+ feat2 = vgg16_cnn.features(torch_img2)
107
+
108
+ # Normalize the features
109
+ num_features = feat1.shape[1]
110
+ def l2norm(node_feat):
111
+ return torch.nn.functional.local_response_norm(
112
+ node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)
113
+ feat1 = l2norm(feat1)
114
+ feat2 = l2norm(feat2)
115
+
116
+ # Up-sample the features to the original image size
117
+ feat1_upsample = torch.nn.functional.interpolate(feat1, (obj_resize[1], obj_resize[0]), mode='bilinear')
118
+ feat2_upsample = torch.nn.functional.interpolate(feat2, (obj_resize[1], obj_resize[0]), mode='bilinear')
119
+
120
+ # Visualize the extracted CNN feature (dimensionality reduction via principle component analysis)
121
+ pca_dim_reduc = PCAdimReduc(n_components=3, whiten=True)
122
+ feat_dim_reduc = pca_dim_reduc.fit_transform(
123
+ np.concatenate((
124
+ feat1_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy(),
125
+ feat2_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy()
126
+ ), axis=0)
127
+ )
128
+ feat_dim_reduc = feat_dim_reduc / np.max(np.abs(feat_dim_reduc), axis=0, keepdims=True) / 2 + 0.5
129
+ feat1_dim_reduc = feat_dim_reduc[:obj_resize[0] * obj_resize[1], :]
130
+ feat2_dim_reduc = feat_dim_reduc[obj_resize[0] * obj_resize[1]:, :]
131
+
132
+ # Plot
133
+ # plt.figure(figsize=(8, 4))
134
+ # plt.subplot(1, 2, 1)
135
+ # plt.title('Image 1 with CNN features')
136
+ # plot_image_with_graph(img1, kpts1, A1)
137
+ # plt.imshow(feat1_dim_reduc.reshape(obj_resize[1], obj_resize[0], 3), alpha=0.5)
138
+ # plt.subplot(1, 2, 2)
139
+ # plt.title('Image 2 with CNN features')
140
+ # plot_image_with_graph(img2, kpts2, A2)
141
+ # plt.imshow(feat2_dim_reduc.reshape(obj_resize[1], obj_resize[0], 3), alpha=0.5)
142
+ # plt.savefig(output_filename.format(3))
143
+
144
+ # Extract node features by nearest interpolation
145
+ rounded_kpts1 = torch.round(kpts1).to(dtype=torch.long)
146
+ rounded_kpts2 = torch.round(kpts2).to(dtype=torch.long)
147
+ node1 = feat1_upsample[0, :, rounded_kpts1[1], rounded_kpts1[0]].t() # shape: NxC
148
+ node2 = feat2_upsample[0, :, rounded_kpts2[1], rounded_kpts2[0]].t() # shape: NxC
149
+
150
+ # Build affinity matrix
151
+ conn1, edge1 = pygm.utils.dense_to_sparse(A1)
152
+ conn2, edge2 = pygm.utils.dense_to_sparse(A2)
153
+ import functools
154
+ gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1) # set affinity function
155
+ K = pygm.utils.build_aff_mat(node1, edge1, conn1, node2, edge2, conn2, edge_aff_fn=gaussian_aff)
156
+
157
+ # Plot affinity matrix
158
+ # plt.figure(figsize=(4, 4))
159
+ # plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
160
+ # plt.imshow(K.numpy(), cmap='Blues')
161
+ # plt.savefig(output_filename.format(4))
162
+
163
+ # Solve graph matching problem by RRWM solver
164
+ X = pygm.rrwm(K, kpts1.shape[1], kpts2.shape[1])
165
+ X = pygm.hungarian(X)
166
+
167
+ # Plot the matching
168
+ plt.figure(figsize=(8, 4))
169
+ plt.suptitle('Image Matching Result by RRWM')
170
+ ax1 = plt.subplot(1, 2, 1)
171
+ plot_image_with_graph(img1, kpts1, A1)
172
+ ax2 = plt.subplot(1, 2, 2)
173
+ plot_image_with_graph(img2, kpts2, A2)
174
+ for i in range(X.shape[0]):
175
+ j = torch.argmax(X[i]).item()
176
+ con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
177
+ axesA=ax1, axesB=ax2, color="red" if i != j else "green")
178
+ plt.gca().add_artist(con)
179
+ plt.savefig(output_filename.format(2))
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ pygmtools
3
+ matplotlib
4
+ torch==2.0.0
5
+ torchvision==0.15.1
6
+ scikit-learn
src/pygm_default.png ADDED
src/pygm_image_1.png ADDED
src/pygm_image_2.png ADDED