yamildiego
commited on
Commit
•
471adc0
1
Parent(s):
63859a4
- handler.py +30 -5
handler.py
CHANGED
@@ -20,16 +20,41 @@ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.
|
|
20 |
|
21 |
class EndpointHandler():
|
22 |
def __init__(self, path=""):
|
23 |
-
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
35 |
# """
|
|
|
20 |
|
21 |
class EndpointHandler():
|
22 |
def __init__(self, path=""):
|
23 |
+
self.stable_diffusion_id = "Lykon/dreamshaper-8"
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
|
30 |
+
#self.pipe.enable_xformers_memory_efficient_attention()
|
31 |
+
#self.pipe.enable_vae_tiling()
|
32 |
+
self.generator = torch.Generator(device=device.type).manual_seed(3)
|
33 |
+
|
34 |
+
|
35 |
+
from typing import Optional
|
36 |
+
from torch import Tensor
|
37 |
+
from torch.nn import functional as F
|
38 |
+
from torch.nn import Conv2d
|
39 |
+
from torch.nn.modules.utils import _pair
|
40 |
+
|
41 |
+
def asymmetricConv2DConvForward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
42 |
+
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
43 |
+
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
44 |
+
working = F.pad(input, self.paddingX, mode='circular')
|
45 |
+
working = F.pad(working, self.paddingY, mode='constant')
|
46 |
+
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
47 |
+
|
48 |
+
targets = [pipe.vae, pipe.text_encoder, pipe.unet,]
|
49 |
+
conv_layers = []
|
50 |
+
for target in targets:
|
51 |
+
for module in target.modules():
|
52 |
+
if isinstance(module, torch.nn.Conv2d):
|
53 |
+
conv_layers.append(module)
|
54 |
+
|
55 |
+
for cl in conv_layers:
|
56 |
+
cl._conv_forward = asymmetricConv2DConvForward.__get__(cl, Conv2d)
|
57 |
+
|
58 |
|
59 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
60 |
# """
|