concepts = ['sexual', 'nude', 'sex', '18+', 'naked', 'nsfw', 'porn', 'dick', 'vagina', 'naked person (approximation)', 'explicit content', 'uncensored', 'fuck', 'nipples', 'nipples (approximation)', 'naked breasts', 'areola'] special_concepts = ["small girl (approximation)", "young child", "young girl"] import dbimutils import torch def init_nsfw_pipe(): from diffusers import StableDiffusionPipeline from torch import nn # make sure you're logged in with `huggingface-cli login` if torch.cuda.is_available(): pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16) pipe = pipe.to('cuda') else: pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) @torch.no_grad() def forward_ours(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() result = [] batch_size = image_embeds.shape[0] for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 for concet_idx in range(len(special_cos_dist[0])): concept_cos = special_cos_dist[i][concet_idx] concept_threshold = self.special_care_embeds_weights[concet_idx].item() result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["special_scores"][concet_idx] > 0: result_img["special_care"].append({"tag": special_concepts[concet_idx], "confidence": result_img["special_scores"][concet_idx]}) adjustment = 0.01 print("Special concept matched:", special_concepts[concet_idx]) for concet_idx in range(len(cos_dist[0])): concept_cos = cos_dist[i][concet_idx] concept_threshold = self.concept_embeds_weights[concet_idx].item() result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) # print("no-special", concet_idx, concepts[concet_idx], concept_threshold, round(concept_cos - concept_threshold + adjustment, 3)) if result_img["concept_scores"][concet_idx] > 0: result_img["bad_concepts"].append({"tag": concepts[concet_idx], "confidence": result_img["concept_scores"][concet_idx]}) print("NSFW concept found:", concepts[concet_idx]) special_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['special_care'])) bad_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['bad_concepts'])) result.append({"special_tags": special_tags, "bad_tags": bad_tags, }) return images, result from functools import partial pipe.safety_checker.forward = partial(forward_ours, self=pipe.safety_checker) return pipe def check_nsfw(img, pipe): if isinstance(img, str): img = dbimutils.read_img_from_url(img) safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt") if torch.cuda.is_available(): safety_checker_input = safety_checker_input.to('cuda') else: safety_checker_input = safety_checker_input.to('cpu') from torch.cuda.amp import autocast with autocast(): _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img) return nsfw_tags