Spaces:
Sleeping
Sleeping
pengdaqian
commited on
Commit
•
4516918
1
Parent(s):
fb1826c
warm up model
Browse files- img_nsfw.py +8 -4
img_nsfw.py
CHANGED
@@ -11,10 +11,13 @@ def init_nsfw_pipe():
|
|
11 |
from torch import nn
|
12 |
|
13 |
# make sure you're logged in with `huggingface-cli login`
|
14 |
-
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
|
15 |
-
torch_dtype=torch.float16)
|
16 |
if torch.cuda.is_available():
|
|
|
|
|
17 |
pipe = pipe.to('cuda')
|
|
|
|
|
|
|
18 |
|
19 |
def cosine_distance(image_embeds, text_embeds):
|
20 |
normalized_image_embeds = nn.functional.normalize(image_embeds)
|
@@ -79,8 +82,9 @@ def check_nsfw(img, pipe):
|
|
79 |
if torch.cuda.is_available():
|
80 |
safety_checker_input = safety_checker_input.to('cuda')
|
81 |
else:
|
82 |
-
safety_checker_input = safety_checker_input
|
|
|
83 |
from torch.cuda.amp import autocast
|
84 |
with autocast():
|
85 |
-
_, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values
|
86 |
return nsfw_tags
|
|
|
11 |
from torch import nn
|
12 |
|
13 |
# make sure you're logged in with `huggingface-cli login`
|
|
|
|
|
14 |
if torch.cuda.is_available():
|
15 |
+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
|
16 |
+
torch_dtype=torch.float16)
|
17 |
pipe = pipe.to('cuda')
|
18 |
+
else:
|
19 |
+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp32",
|
20 |
+
torch_dtype=torch.float32)
|
21 |
|
22 |
def cosine_distance(image_embeds, text_embeds):
|
23 |
normalized_image_embeds = nn.functional.normalize(image_embeds)
|
|
|
82 |
if torch.cuda.is_available():
|
83 |
safety_checker_input = safety_checker_input.to('cuda')
|
84 |
else:
|
85 |
+
safety_checker_input = safety_checker_input.to('cpu')
|
86 |
+
|
87 |
from torch.cuda.amp import autocast
|
88 |
with autocast():
|
89 |
+
_, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)
|
90 |
return nsfw_tags
|