pengdaqian commited on
Commit
4516918
1 Parent(s): fb1826c

warm up model

Browse files
Files changed (1) hide show
  1. img_nsfw.py +8 -4
img_nsfw.py CHANGED
@@ -11,10 +11,13 @@ def init_nsfw_pipe():
11
  from torch import nn
12
 
13
  # make sure you're logged in with `huggingface-cli login`
14
- pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
15
- torch_dtype=torch.float16)
16
  if torch.cuda.is_available():
 
 
17
  pipe = pipe.to('cuda')
 
 
 
18
 
19
  def cosine_distance(image_embeds, text_embeds):
20
  normalized_image_embeds = nn.functional.normalize(image_embeds)
@@ -79,8 +82,9 @@ def check_nsfw(img, pipe):
79
  if torch.cuda.is_available():
80
  safety_checker_input = safety_checker_input.to('cuda')
81
  else:
82
- safety_checker_input = safety_checker_input
 
83
  from torch.cuda.amp import autocast
84
  with autocast():
85
- _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values.half(), images=img)
86
  return nsfw_tags
 
11
  from torch import nn
12
 
13
  # make sure you're logged in with `huggingface-cli login`
 
 
14
  if torch.cuda.is_available():
15
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
16
+ torch_dtype=torch.float16)
17
  pipe = pipe.to('cuda')
18
+ else:
19
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp32",
20
+ torch_dtype=torch.float32)
21
 
22
  def cosine_distance(image_embeds, text_embeds):
23
  normalized_image_embeds = nn.functional.normalize(image_embeds)
 
82
  if torch.cuda.is_available():
83
  safety_checker_input = safety_checker_input.to('cuda')
84
  else:
85
+ safety_checker_input = safety_checker_input.to('cpu')
86
+
87
  from torch.cuda.amp import autocast
88
  with autocast():
89
+ _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)
90
  return nsfw_tags