File size: 4,256 Bytes
d39fc00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
concepts = ['sexual', 'nude', 'sex', '18+', 'naked', 'nsfw', 'porn', 'dick', 'vagina', 'naked person (approximation)',
            'explicit content', 'uncensored', 'fuck', 'nipples', 'nipples (approximation)', 'naked breasts', 'areola']
special_concepts = ["small girl (approximation)", "young child", "young girl"]

import dbimutils


def init_nsfw_pipe():
    import torch
    from diffusers import StableDiffusionPipeline
    from torch import nn

    # make sure you're logged in with `huggingface-cli login`
    pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
                                                   torch_dtype=torch.float16)
    pipe = pipe.to('cuda')

    def cosine_distance(image_embeds, text_embeds):
        normalized_image_embeds = nn.functional.normalize(image_embeds)
        normalized_text_embeds = nn.functional.normalize(text_embeds)
        return torch.mm(normalized_image_embeds, normalized_text_embeds.t())

    @torch.no_grad()
    def forward_ours(self, clip_input, images):
        pooled_output = self.vision_model(clip_input)[1]  # pooled_output
        image_embeds = self.visual_projection(pooled_output)

        special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
        cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()

        result = []
        batch_size = image_embeds.shape[0]
        for i in range(batch_size):
            result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}

            # increase this value to create a stronger `nfsw` filter
            # at the cost of increasing the possibility of filtering benign images
            adjustment = 0.0

            for concet_idx in range(len(special_cos_dist[0])):
                concept_cos = special_cos_dist[i][concet_idx]
                concept_threshold = self.special_care_embeds_weights[concet_idx].item()
                result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
                if result_img["special_scores"][concet_idx] > 0:
                    result_img["special_care"].append({"tag": special_concepts[concet_idx],
                                                       "confidence": result_img["special_scores"][concet_idx]})
                    adjustment = 0.01
                    print("Special concept matched:", special_concepts[concet_idx])

            for concet_idx in range(len(cos_dist[0])):
                concept_cos = cos_dist[i][concet_idx]
                concept_threshold = self.concept_embeds_weights[concet_idx].item()
                result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
                # print("no-special", concet_idx, concepts[concet_idx], concept_threshold, round(concept_cos - concept_threshold + adjustment, 3))
                if result_img["concept_scores"][concet_idx] > 0:
                    result_img["bad_concepts"].append({"tag": concepts[concet_idx],
                                                       "confidence": result_img["concept_scores"][concet_idx]})
                    print("NSFW concept found:", concepts[concet_idx])

            special_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['special_care']))
            bad_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['bad_concepts']))

            result.append({"special_tags": special_tags,
                           "bad_tags": bad_tags, })

        return images, result

    from functools import partial
    pipe.safety_checker.forward = partial(forward_ours, self=pipe.safety_checker)

    return pipe


def check_nsfw(img, pipe):
    if isinstance(img, str):
        img = dbimutils.read_img_from_url(img)
        safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt").to("cuda")
    else:
        safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt").to("cuda")
    from torch.cuda.amp import autocast
    with autocast():
        _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)
        return nsfw_tags