pengdaqian commited on
Commit
1960aa0
1 Parent(s): d2ae0ec

warm up model

Browse files
Files changed (1) hide show
  1. img_nsfw.py +2 -0
img_nsfw.py CHANGED
@@ -78,6 +78,8 @@ def check_nsfw(img, pipe):
78
  safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt")
79
  if torch.cuda.is_available():
80
  safety_checker_input = safety_checker_input.to('cuda')
 
 
81
  from torch.cuda.amp import autocast
82
  with autocast():
83
  _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)
 
78
  safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt")
79
  if torch.cuda.is_available():
80
  safety_checker_input = safety_checker_input.to('cuda')
81
+ else:
82
+ safety_checker_input = safety_checker_input.half()
83
  from torch.cuda.amp import autocast
84
  with autocast():
85
  _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)