image-label-2 / img_nsfw.py
pengdaqian
init scan
ebf0b64
raw
history blame
4.26 kB
concepts = ['sexual', 'nude', 'sex', '18+', 'naked', 'nsfw', 'porn', 'dick', 'vagina', 'naked person (approximation)',
'explicit content', 'uncensored', 'fuck', 'nipples', 'nipples (approximation)', 'naked breasts', 'areola']
special_concepts = ["small girl (approximation)", "young child", "young girl"]
import dbimutils
import torch
def init_nsfw_pipe():
from diffusers import StableDiffusionPipeline
from torch import nn
# make sure you're logged in with `huggingface-cli login`
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
torch_dtype=torch.float16)
if torch.cuda.is_available():
pipe = pipe.to('cuda')
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
@torch.no_grad()
def forward_ours(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concet_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concet_idx]
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concet_idx] > 0:
result_img["special_care"].append({"tag": special_concepts[concet_idx],
"confidence": result_img["special_scores"][concet_idx]})
adjustment = 0.01
print("Special concept matched:", special_concepts[concet_idx])
for concet_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concet_idx]
concept_threshold = self.concept_embeds_weights[concet_idx].item()
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
# print("no-special", concet_idx, concepts[concet_idx], concept_threshold, round(concept_cos - concept_threshold + adjustment, 3))
if result_img["concept_scores"][concet_idx] > 0:
result_img["bad_concepts"].append({"tag": concepts[concet_idx],
"confidence": result_img["concept_scores"][concet_idx]})
print("NSFW concept found:", concepts[concet_idx])
special_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['special_care']))
bad_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['bad_concepts']))
result.append({"special_tags": special_tags,
"bad_tags": bad_tags, })
return images, result
from functools import partial
pipe.safety_checker.forward = partial(forward_ours, self=pipe.safety_checker)
return pipe
def check_nsfw(img, pipe):
if isinstance(img, str):
img = dbimutils.read_img_from_url(img)
safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt")
if torch.cuda.is_available():
safety_checker_input = safety_checker_input.to('cuda')
from torch.cuda.amp import autocast
with autocast():
_, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)
return nsfw_tags