Spaces:
Runtime error
Runtime error
Arnaudding001
commited on
Commit
•
9fe2945
1
Parent(s):
956e60b
Create stylegan_op_upfirdn2d.py
Browse files- stylegan_op_upfirdn2d.py +61 -0
stylegan_op_upfirdn2d.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import abc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def upfirdn2d(inputs, kernel, up=1, down=1, pad=(0, 0)):
|
8 |
+
if not isinstance(up, abc.Iterable):
|
9 |
+
up = (up, up)
|
10 |
+
|
11 |
+
if not isinstance(down, abc.Iterable):
|
12 |
+
down = (down, down)
|
13 |
+
|
14 |
+
if len(pad) == 2:
|
15 |
+
pad = (pad[0], pad[1], pad[0], pad[1])
|
16 |
+
|
17 |
+
return upfirdn2d_native(inputs, kernel, *up, *down, *pad)
|
18 |
+
|
19 |
+
|
20 |
+
def upfirdn2d_native(
|
21 |
+
inputs, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
22 |
+
):
|
23 |
+
_, channel, in_h, in_w = inputs.shape
|
24 |
+
inputs = inputs.reshape(-1, in_h, in_w, 1)
|
25 |
+
|
26 |
+
_, in_h, in_w, minor = inputs.shape
|
27 |
+
kernel_h, kernel_w = kernel.shape
|
28 |
+
|
29 |
+
out = inputs.view(-1, in_h, 1, in_w, 1, minor)
|
30 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
31 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
32 |
+
|
33 |
+
out = F.pad(
|
34 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
35 |
+
)
|
36 |
+
out = out[
|
37 |
+
:,
|
38 |
+
max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
|
39 |
+
max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
|
40 |
+
:,
|
41 |
+
]
|
42 |
+
|
43 |
+
out = out.permute(0, 3, 1, 2)
|
44 |
+
out = out.reshape(
|
45 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
46 |
+
)
|
47 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
48 |
+
out = F.conv2d(out, w)
|
49 |
+
out = out.reshape(
|
50 |
+
-1,
|
51 |
+
minor,
|
52 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
53 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
54 |
+
)
|
55 |
+
out = out.permute(0, 2, 3, 1)
|
56 |
+
out = out[:, ::down_y, ::down_x, :]
|
57 |
+
|
58 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
59 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
60 |
+
|
61 |
+
return out.view(-1, channel, out_h, out_w)
|