Spaces:
Build error
Build error
Ren Jiawei
commited on
Commit
•
d7b89b7
1
Parent(s):
1f2ff91
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +1 -0
- .gitignore +4 -0
- .idea/.gitignore +8 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/pointcloud-c.iml +8 -0
- .idea/vcs.xml +6 -0
- DGCNN.py +121 -0
- GDANet_WOLFMix.t7 +3 -0
- GDANet_cls.py +113 -0
- __pycache__/DGCNN.cpython-38.pyc +0 -0
- __pycache__/GDANet_cls.cpython-38.pyc +0 -0
- app.py +119 -0
- dgcnn.t7 +3 -0
- downsample.py +12 -0
- modelnet_c/.DS_Store +0 -0
- modelnet_c/add_global_0.h5 +3 -0
- modelnet_c/add_global_1.h5 +3 -0
- modelnet_c/add_global_2.h5 +3 -0
- modelnet_c/add_global_3.h5 +3 -0
- modelnet_c/add_global_4.h5 +3 -0
- modelnet_c/add_local_0.h5 +3 -0
- modelnet_c/add_local_1.h5 +3 -0
- modelnet_c/add_local_2.h5 +3 -0
- modelnet_c/add_local_3.h5 +3 -0
- modelnet_c/add_local_4.h5 +3 -0
- modelnet_c/clean.h5 +3 -0
- modelnet_c/dropout_global_0.h5 +3 -0
- modelnet_c/dropout_global_1.h5 +3 -0
- modelnet_c/dropout_global_2.h5 +3 -0
- modelnet_c/dropout_global_3.h5 +3 -0
- modelnet_c/dropout_global_4.h5 +3 -0
- modelnet_c/dropout_local_0.h5 +3 -0
- modelnet_c/dropout_local_1.h5 +3 -0
- modelnet_c/dropout_local_2.h5 +3 -0
- modelnet_c/dropout_local_3.h5 +3 -0
- modelnet_c/dropout_local_4.h5 +3 -0
- modelnet_c/jitter_0.h5 +3 -0
- modelnet_c/jitter_1.h5 +3 -0
- modelnet_c/jitter_2.h5 +3 -0
- modelnet_c/jitter_3.h5 +3 -0
- modelnet_c/jitter_4.h5 +3 -0
- modelnet_c/rotate_0.h5 +3 -0
- modelnet_c/rotate_1.h5 +3 -0
- modelnet_c/rotate_2.h5 +3 -0
- modelnet_c/rotate_3.h5 +3 -0
- modelnet_c/rotate_4.h5 +3 -0
- modelnet_c/scale_0.h5 +3 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.t7 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.DS_store
|
2 |
+
*.pyc
|
3 |
+
flagged
|
4 |
+
*.png
|
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pointcloud-c)" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/pointcloud-c.iml" filepath="$PROJECT_DIR$/.idea/pointcloud-c.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/pointcloud-c.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="Python 3.8 (pointcloud-c)" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
DGCNN.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@Author: Yue Wang
|
5 |
+
@Contact: yuewangx@mit.edu
|
6 |
+
@File: model.py
|
7 |
+
@Time: 2018/10/13 6:35 PM
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import copy
|
13 |
+
import math
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
|
20 |
+
def knn(x, k):
|
21 |
+
inner = -2 * torch.matmul(x.transpose(2, 1), x)
|
22 |
+
xx = torch.sum(x ** 2, dim=1, keepdim=True)
|
23 |
+
pairwise_distance = -xx - inner - xx.transpose(2, 1)
|
24 |
+
|
25 |
+
idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
|
26 |
+
return idx
|
27 |
+
|
28 |
+
|
29 |
+
def get_graph_feature(x, k=20, idx=None):
|
30 |
+
batch_size = x.size(0)
|
31 |
+
num_points = x.size(2)
|
32 |
+
x = x.view(batch_size, -1, num_points)
|
33 |
+
if idx is None:
|
34 |
+
idx = knn(x, k=k) # (batch_size, num_points, k)
|
35 |
+
device = torch.device('cpu')
|
36 |
+
|
37 |
+
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
|
38 |
+
|
39 |
+
idx = idx + idx_base
|
40 |
+
|
41 |
+
idx = idx.view(-1)
|
42 |
+
|
43 |
+
_, num_dims, _ = x.size()
|
44 |
+
|
45 |
+
x = x.transpose(2,
|
46 |
+
1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
|
47 |
+
feature = x.view(batch_size * num_points, -1)[idx, :]
|
48 |
+
feature = feature.view(batch_size, num_points, k, num_dims)
|
49 |
+
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
|
50 |
+
|
51 |
+
feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous()
|
52 |
+
|
53 |
+
return feature
|
54 |
+
|
55 |
+
class DGCNN(nn.Module):
|
56 |
+
def __init__(self, output_channels=40):
|
57 |
+
super(DGCNN, self).__init__()
|
58 |
+
self.k = 20
|
59 |
+
emb_dims = 1024
|
60 |
+
dropout = 0.5
|
61 |
+
|
62 |
+
self.bn1 = nn.BatchNorm2d(64)
|
63 |
+
self.bn2 = nn.BatchNorm2d(64)
|
64 |
+
self.bn3 = nn.BatchNorm2d(128)
|
65 |
+
self.bn4 = nn.BatchNorm2d(256)
|
66 |
+
self.bn5 = nn.BatchNorm1d(emb_dims)
|
67 |
+
|
68 |
+
self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
|
69 |
+
self.bn1,
|
70 |
+
nn.LeakyReLU(negative_slope=0.2))
|
71 |
+
self.conv2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
|
72 |
+
self.bn2,
|
73 |
+
nn.LeakyReLU(negative_slope=0.2))
|
74 |
+
self.conv3 = nn.Sequential(nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False),
|
75 |
+
self.bn3,
|
76 |
+
nn.LeakyReLU(negative_slope=0.2))
|
77 |
+
self.conv4 = nn.Sequential(nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False),
|
78 |
+
self.bn4,
|
79 |
+
nn.LeakyReLU(negative_slope=0.2))
|
80 |
+
self.conv5 = nn.Sequential(nn.Conv1d(512, emb_dims, kernel_size=1, bias=False),
|
81 |
+
self.bn5,
|
82 |
+
nn.LeakyReLU(negative_slope=0.2))
|
83 |
+
self.linear1 = nn.Linear(emb_dims * 2, 512, bias=False)
|
84 |
+
self.bn6 = nn.BatchNorm1d(512)
|
85 |
+
self.dp1 = nn.Dropout(p=dropout)
|
86 |
+
self.linear2 = nn.Linear(512, 256)
|
87 |
+
self.bn7 = nn.BatchNorm1d(256)
|
88 |
+
self.dp2 = nn.Dropout(p=dropout)
|
89 |
+
self.linear3 = nn.Linear(256, output_channels)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
batch_size = x.size(0)
|
93 |
+
x = get_graph_feature(x, k=self.k)
|
94 |
+
x = self.conv1(x)
|
95 |
+
x1 = x.max(dim=-1, keepdim=False)[0]
|
96 |
+
|
97 |
+
x = get_graph_feature(x1, k=self.k)
|
98 |
+
x = self.conv2(x)
|
99 |
+
x2 = x.max(dim=-1, keepdim=False)[0]
|
100 |
+
|
101 |
+
x = get_graph_feature(x2, k=self.k)
|
102 |
+
x = self.conv3(x)
|
103 |
+
x3 = x.max(dim=-1, keepdim=False)[0]
|
104 |
+
|
105 |
+
x = get_graph_feature(x3, k=self.k)
|
106 |
+
x = self.conv4(x)
|
107 |
+
x4 = x.max(dim=-1, keepdim=False)[0]
|
108 |
+
|
109 |
+
x = torch.cat((x1, x2, x3, x4), dim=1)
|
110 |
+
|
111 |
+
x = self.conv5(x)
|
112 |
+
x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
|
113 |
+
x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
|
114 |
+
x = torch.cat((x1, x2), 1)
|
115 |
+
|
116 |
+
x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
|
117 |
+
x = self.dp1(x)
|
118 |
+
x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
|
119 |
+
x = self.dp2(x)
|
120 |
+
x = self.linear3(x)
|
121 |
+
return x
|
GDANet_WOLFMix.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef1f05156c6ace4f72e9e70ac373dd7f5d8ece8fe2af15a1099c56b8e13431dd
|
3 |
+
size 3796397
|
GDANet_cls.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from util.GDANet_util import local_operator, GDM, SGCAM
|
5 |
+
|
6 |
+
|
7 |
+
class GDANET(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(GDANET, self).__init__()
|
10 |
+
|
11 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=0.1)
|
12 |
+
self.bn11 = nn.BatchNorm2d(64, momentum=0.1)
|
13 |
+
self.bn12 = nn.BatchNorm1d(64, momentum=0.1)
|
14 |
+
|
15 |
+
self.bn2 = nn.BatchNorm2d(64, momentum=0.1)
|
16 |
+
self.bn21 = nn.BatchNorm2d(64, momentum=0.1)
|
17 |
+
self.bn22 = nn.BatchNorm1d(64, momentum=0.1)
|
18 |
+
|
19 |
+
self.bn3 = nn.BatchNorm2d(128, momentum=0.1)
|
20 |
+
self.bn31 = nn.BatchNorm2d(128, momentum=0.1)
|
21 |
+
self.bn32 = nn.BatchNorm1d(128, momentum=0.1)
|
22 |
+
|
23 |
+
self.bn4 = nn.BatchNorm1d(512, momentum=0.1)
|
24 |
+
|
25 |
+
self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=True),
|
26 |
+
self.bn1)
|
27 |
+
self.conv11 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=True),
|
28 |
+
self.bn11)
|
29 |
+
self.conv12 = nn.Sequential(nn.Conv1d(64 * 2, 64, kernel_size=1, bias=True),
|
30 |
+
self.bn12)
|
31 |
+
|
32 |
+
self.conv2 = nn.Sequential(nn.Conv2d(67 * 2, 64, kernel_size=1, bias=True),
|
33 |
+
self.bn2)
|
34 |
+
self.conv21 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=True),
|
35 |
+
self.bn21)
|
36 |
+
self.conv22 = nn.Sequential(nn.Conv1d(64 * 2, 64, kernel_size=1, bias=True),
|
37 |
+
self.bn22)
|
38 |
+
|
39 |
+
self.conv3 = nn.Sequential(nn.Conv2d(131 * 2, 128, kernel_size=1, bias=True),
|
40 |
+
self.bn3)
|
41 |
+
self.conv31 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=1, bias=True),
|
42 |
+
self.bn31)
|
43 |
+
self.conv32 = nn.Sequential(nn.Conv1d(128, 128, kernel_size=1, bias=True),
|
44 |
+
self.bn32)
|
45 |
+
|
46 |
+
self.conv4 = nn.Sequential(nn.Conv1d(256, 512, kernel_size=1, bias=True),
|
47 |
+
self.bn4)
|
48 |
+
|
49 |
+
self.SGCAM_1s = SGCAM(64)
|
50 |
+
self.SGCAM_1g = SGCAM(64)
|
51 |
+
self.SGCAM_2s = SGCAM(64)
|
52 |
+
self.SGCAM_2g = SGCAM(64)
|
53 |
+
|
54 |
+
self.linear1 = nn.Linear(1024, 512, bias=True)
|
55 |
+
self.bn6 = nn.BatchNorm1d(512)
|
56 |
+
self.dp1 = nn.Dropout(p=0.4)
|
57 |
+
self.linear2 = nn.Linear(512, 256, bias=True)
|
58 |
+
self.bn7 = nn.BatchNorm1d(256)
|
59 |
+
self.dp2 = nn.Dropout(p=0.4)
|
60 |
+
self.linear3 = nn.Linear(256, 40, bias=True)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
B, C, N = x.size()
|
64 |
+
###############
|
65 |
+
"""block 1"""
|
66 |
+
# Local operator:
|
67 |
+
x1 = local_operator(x, k=30)
|
68 |
+
x1 = F.relu(self.conv1(x1))
|
69 |
+
x1 = F.relu(self.conv11(x1))
|
70 |
+
x1 = x1.max(dim=-1, keepdim=False)[0]
|
71 |
+
|
72 |
+
# Geometry-Disentangle Module:
|
73 |
+
x1s, x1g = GDM(x1, M=256)
|
74 |
+
|
75 |
+
# Sharp-Gentle Complementary Attention Module:
|
76 |
+
y1s = self.SGCAM_1s(x1, x1s.transpose(2, 1))
|
77 |
+
y1g = self.SGCAM_1g(x1, x1g.transpose(2, 1))
|
78 |
+
z1 = torch.cat([y1s, y1g], 1)
|
79 |
+
z1 = F.relu(self.conv12(z1))
|
80 |
+
###############
|
81 |
+
"""block 2"""
|
82 |
+
x1t = torch.cat((x, z1), dim=1)
|
83 |
+
x2 = local_operator(x1t, k=30)
|
84 |
+
x2 = F.relu(self.conv2(x2))
|
85 |
+
x2 = F.relu(self.conv21(x2))
|
86 |
+
x2 = x2.max(dim=-1, keepdim=False)[0]
|
87 |
+
|
88 |
+
x2s, x2g = GDM(x2, M=256)
|
89 |
+
|
90 |
+
y2s = self.SGCAM_2s(x2, x2s.transpose(2, 1))
|
91 |
+
y2g = self.SGCAM_2g(x2, x2g.transpose(2, 1))
|
92 |
+
z2 = torch.cat([y2s, y2g], 1)
|
93 |
+
z2 = F.relu(self.conv22(z2))
|
94 |
+
###############
|
95 |
+
x2t = torch.cat((x1t, z2), dim=1)
|
96 |
+
x3 = local_operator(x2t, k=30)
|
97 |
+
x3 = F.relu(self.conv3(x3))
|
98 |
+
x3 = F.relu(self.conv31(x3))
|
99 |
+
x3 = x3.max(dim=-1, keepdim=False)[0]
|
100 |
+
z3 = F.relu(self.conv32(x3))
|
101 |
+
###############
|
102 |
+
x = torch.cat((z1, z2, z3), dim=1)
|
103 |
+
x = F.relu(self.conv4(x))
|
104 |
+
x11 = F.adaptive_max_pool1d(x, 1).view(B, -1)
|
105 |
+
x22 = F.adaptive_avg_pool1d(x, 1).view(B, -1)
|
106 |
+
x = torch.cat((x11, x22), 1)
|
107 |
+
|
108 |
+
x = F.relu(self.bn6(self.linear1(x)))
|
109 |
+
x = self.dp1(x)
|
110 |
+
x = F.relu(self.bn7(self.linear2(x)))
|
111 |
+
x = self.dp2(x)
|
112 |
+
x = self.linear3(x)
|
113 |
+
return x
|
__pycache__/DGCNN.cpython-38.pyc
ADDED
Binary file (3.23 kB). View file
|
|
__pycache__/GDANet_cls.cpython-38.pyc
ADDED
Binary file (2.94 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import mathutils
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import matplotlib
|
7 |
+
import matplotlib.cm as cmx
|
8 |
+
import os.path as osp
|
9 |
+
import h5py
|
10 |
+
import random
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from GDANet_cls import GDANET
|
15 |
+
from DGCNN import DGCNN
|
16 |
+
|
17 |
+
with open('shape_names.txt') as f:
|
18 |
+
CLASS_NAME = f.read().splitlines()
|
19 |
+
|
20 |
+
model_gda = GDANET()
|
21 |
+
model_gda = nn.DataParallel(model_gda)
|
22 |
+
model_gda.load_state_dict(torch.load('./GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
|
23 |
+
model_gda.eval()
|
24 |
+
|
25 |
+
model_dgcnn = DGCNN()
|
26 |
+
model_dgcnn = nn.DataParallel(model_dgcnn)
|
27 |
+
model_dgcnn.load_state_dict(torch.load('./dgcnn.t7', map_location=torch.device('cpu')))
|
28 |
+
model_dgcnn.eval()
|
29 |
+
|
30 |
+
def pyplot_draw_point_cloud(points, corruption):
|
31 |
+
rot1 = mathutils.Euler([-math.pi / 2, 0, 0]).to_matrix().to_3x3()
|
32 |
+
rot2 = mathutils.Euler([0, 0, math.pi]).to_matrix().to_3x3()
|
33 |
+
points = np.dot(points, rot1)
|
34 |
+
points = np.dot(points, rot2)
|
35 |
+
x, y, z = points[:, 0], points[:, 1], points[:, 2]
|
36 |
+
colorsMap = 'winter'
|
37 |
+
cs = y
|
38 |
+
cm = plt.get_cmap(colorsMap)
|
39 |
+
cNorm = matplotlib.colors.Normalize(vmin=-1, vmax=1)
|
40 |
+
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
|
41 |
+
fig = plt.figure(figsize=(5, 5))
|
42 |
+
ax = fig.add_subplot(111, projection='3d')
|
43 |
+
ax.scatter(x, y, z, c=scalarMap.to_rgba(cs))
|
44 |
+
scalarMap.set_array(cs)
|
45 |
+
ax.set_xlim(-1, 1)
|
46 |
+
ax.set_ylim(-1, 1)
|
47 |
+
ax.set_zlim(-1, 1)
|
48 |
+
plt.axis('off')
|
49 |
+
plt.title(corruption, fontsize=30)
|
50 |
+
plt.tight_layout()
|
51 |
+
plt.savefig('visualization.png', bbox_inches='tight', dpi=200)
|
52 |
+
plt.close()
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
def load_dataset(corruption_idx, severity):
|
57 |
+
corruptions = [
|
58 |
+
'clean',
|
59 |
+
'scale',
|
60 |
+
'jitter',
|
61 |
+
'rotate',
|
62 |
+
'dropout_global',
|
63 |
+
'dropout_local',
|
64 |
+
'add_global',
|
65 |
+
'add_local',
|
66 |
+
]
|
67 |
+
corruption_type = corruptions[corruption_idx]
|
68 |
+
if corruption_type == 'clean':
|
69 |
+
f = h5py.File(osp.join('modelnet_c', corruption_type + '.h5'))
|
70 |
+
else:
|
71 |
+
f = h5py.File(osp.join('modelnet_c', corruption_type + '_{}'.format(severity-1) + '.h5'))
|
72 |
+
data = f['data'][:].astype('float32')
|
73 |
+
label = f['label'][:].astype('int64')
|
74 |
+
f.close()
|
75 |
+
return data, label
|
76 |
+
|
77 |
+
def recognize_pcd(model, pcd):
|
78 |
+
pcd = torch.tensor(pcd).unsqueeze(0)
|
79 |
+
pcd = pcd.permute(0, 2, 1)
|
80 |
+
output = model(pcd)
|
81 |
+
prediction = output.softmax(-1).flatten()
|
82 |
+
_, top5_idx = torch.topk(prediction, 5)
|
83 |
+
return {CLASS_NAME[i]: float(prediction[i]) for i in top5_idx.tolist()}
|
84 |
+
|
85 |
+
def run(seed, corruption_idx, severity):
|
86 |
+
data, label = load_dataset(corruption_idx, severity)
|
87 |
+
sample_indx = int(seed)
|
88 |
+
pcd, cls = data[sample_indx], label[sample_indx]
|
89 |
+
pyplot_draw_point_cloud(pcd, CLASS_NAME[cls[0]])
|
90 |
+
output = 'visualization.png'
|
91 |
+
return output, recognize_pcd(model_dgcnn, pcd), recognize_pcd(model_gda, pcd)
|
92 |
+
|
93 |
+
if __name__ == '__main__':
|
94 |
+
iface = gr.Interface(
|
95 |
+
fn=run,
|
96 |
+
inputs=[
|
97 |
+
gr.components.Number(label='Sample Seed', precision=0),
|
98 |
+
gr.components.Radio(
|
99 |
+
['Clean', 'Scale', 'Jitter', 'Rotate', 'Drop Global', 'Drop Local', 'Add Global', 'Add Local'],
|
100 |
+
value='Clean', type="index", label='Corruption Type'),
|
101 |
+
gr.components.Slider(1, 5, step=1, label='Corruption severity'),
|
102 |
+
],
|
103 |
+
outputs=[
|
104 |
+
gr.components.Image(type="file", label="Visualization"),
|
105 |
+
gr.components.Label(num_top_classes=5, label="Baseline (DGCNN) Prediction"),
|
106 |
+
gr.components.Label(num_top_classes=5, label="Ours (GDANet+WolfMix) Prediction")
|
107 |
+
],
|
108 |
+
live=False,
|
109 |
+
allow_flagging='never',
|
110 |
+
title="Benchmarking and Analyzing Point Cloud Classification under Corruptions [ICML 2022]",
|
111 |
+
description="Welcome to the demo of ModelNet-C! You can visualize various types of corrupted point clouds in ModelNet-C and see how our proposed techniques contribute to robust predicitions compared to baseline methods.",
|
112 |
+
examples=[
|
113 |
+
[0, 'Jitter', 5],
|
114 |
+
[999, 'Drop Local', 5],
|
115 |
+
],
|
116 |
+
# css=".output-image, .image-preview {height: 500px !important}",
|
117 |
+
article="<p style='text-align: center'><a href='https://github.com/jiawei-ren/ModelNet-C' target='_blank'>ModelNet-C @ GitHub</a></p> "
|
118 |
+
)
|
119 |
+
iface.launch()
|
dgcnn.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f366f60ca9dacf42cff5b747ba86020f10a6480ab31bb8122a8a609152ce4baa
|
3 |
+
size 7268024
|
downsample.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import h5py
|
3 |
+
|
4 |
+
for fpath in glob.glob('modelnet_c/*.h5'):
|
5 |
+
f = h5py.File(fpath)
|
6 |
+
data = f['data'][:].astype('float32')
|
7 |
+
label = f['label'][:].astype('int64')
|
8 |
+
f.close()
|
9 |
+
f = h5py.File(fpath, 'w')
|
10 |
+
f.create_dataset('data', data=data[:100])
|
11 |
+
f.create_dataset('label', data=label[:100])
|
12 |
+
f.close()
|
modelnet_c/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
modelnet_c/add_global_0.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0588ab009c0d0b4f4f8598e2fd7a6df0df14937b30a42ecdfc6c2811e70d494e
|
3 |
+
size 61267680
|
modelnet_c/add_global_1.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e18f34cf1e30b58758edfbe20f4e0db0a53e067bc87520e20ee8eb9fc85035d
|
3 |
+
size 61860000
|
modelnet_c/add_global_2.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c01495ed18ede2c5d9bab1cfc29555b9153e862c827909295504f76909addef
|
3 |
+
size 62452320
|
modelnet_c/add_global_3.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8d5e4594ca8aeb84e10397e0189b45a594f47009a4669d4efe070ffe72d4c4c
|
3 |
+
size 63044640
|
modelnet_c/add_global_4.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:989951ee7a7056f70c03f671a61a7ac55ca3d2b78eeccfeb855fcc033b276797
|
3 |
+
size 63636960
|
modelnet_c/add_local_0.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd0e3adfcc8c98136b33f0a712033340e594cfcd7dafa3f76e6a8f517a59a429
|
3 |
+
size 33310176
|
modelnet_c/add_local_1.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43986bd483fdcc5f4a8779c31f3b45d45401b82d18443161872d0862106a0fa6
|
3 |
+
size 36271776
|
modelnet_c/add_local_2.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d69968a6f58f25687058eaaca1a67a6890396cac812529e9529b978054e94c36
|
3 |
+
size 39233376
|
modelnet_c/add_local_3.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:029c794a19cffe80797950d2aac661fe2cc35b978d486e3684024a0bcf42c10c
|
3 |
+
size 42194976
|
modelnet_c/add_local_4.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd8597a0a95f358cacb9f5ad8c1d76b74179988b4037b277dff7c58cd95c8146
|
3 |
+
size 45156576
|
modelnet_c/clean.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:203ec0037ebdee84f13df5abe12e3cf0e1047192832ff6149ac54b4deaf37931
|
3 |
+
size 30348576
|
modelnet_c/dropout_global_0.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d8b7527cc2512fd83e074b3f814cb88253781b2331d9f3c79c54acc3a0f6e64e
|
3 |
+
size 22766880
|
modelnet_c/dropout_global_1.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:00c769d0f5f3bd7f9a4bf261d23b567fe75861b54c83671f5c46fe7ea288c79c
|
3 |
+
size 18976032
|
modelnet_c/dropout_global_2.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff6fe9780aee5fb4995e81313a3341df7b6c23d710e2bb0cd1bc46c9cb6e6815
|
3 |
+
size 15185184
|
modelnet_c/dropout_global_3.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8d9c46fe1d02d779895135d42861c34988355893c4d81ba54586ebb7571b4cd
|
3 |
+
size 11394336
|
modelnet_c/dropout_global_4.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0342a91f9e16e29751e90296171c593e1eb35da0ffe3579a6664bb18c2f28d9b
|
3 |
+
size 7603488
|
modelnet_c/dropout_local_0.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8751d701b16e7f37065a1322c4008bff807db5b2f2ab5ce9254fa941a519742
|
3 |
+
size 27386976
|
modelnet_c/dropout_local_1.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2441fae91bc29fba8953799e11f604ec5beeb20913380ca225e1b9a8fe3b500
|
3 |
+
size 24425376
|
modelnet_c/dropout_local_2.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fa9a2ba4146bc522d7731a495134babd19d8b212105530abbcd40256c7eaaed
|
3 |
+
size 21463776
|
modelnet_c/dropout_local_3.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:afc6742314f168e7a53209e81ce4298485fb665578953e9f231eb80fc7dd9940
|
3 |
+
size 18502176
|
modelnet_c/dropout_local_4.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6fc12bb058cb674274a0e01e79b3810cf693e85aaa0d8a21832a3f3108d8c69f
|
3 |
+
size 15540576
|
modelnet_c/jitter_0.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36b0e5c417d801b2b12d9a97cb4a998eb3d639d3a9bc7a714ee70dbda25d2284
|
3 |
+
size 60675360
|
modelnet_c/jitter_1.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92b770e594a0c72fbd4b7394e02d2a6bf0490c44acb7474fc2e3837de0827c5d
|
3 |
+
size 60675360
|
modelnet_c/jitter_2.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d95bbc72d191f632a368873552bcaa15068082f424f184485b997d4e1ba6f05
|
3 |
+
size 60675360
|
modelnet_c/jitter_3.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed401170657a4326638787c71004960db5cc5bb03de56df64ef335d73c975032
|
3 |
+
size 60675360
|
modelnet_c/jitter_4.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4547327b47fa368814145f8397bf87061a86281afbf253f85c0a6d1abbd8251b
|
3 |
+
size 60675360
|
modelnet_c/rotate_0.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95246d8c9506578039c4ab8f0cb1b8794d3d32c533e906aea724a04de8c3a1c7
|
3 |
+
size 60675360
|
modelnet_c/rotate_1.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:70147de9808e5ed0cdcb0184a1ade44241497c3d69d74576e5dc1b333cfc87aa
|
3 |
+
size 60675360
|
modelnet_c/rotate_2.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6ae9db2079eb728eb2f8cdc7dfdcb7ff5456575a986c7e10777cbcd29fb4fbb
|
3 |
+
size 60675360
|
modelnet_c/rotate_3.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c13ab0be340e82c63ef325619759c71aae63a622a1846c95f0697152c3e7b1c
|
3 |
+
size 60675360
|
modelnet_c/rotate_4.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1d70d56407cb9a02308a34806bb908a1bba8db2779881a7a1a49ba7005119fb
|
3 |
+
size 60675360
|
modelnet_c/scale_0.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4bb6638a9330960261aadda500d34060ad00143322cc2a40f3b5dea55ad7a40
|
3 |
+
size 30348576
|