pengdaqian commited on
Commit
83c3831
1 Parent(s): 4516918

warm up model

Browse files
Files changed (1) hide show
  1. img_nsfw.py +10 -11
img_nsfw.py CHANGED
@@ -11,13 +11,12 @@ def init_nsfw_pipe():
11
  from torch import nn
12
 
13
  # make sure you're logged in with `huggingface-cli login`
14
- if torch.cuda.is_available():
15
- pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
16
- torch_dtype=torch.float16)
17
- pipe = pipe.to('cuda')
18
- else:
19
- pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp32",
20
- torch_dtype=torch.float32)
21
 
22
  def cosine_distance(image_embeds, text_embeds):
23
  normalized_image_embeds = nn.functional.normalize(image_embeds)
@@ -79,10 +78,10 @@ def check_nsfw(img, pipe):
79
  if isinstance(img, str):
80
  img = dbimutils.read_img_from_url(img)
81
  safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt")
82
- if torch.cuda.is_available():
83
- safety_checker_input = safety_checker_input.to('cuda')
84
- else:
85
- safety_checker_input = safety_checker_input.to('cpu')
86
 
87
  from torch.cuda.amp import autocast
88
  with autocast():
 
11
  from torch import nn
12
 
13
  # make sure you're logged in with `huggingface-cli login`
14
+ # if torch.cuda.is_available():
15
+ # pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
16
+ # torch_dtype=torch.float16)
17
+ # pipe = pipe.to('cuda')
18
+ # else:
19
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
 
20
 
21
  def cosine_distance(image_embeds, text_embeds):
22
  normalized_image_embeds = nn.functional.normalize(image_embeds)
 
78
  if isinstance(img, str):
79
  img = dbimutils.read_img_from_url(img)
80
  safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt")
81
+ # if torch.cuda.is_available():
82
+ # safety_checker_input = safety_checker_input.to('cuda')
83
+ # else:
84
+ safety_checker_input = safety_checker_input.to('cpu')
85
 
86
  from torch.cuda.amp import autocast
87
  with autocast():