Arnaudding001 commited on
Commit
9fe2945
1 Parent(s): 956e60b

Create stylegan_op_upfirdn2d.py

Browse files
Files changed (1) hide show
  1. stylegan_op_upfirdn2d.py +61 -0
stylegan_op_upfirdn2d.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def upfirdn2d(inputs, kernel, up=1, down=1, pad=(0, 0)):
8
+ if not isinstance(up, abc.Iterable):
9
+ up = (up, up)
10
+
11
+ if not isinstance(down, abc.Iterable):
12
+ down = (down, down)
13
+
14
+ if len(pad) == 2:
15
+ pad = (pad[0], pad[1], pad[0], pad[1])
16
+
17
+ return upfirdn2d_native(inputs, kernel, *up, *down, *pad)
18
+
19
+
20
+ def upfirdn2d_native(
21
+ inputs, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
22
+ ):
23
+ _, channel, in_h, in_w = inputs.shape
24
+ inputs = inputs.reshape(-1, in_h, in_w, 1)
25
+
26
+ _, in_h, in_w, minor = inputs.shape
27
+ kernel_h, kernel_w = kernel.shape
28
+
29
+ out = inputs.view(-1, in_h, 1, in_w, 1, minor)
30
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
31
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
32
+
33
+ out = F.pad(
34
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
35
+ )
36
+ out = out[
37
+ :,
38
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
39
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
40
+ :,
41
+ ]
42
+
43
+ out = out.permute(0, 3, 1, 2)
44
+ out = out.reshape(
45
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
46
+ )
47
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
48
+ out = F.conv2d(out, w)
49
+ out = out.reshape(
50
+ -1,
51
+ minor,
52
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
53
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
54
+ )
55
+ out = out.permute(0, 2, 3, 1)
56
+ out = out[:, ::down_y, ::down_x, :]
57
+
58
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
59
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
60
+
61
+ return out.view(-1, channel, out_h, out_w)