Spaces:
Runtime error
Runtime error
add safety checker
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
|
|
6 |
import torch
|
7 |
from einops import rearrange
|
8 |
from PIL import Image
|
|
|
9 |
|
10 |
from flux.cli import SamplingOptions
|
11 |
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
@@ -13,6 +14,7 @@ from flux.util import load_ae, load_clip, load_flow_model, load_t5
|
|
13 |
from pulid.pipeline_flux import PuLIDPipeline
|
14 |
from pulid.utils import resize_numpy_image_long
|
15 |
|
|
|
16 |
|
17 |
def get_models(name: str, device: torch.device, offload: bool):
|
18 |
t5 = load_t5(device, max_length=128)
|
@@ -20,7 +22,8 @@ def get_models(name: str, device: torch.device, offload: bool):
|
|
20 |
model = load_flow_model(name, device="cpu" if offload else device)
|
21 |
model.eval()
|
22 |
ae = load_ae(name, device="cpu" if offload else device)
|
23 |
-
|
|
|
24 |
|
25 |
|
26 |
class FluxGenerator:
|
@@ -28,7 +31,7 @@ class FluxGenerator:
|
|
28 |
self.device = torch.device('cuda')
|
29 |
self.offload = False
|
30 |
self.model_name = 'flux-dev'
|
31 |
-
self.model, self.ae, self.t5, self.clip = get_models(
|
32 |
self.model_name,
|
33 |
device=self.device,
|
34 |
offload=self.offload,
|
@@ -147,7 +150,12 @@ def generate_image(
|
|
147 |
x = rearrange(x[0], "c h w -> h w c")
|
148 |
|
149 |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
_HEADER_ = '''
|
153 |
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
|
|
6 |
import torch
|
7 |
from einops import rearrange
|
8 |
from PIL import Image
|
9 |
+
from transformers import pipeline
|
10 |
|
11 |
from flux.cli import SamplingOptions
|
12 |
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
|
|
14 |
from pulid.pipeline_flux import PuLIDPipeline
|
15 |
from pulid.utils import resize_numpy_image_long
|
16 |
|
17 |
+
NSFW_THRESHOLD = 0.85
|
18 |
|
19 |
def get_models(name: str, device: torch.device, offload: bool):
|
20 |
t5 = load_t5(device, max_length=128)
|
|
|
22 |
model = load_flow_model(name, device="cpu" if offload else device)
|
23 |
model.eval()
|
24 |
ae = load_ae(name, device="cpu" if offload else device)
|
25 |
+
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
26 |
+
return model, ae, t5, clip, nsfw_classifier
|
27 |
|
28 |
|
29 |
class FluxGenerator:
|
|
|
31 |
self.device = torch.device('cuda')
|
32 |
self.offload = False
|
33 |
self.model_name = 'flux-dev'
|
34 |
+
self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
|
35 |
self.model_name,
|
36 |
device=self.device,
|
37 |
offload=self.offload,
|
|
|
150 |
x = rearrange(x[0], "c h w -> h w c")
|
151 |
|
152 |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
153 |
+
nsfw_score = [x["score"] for x in flux_generator.nsfw_classifier(img) if x["label"] == "nsfw"][0]
|
154 |
+
if nsfw_score < NSFW_THRESHOLD:
|
155 |
+
return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
|
156 |
+
else:
|
157 |
+
return (None, f"Your generated image may contain NSFW (with nsfw_score: {nsfw_score}) content",
|
158 |
+
flux_generator.pulid_model.debug_img_list)
|
159 |
|
160 |
_HEADER_ = '''
|
161 |
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|