Spaces:
Sleeping
Sleeping
File size: 4,256 Bytes
d39fc00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
concepts = ['sexual', 'nude', 'sex', '18+', 'naked', 'nsfw', 'porn', 'dick', 'vagina', 'naked person (approximation)',
'explicit content', 'uncensored', 'fuck', 'nipples', 'nipples (approximation)', 'naked breasts', 'areola']
special_concepts = ["small girl (approximation)", "young child", "young girl"]
import dbimutils
def init_nsfw_pipe():
import torch
from diffusers import StableDiffusionPipeline
from torch import nn
# make sure you're logged in with `huggingface-cli login`
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
torch_dtype=torch.float16)
pipe = pipe.to('cuda')
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
@torch.no_grad()
def forward_ours(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concet_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concet_idx]
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concet_idx] > 0:
result_img["special_care"].append({"tag": special_concepts[concet_idx],
"confidence": result_img["special_scores"][concet_idx]})
adjustment = 0.01
print("Special concept matched:", special_concepts[concet_idx])
for concet_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concet_idx]
concept_threshold = self.concept_embeds_weights[concet_idx].item()
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
# print("no-special", concet_idx, concepts[concet_idx], concept_threshold, round(concept_cos - concept_threshold + adjustment, 3))
if result_img["concept_scores"][concet_idx] > 0:
result_img["bad_concepts"].append({"tag": concepts[concet_idx],
"confidence": result_img["concept_scores"][concet_idx]})
print("NSFW concept found:", concepts[concet_idx])
special_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['special_care']))
bad_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['bad_concepts']))
result.append({"special_tags": special_tags,
"bad_tags": bad_tags, })
return images, result
from functools import partial
pipe.safety_checker.forward = partial(forward_ours, self=pipe.safety_checker)
return pipe
def check_nsfw(img, pipe):
if isinstance(img, str):
img = dbimutils.read_img_from_url(img)
safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt").to("cuda")
else:
safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt").to("cuda")
from torch.cuda.amp import autocast
with autocast():
_, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)
return nsfw_tags
|