yamildiego
commited on
Commit
•
15a8194
1
Parent(s):
471adc0
test
Browse files- handler.py +10 -21
handler.py
CHANGED
@@ -31,29 +31,18 @@ class EndpointHandler():
|
|
31 |
#self.pipe.enable_vae_tiling()
|
32 |
self.generator = torch.Generator(device=device.type).manual_seed(3)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
def asymmetricConv2DConvForward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
42 |
-
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
43 |
-
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
44 |
-
working = F.pad(input, self.paddingX, mode='circular')
|
45 |
-
working = F.pad(working, self.paddingY, mode='constant')
|
46 |
-
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
47 |
-
|
48 |
-
targets = [pipe.vae, pipe.text_encoder, pipe.unet,]
|
49 |
-
conv_layers = []
|
50 |
for target in targets:
|
51 |
for module in target.modules():
|
52 |
-
if isinstance(module, torch.nn.Conv2d):
|
53 |
-
conv_layers.append(module)
|
54 |
-
|
55 |
-
for cl in conv_layers:
|
56 |
-
cl._conv_forward = asymmetricConv2DConvForward.__get__(cl, Conv2d)
|
57 |
|
58 |
|
59 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
|
|
31 |
#self.pipe.enable_vae_tiling()
|
32 |
self.generator = torch.Generator(device=device.type).manual_seed(3)
|
33 |
|
34 |
+
targets = [
|
35 |
+
self.pipe.vae,
|
36 |
+
self.pipe.text_encoder,
|
37 |
+
self.pipe.unet,
|
38 |
+
]
|
39 |
+
self.conv_layers = []
|
40 |
+
self.conv_layers_original_paddings = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
for target in targets:
|
42 |
for module in target.modules():
|
43 |
+
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.ConvTranspose2d):
|
44 |
+
self.conv_layers.append(module)
|
45 |
+
self.conv_layers_original_paddings.append(module.padding_mode)
|
|
|
|
|
46 |
|
47 |
|
48 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|