epishchik commited on
Commit
41d73ae
1 Parent(s): b462e09

Upload model

Browse files
Files changed (5) hide show
  1. config.json +18 -0
  2. config.py +23 -0
  3. model.py +22 -0
  4. model.safetensors +3 -0
  5. rrdbnet.py +194 -0
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RealESRGANModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "config.RealESRGANConfig",
7
+ "AutoModel": "model.RealESRGANModel"
8
+ },
9
+ "model_type": "realesrgan",
10
+ "num_block": 23,
11
+ "num_feat": 64,
12
+ "num_grow_ch": 32,
13
+ "num_in_ch": 3,
14
+ "num_out_ch": 3,
15
+ "scale": 4,
16
+ "torch_dtype": "float32",
17
+ "transformers_version": "4.38.1"
18
+ }
config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RealESRGANConfig(PretrainedConfig):
5
+ model_type = "realesrgan"
6
+
7
+ def __init__(
8
+ self,
9
+ num_in_ch: int = 3,
10
+ num_out_ch: int = 3,
11
+ num_feat: int = 64,
12
+ num_block: int = 23,
13
+ num_grow_ch: int = 32,
14
+ scale: int = 4,
15
+ **kwargs,
16
+ ):
17
+ self.num_in_ch = num_in_ch
18
+ self.num_out_ch = num_out_ch
19
+ self.num_feat = num_feat
20
+ self.num_block = num_block
21
+ self.num_grow_ch = num_grow_ch
22
+ self.scale = scale
23
+ super().__init__(**kwargs)
model.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+
3
+ from .config import RealESRGANConfig
4
+ from .rrdbnet import RRDBNet
5
+
6
+
7
+ class RealESRGANModel(PreTrainedModel):
8
+ config_class = RealESRGANConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ self.model = RRDBNet(
13
+ num_in_ch=config.num_in_ch,
14
+ num_out_ch=config.num_out_ch,
15
+ num_feat=config.num_feat,
16
+ num_block=config.num_block,
17
+ num_grow_ch=config.num_grow_ch,
18
+ scale=config.scale,
19
+ )
20
+
21
+ def forward(self, tensor):
22
+ return self.model.forward(tensor)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94ffebe0816db7d0f0837f5b5d49ab75144af01797ff5c010a92f314217c32d9
3
+ size 66862076
rrdbnet.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn import init as init
5
+ from torch.nn.modules.batchnorm import _BatchNorm
6
+
7
+
8
+ def pixel_unshuffle(x, scale):
9
+ """Pixel unshuffle.
10
+
11
+ Args:
12
+ x (Tensor): Input feature with shape (b, c, hh, hw).
13
+ scale (int): Downsample ratio.
14
+
15
+ Returns:
16
+ Tensor: the pixel unshuffled feature.
17
+ """
18
+ b, c, hh, hw = x.size()
19
+ out_channel = c * (scale**2)
20
+ assert hh % scale == 0 and hw % scale == 0
21
+ h = hh // scale
22
+ w = hw // scale
23
+ x_view = x.view(b, c, h, scale, w, scale)
24
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
25
+
26
+
27
+ @torch.no_grad()
28
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
29
+ """Initialize network weights.
30
+
31
+ Args:
32
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
33
+ scale (float): Scale initialized weights, especially for residual
34
+ blocks. Default: 1.
35
+ bias_fill (float): The value to fill bias. Default: 0
36
+ kwargs (dict): Other arguments for initialization function.
37
+ """
38
+ if not isinstance(module_list, list):
39
+ module_list = [module_list]
40
+ for module in module_list:
41
+ for m in module.modules():
42
+ if isinstance(m, nn.Conv2d):
43
+ init.kaiming_normal_(m.weight, **kwargs)
44
+ m.weight.data *= scale
45
+ if m.bias is not None:
46
+ m.bias.data.fill_(bias_fill)
47
+ elif isinstance(m, nn.Linear):
48
+ init.kaiming_normal_(m.weight, **kwargs)
49
+ m.weight.data *= scale
50
+ if m.bias is not None:
51
+ m.bias.data.fill_(bias_fill)
52
+ elif isinstance(m, _BatchNorm):
53
+ init.constant_(m.weight, 1)
54
+ if m.bias is not None:
55
+ m.bias.data.fill_(bias_fill)
56
+
57
+
58
+ def make_layer(basic_block, num_basic_block, **kwarg):
59
+ """Make layers by stacking the same blocks.
60
+
61
+ Args:
62
+ basic_block (nn.module): nn.module class for basic block.
63
+ num_basic_block (int): number of blocks.
64
+
65
+ Returns:
66
+ nn.Sequential: Stacked blocks in nn.Sequential.
67
+ """
68
+ layers = []
69
+ for _ in range(num_basic_block):
70
+ layers.append(basic_block(**kwarg))
71
+ return nn.Sequential(*layers)
72
+
73
+
74
+ class ResidualDenseBlock(nn.Module):
75
+ """Residual Dense Block.
76
+
77
+ Used in RRDB block in ESRGAN.
78
+
79
+ Args:
80
+ num_feat (int): Channel number of intermediate features.
81
+ num_grow_ch (int): Channels for each growth.
82
+ """
83
+
84
+ def __init__(self, num_feat=64, num_grow_ch=32):
85
+ super(ResidualDenseBlock, self).__init__()
86
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
87
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
88
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
89
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
90
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
91
+
92
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
93
+
94
+ # initialization
95
+ default_init_weights(
96
+ [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1
97
+ )
98
+
99
+ def forward(self, x):
100
+ x1 = self.lrelu(self.conv1(x))
101
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
102
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
103
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
104
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
105
+ # Empirically, we use 0.2 to scale the residual for better performance
106
+ return x5 * 0.2 + x
107
+
108
+
109
+ class RRDB(nn.Module):
110
+ """Residual in Residual Dense Block.
111
+
112
+ Used in RRDB-Net in ESRGAN.
113
+
114
+ Args:
115
+ num_feat (int): Channel number of intermediate features.
116
+ num_grow_ch (int): Channels for each growth.
117
+ """
118
+
119
+ def __init__(self, num_feat, num_grow_ch=32):
120
+ super(RRDB, self).__init__()
121
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
122
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
123
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
124
+
125
+ def forward(self, x):
126
+ out = self.rdb1(x)
127
+ out = self.rdb2(out)
128
+ out = self.rdb3(out)
129
+ # Empirically, we use 0.2 to scale the residual for better performance
130
+ return out * 0.2 + x
131
+
132
+
133
+ class RRDBNet(nn.Module):
134
+ """Networks consisting of Residual in Residual Dense Block, which is used
135
+ in ESRGAN.
136
+
137
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
138
+
139
+ We extend ESRGAN for scale x2 and scale x1.
140
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
141
+ We first employ the pixel-unshuffle an inverse operation of pixelshuffle to reduce
142
+ the spatial size and enlarge the channel size before feeding inputs
143
+ into the main ESRGAN architecture.
144
+
145
+ Args:
146
+ num_in_ch (int): Channel number of inputs.
147
+ num_out_ch (int): Channel number of outputs.
148
+ num_feat (int): Channel number of intermediate features.
149
+ Default: 64
150
+ num_block (int): Block number in the trunk network. Defaults: 23
151
+ num_grow_ch (int): Channels for each growth. Default: 32.
152
+ """
153
+
154
+ def __init__(
155
+ self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32
156
+ ):
157
+ super(RRDBNet, self).__init__()
158
+ self.scale = scale
159
+ if scale == 2:
160
+ num_in_ch = num_in_ch * 4
161
+ elif scale == 1:
162
+ num_in_ch = num_in_ch * 16
163
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
164
+ self.body = make_layer(
165
+ RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch
166
+ )
167
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
168
+ # upsample
169
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
170
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
171
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
172
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
173
+
174
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
175
+
176
+ def forward(self, x):
177
+ if self.scale == 2:
178
+ feat = pixel_unshuffle(x, scale=2)
179
+ elif self.scale == 1:
180
+ feat = pixel_unshuffle(x, scale=4)
181
+ else:
182
+ feat = x
183
+ feat = self.conv_first(feat)
184
+ body_feat = self.conv_body(self.body(feat))
185
+ feat = feat + body_feat
186
+ # upsample
187
+ feat = self.lrelu(
188
+ self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
189
+ )
190
+ feat = self.lrelu(
191
+ self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
192
+ )
193
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
194
+ return out