pengdaqian commited on
Commit
ebf0b64
1 Parent(s): d39fc00
Files changed (1) hide show
  1. img_nsfw.py +6 -5
img_nsfw.py CHANGED
@@ -3,17 +3,18 @@ concepts = ['sexual', 'nude', 'sex', '18+', 'naked', 'nsfw', 'porn', 'dick', 'va
3
  special_concepts = ["small girl (approximation)", "young child", "young girl"]
4
 
5
  import dbimutils
 
6
 
7
 
8
  def init_nsfw_pipe():
9
- import torch
10
  from diffusers import StableDiffusionPipeline
11
  from torch import nn
12
 
13
  # make sure you're logged in with `huggingface-cli login`
14
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
15
  torch_dtype=torch.float16)
16
- pipe = pipe.to('cuda')
 
17
 
18
  def cosine_distance(image_embeds, text_embeds):
19
  normalized_image_embeds = nn.functional.normalize(image_embeds)
@@ -74,9 +75,9 @@ def init_nsfw_pipe():
74
  def check_nsfw(img, pipe):
75
  if isinstance(img, str):
76
  img = dbimutils.read_img_from_url(img)
77
- safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt").to("cuda")
78
- else:
79
- safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt").to("cuda")
80
  from torch.cuda.amp import autocast
81
  with autocast():
82
  _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)
 
3
  special_concepts = ["small girl (approximation)", "young child", "young girl"]
4
 
5
  import dbimutils
6
+ import torch
7
 
8
 
9
  def init_nsfw_pipe():
 
10
  from diffusers import StableDiffusionPipeline
11
  from torch import nn
12
 
13
  # make sure you're logged in with `huggingface-cli login`
14
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
15
  torch_dtype=torch.float16)
16
+ if torch.cuda.is_available():
17
+ pipe = pipe.to('cuda')
18
 
19
  def cosine_distance(image_embeds, text_embeds):
20
  normalized_image_embeds = nn.functional.normalize(image_embeds)
 
75
  def check_nsfw(img, pipe):
76
  if isinstance(img, str):
77
  img = dbimutils.read_img_from_url(img)
78
+ safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt")
79
+ if torch.cuda.is_available():
80
+ safety_checker_input = safety_checker_input.to('cuda')
81
  from torch.cuda.amp import autocast
82
  with autocast():
83
  _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)