swin2mose
#1
by
hachreak
- opened
- .gitignore +1 -0
- swin2_mose/libs.py +56 -0
- swin2_mose/model.py +9 -12
- swin2_mose/moe.py +3 -2
- swin2_mose/run.py +36 -20
- swin2_mose/utils.py +77 -56
- swin2_mose/weights/config-70.yml +46 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pyc
|
swin2_mose/libs.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
def window_reverse(windows, window_size, H, W):
|
5 |
+
"""
|
6 |
+
Args:
|
7 |
+
windows: (num_windows*B, window_size, window_size, C)
|
8 |
+
window_size (int): Window size
|
9 |
+
H (int): Height of image
|
10 |
+
W (int): Width of image
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
x: (B, H, W, C)
|
14 |
+
"""
|
15 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
16 |
+
x = windows.view(B, H // window_size, W // window_size, window_size,
|
17 |
+
window_size, -1)
|
18 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class Mlp(nn.Module):
|
23 |
+
def __init__(self, in_features, hidden_features=None, out_features=None,
|
24 |
+
act_layer=nn.GELU, drop=0.):
|
25 |
+
super().__init__()
|
26 |
+
out_features = out_features or in_features
|
27 |
+
hidden_features = hidden_features or in_features
|
28 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
29 |
+
self.act = act_layer()
|
30 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
31 |
+
self.drop = nn.Dropout(drop)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x = self.fc1(x)
|
35 |
+
x = self.act(x)
|
36 |
+
x = self.drop(x)
|
37 |
+
x = self.fc2(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
def window_partition(x, window_size):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
x: (B, H, W, C)
|
46 |
+
window_size (int): window size
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
windows: (num_windows*B, window_size, window_size, C)
|
50 |
+
"""
|
51 |
+
B, H, W, C = x.shape
|
52 |
+
x = x.view(B, H // window_size, window_size,
|
53 |
+
W // window_size, window_size, C)
|
54 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
55 |
+
-1, window_size, window_size, C)
|
56 |
+
return windows
|
swin2_mose/model.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
#
|
2 |
-
# Source code: https://github.com/
|
3 |
#
|
4 |
-
#
|
5 |
-
#
|
6 |
-
#
|
7 |
-
# -----------------------------------------------------------------------------------
|
8 |
|
9 |
import math
|
10 |
import numpy as np
|
@@ -14,7 +13,7 @@ import torch.nn.functional as F
|
|
14 |
import torch.utils.checkpoint as checkpoint
|
15 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
16 |
|
17 |
-
from
|
18 |
from moe import MoE
|
19 |
|
20 |
|
@@ -746,9 +745,8 @@ class UpsampleOneStep(nn.Sequential):
|
|
746 |
|
747 |
|
748 |
|
749 |
-
class
|
750 |
-
r"""
|
751 |
-
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
|
752 |
|
753 |
Args:
|
754 |
img_size (int | tuple(int)): Input image size. Default 64
|
@@ -784,8 +782,7 @@ class Swin2SR(nn.Module):
|
|
784 |
MoE_config=None,
|
785 |
use_rpe_bias=False,
|
786 |
**kwargs):
|
787 |
-
super(
|
788 |
-
print('==== SWIN 2SR')
|
789 |
num_in_ch = in_chans
|
790 |
num_out_ch = in_chans
|
791 |
num_feat = 64
|
@@ -1154,4 +1151,4 @@ class Swin2SR(nn.Module):
|
|
1154 |
flops += layer.flops()
|
1155 |
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
1156 |
flops += self.upsample.flops()
|
1157 |
-
return flops
|
|
|
1 |
#
|
2 |
+
# Source code: https://github.com/IMPLabUniPr/swin2-mose
|
3 |
#
|
4 |
+
# ----------------------------------------------------------------------------
|
5 |
+
# https://arxiv.org/abs/2404.18924
|
6 |
+
# ----------------------------------------------------------------------------
|
|
|
7 |
|
8 |
import math
|
9 |
import numpy as np
|
|
|
13 |
import torch.utils.checkpoint as checkpoint
|
14 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
15 |
|
16 |
+
from libs import window_reverse, Mlp, window_partition
|
17 |
from moe import MoE
|
18 |
|
19 |
|
|
|
745 |
|
746 |
|
747 |
|
748 |
+
class Swin2MoSE(nn.Module):
|
749 |
+
r""" Swin2-MoSE
|
|
|
750 |
|
751 |
Args:
|
752 |
img_size (int | tuple(int)): Input image size. Default 64
|
|
|
782 |
MoE_config=None,
|
783 |
use_rpe_bias=False,
|
784 |
**kwargs):
|
785 |
+
super(Swin2MoSE, self).__init__()
|
|
|
786 |
num_in_ch = in_chans
|
787 |
num_out_ch = in_chans
|
788 |
num_feat = 64
|
|
|
1151 |
flops += layer.flops()
|
1152 |
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
1153 |
flops += self.upsample.flops()
|
1154 |
+
return flops
|
swin2_mose/moe.py
CHANGED
@@ -18,7 +18,8 @@ from torch.distributions.normal import Normal
|
|
18 |
from copy import deepcopy
|
19 |
import numpy as np
|
20 |
|
21 |
-
from
|
|
|
22 |
|
23 |
class SparseDispatcher(object):
|
24 |
"""Helper for implementing a mixture of experts.
|
@@ -320,4 +321,4 @@ class MoE(nn.Module):
|
|
320 |
expert_outputs = [self.experts[i](expert_inputs[i])
|
321 |
for i in range(self.num_experts)]
|
322 |
y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine)
|
323 |
-
return y, loss
|
|
|
18 |
from copy import deepcopy
|
19 |
import numpy as np
|
20 |
|
21 |
+
from libs import Mlp as MLP
|
22 |
+
|
23 |
|
24 |
class SparseDispatcher(object):
|
25 |
"""Helper for implementing a mixture of experts.
|
|
|
321 |
expert_outputs = [self.experts[i](expert_inputs[i])
|
322 |
for i in range(self.num_experts)]
|
323 |
y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine)
|
324 |
+
return y, loss
|
swin2_mose/run.py
CHANGED
@@ -1,20 +1,36 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import benchmark
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import opensr_test
|
4 |
+
|
5 |
+
from utils import load_swin2_mose, load_config, run_swin2_mose
|
6 |
+
|
7 |
+
|
8 |
+
path = 'swin2_mose/weights/config-70.yml'
|
9 |
+
model_weights = "swin2_mose/weights/model-70.pt"
|
10 |
+
index = 2
|
11 |
+
|
12 |
+
# load config
|
13 |
+
cfg = load_config(path)
|
14 |
+
# load model
|
15 |
+
model = load_swin2_mose(model_weights, cfg)
|
16 |
+
|
17 |
+
# load the dataset
|
18 |
+
dataset = opensr_test.load("venus")
|
19 |
+
lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
|
20 |
+
|
21 |
+
results = run_swin2_mose(model, lr_dataset[index], hr_dataset[index])
|
22 |
+
|
23 |
+
# Display the results
|
24 |
+
fig, ax = plt.subplots(1, 3, figsize=(10, 5))
|
25 |
+
ax[0].imshow(results['lr'].numpy().transpose(1, 2, 0)/3000)
|
26 |
+
ax[0].set_title("LR")
|
27 |
+
ax[0].axis("off")
|
28 |
+
ax[1].imshow(results["sr"].detach().numpy().transpose(1, 2, 0)/3000)
|
29 |
+
ax[1].set_title("SR")
|
30 |
+
ax[1].axis("off")
|
31 |
+
ax[2].imshow(results['hr'].numpy().transpose(1, 2, 0) / 3000)
|
32 |
+
ax[2].set_title("HR")
|
33 |
+
# plt.show()
|
34 |
+
|
35 |
+
# Run the experiment
|
36 |
+
benchmark.create_geotiff(model, run_swin2_mose, "all", "swin2mose/")
|
swin2_mose/utils.py
CHANGED
@@ -1,56 +1,77 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import yaml
|
3 |
+
|
4 |
+
from model import Swin2MoSE
|
5 |
+
|
6 |
+
|
7 |
+
def to_shape(t1, t2):
|
8 |
+
t1 = t1[None].repeat(t2.shape[0], 1)
|
9 |
+
t1 = t1.view((t2.shape[:2] + (1, 1)))
|
10 |
+
return t1
|
11 |
+
|
12 |
+
|
13 |
+
def norm(tensor, mean, std):
|
14 |
+
# get stats
|
15 |
+
mean = torch.tensor(mean).to(tensor.device)
|
16 |
+
std = torch.tensor(std).to(tensor.device)
|
17 |
+
# denorm
|
18 |
+
return (tensor - to_shape(mean, tensor)) / to_shape(std, tensor)
|
19 |
+
|
20 |
+
|
21 |
+
def denorm(tensor, mean, std):
|
22 |
+
# get stats
|
23 |
+
mean = torch.tensor(mean).to(tensor.device)
|
24 |
+
std = torch.tensor(std).to(tensor.device)
|
25 |
+
# denorm
|
26 |
+
return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor)
|
27 |
+
|
28 |
+
|
29 |
+
def load_config(path):
|
30 |
+
# load config
|
31 |
+
with open(path, 'r') as f:
|
32 |
+
cfg = yaml.safe_load(f)
|
33 |
+
return cfg
|
34 |
+
|
35 |
+
|
36 |
+
def load_swin2_mose(model_weights, cfg):
|
37 |
+
# load checkpoint
|
38 |
+
checkpoint = torch.load(model_weights)
|
39 |
+
|
40 |
+
# build model
|
41 |
+
sr_model = Swin2MoSE(**cfg['super_res']['model'])
|
42 |
+
sr_model.load_state_dict(
|
43 |
+
checkpoint['model_state_dict'])
|
44 |
+
|
45 |
+
sr_model.cfg = cfg
|
46 |
+
|
47 |
+
return sr_model
|
48 |
+
|
49 |
+
|
50 |
+
def run_swin2_mose(model, lr, hr):
|
51 |
+
cfg = model.cfg
|
52 |
+
|
53 |
+
# norm fun
|
54 |
+
hr_stats = cfg['dataset']['stats']['tensor_05m_b2b3b4b8']
|
55 |
+
lr_stats = cfg['dataset']['stats']['tensor_10m_b2b3b4b8']
|
56 |
+
|
57 |
+
# select 10m lr bands: B02, B03, B04, B08 and hr bands
|
58 |
+
lr_orig = torch.tensor(lr)[None].float()[:, [3, 2, 1, 7]]
|
59 |
+
hr_orig = torch.tensor(hr)[None].float()
|
60 |
+
|
61 |
+
# normalize data
|
62 |
+
lr = norm(lr_orig, mean=lr_stats['mean'], std=lr_stats['std'])
|
63 |
+
hr = norm(hr_orig, mean=hr_stats['mean'], std=hr_stats['std'])
|
64 |
+
|
65 |
+
# predict a image
|
66 |
+
sr = model(lr)
|
67 |
+
if not torch.is_tensor(sr):
|
68 |
+
sr, _ = sr
|
69 |
+
|
70 |
+
# denorm sr
|
71 |
+
sr = denorm(sr, mean=hr_stats['mean'], std=hr_stats['std'])
|
72 |
+
|
73 |
+
return {
|
74 |
+
"lr": lr_orig[0],
|
75 |
+
"sr": sr[0],
|
76 |
+
"hr": hr_orig[0],
|
77 |
+
}
|
swin2_mose/weights/config-70.yml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
root_path: data/sen2venus
|
3 |
+
stats:
|
4 |
+
use_minmax: true
|
5 |
+
tensor_05m_b2b3b4b8: {
|
6 |
+
mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875],
|
7 |
+
std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625],
|
8 |
+
min: [-1025.0, -3112.0, -5122.0, -3851.0],
|
9 |
+
max: [14748.0, 14960.0, 16472.0, 16109.0]
|
10 |
+
}
|
11 |
+
tensor_10m_b2b3b4b8: {
|
12 |
+
mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875],
|
13 |
+
std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875],
|
14 |
+
min: [-848.0, -902.0, -946.0, -323.0],
|
15 |
+
max: [19684.0, 17982.0, 17064.0, 15958.0]
|
16 |
+
}
|
17 |
+
hr_name: tensor_05m_b2b3b4b8
|
18 |
+
lr_name: tensor_10m_b2b3b4b8
|
19 |
+
collate_fn: mods.v3.collate_fn
|
20 |
+
denorm: mods.v3.uncollate_fn
|
21 |
+
printable: mods.v3.printable
|
22 |
+
super_res: {
|
23 |
+
version: 'v2',
|
24 |
+
model: {
|
25 |
+
upscale: 2,
|
26 |
+
use_lepe: true,
|
27 |
+
use_cpb_bias: false,
|
28 |
+
use_rpe_bias: true,
|
29 |
+
mlp_ratio: 1,
|
30 |
+
MoE_config: {
|
31 |
+
k: 2,
|
32 |
+
num_experts: 8,
|
33 |
+
with_noise: false,
|
34 |
+
with_smart_merger: v1,
|
35 |
+
},
|
36 |
+
depths: [6, 6, 6, 6],
|
37 |
+
embed_dim: 90,
|
38 |
+
img_range: 1.,
|
39 |
+
img_size: 64,
|
40 |
+
in_chans: 4,
|
41 |
+
num_heads: [6, 6, 6, 6],
|
42 |
+
resi_connection: 1conv,
|
43 |
+
upsampler: pixelshuffledirect,
|
44 |
+
window_size: 16,
|
45 |
+
}
|
46 |
+
}
|