radames commited on
Commit
fc3c6a1
·
1 Parent(s): 87b827d

Upload gradio-app.py

Browse files
Files changed (1) hide show
  1. gradio-app.py +140 -0
gradio-app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from diffusers import DiffusionPipeline, AutoencoderTiny
5
+ import os
6
+
7
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
8
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
9
+
10
+ if SAFETY_CHECKER:
11
+ pipe = DiffusionPipeline.from_pretrained(
12
+ "SimianLuo/LCM_Dreamshaper_v7",
13
+ custom_pipeline="lcm_txt2img",
14
+ scheduler=None,
15
+ )
16
+ else:
17
+ pipe = DiffusionPipeline.from_pretrained(
18
+ "SimianLuo/LCM_Dreamshaper_v7",
19
+ custom_pipeline="lcm_txt2img",
20
+ scheduler=None,
21
+ safety_checker=None,
22
+ )
23
+ pipe.to(device="cuda", dtype=torch.float16)
24
+ pipe.vae = AutoencoderTiny.from_pretrained(
25
+ "madebyollin/taesd", device="cuda", torch_dtype=torch.float16
26
+ )
27
+ pipe.vae = pipe.vae.cuda()
28
+ pipe.unet.to(memory_format=torch.channels_last)
29
+ pipe.set_progress_bar_config(disable=True)
30
+
31
+ if TORCH_COMPILE:
32
+ pipe.text_encoder = torch.compile(pipe.text_encoder, mode="max-autotune")
33
+ pipe.tokenizer = torch.compile(pipe.tokenizer, mode="max-autotune")
34
+ pipe.unet = torch.compile(pipe.unet, mode="max-autotune")
35
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune")
36
+
37
+
38
+ def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
39
+ torch.manual_seed(seed)
40
+ results = pipe(
41
+ prompt1=prompt1,
42
+ prompt2=prompt2,
43
+ sv=merge_ratio,
44
+ sharpness=sharpness,
45
+ width=512,
46
+ height=512,
47
+ num_inference_steps=steps,
48
+ guidance_scale=guidance,
49
+ lcm_origin_steps=50,
50
+ output_type="pil",
51
+ # return_dict=False,
52
+ )
53
+ nsfw_content_detected = (
54
+ results.nsfw_content_detected[0]
55
+ if "nsfw_content_detected" in results
56
+ else False
57
+ )
58
+ if nsfw_content_detected:
59
+ raise gr.Error("NSFW content detected. Please try another prompt.")
60
+ return results.images[0]
61
+
62
+
63
+ css = """
64
+ #container{
65
+ margin: 0 auto;
66
+ max-width: 80rem;
67
+ }
68
+ #intro{
69
+ max-width: 32rem;
70
+ text-align: center;
71
+ margin: 0 auto;
72
+ }
73
+ """
74
+ with gr.Blocks(css=css) as demo:
75
+ with gr.Column(elem_id="container"):
76
+ gr.Markdown(
77
+ """# SDZoom
78
+
79
+ Welcome to sdzoom, a testbed application designed for optimizing and experimenting with various
80
+ configurations to achieve the fastest Stable Diffusion (SD) pipelines.
81
+ RTSD leverages the expertise provided by Latent Consistency Models (LCM). For more information about LCM,
82
+ visit their website at [Latent Consistency Models](https://latent-consistency-models.github.io/).
83
+
84
+ """,
85
+ elem_id="intro",
86
+ )
87
+ with gr.Row():
88
+ with gr.Column():
89
+ image = gr.Image(type="pil")
90
+ with gr.Column():
91
+ merge_ratio = gr.Slider(
92
+ value=50, minimum=1, maximum=100, step=1, label="Merge Ratio"
93
+ )
94
+ guidance = gr.Slider(
95
+ label="Guidance", minimum=1, maximum=50, value=10.0, step=0.01
96
+ )
97
+ steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=20, step=1)
98
+ sharpness = gr.Slider(
99
+ value=1.0, minimum=0, maximum=1, step=0.001, label="Sharpness"
100
+ )
101
+ seed = gr.Slider(
102
+ randomize=True, minimum=0, maximum=12013012031030, label="Seed"
103
+ )
104
+ prompt1 = gr.Textbox(label="Prompt 1")
105
+ prompt2 = gr.Textbox(label="Prompt 2")
106
+ generate_bt = gr.Button("Generate")
107
+
108
+ inputs = [prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed]
109
+ gr.Examples(
110
+ examples=[
111
+ ["Elon Musk", "Mark Zuckerberg", 50, 10.0, 4, 1.0, 1231231],
112
+ ["Elon Musk", "Bill Gates", 50, 10.0, 4, 1.0, 53453],
113
+ [
114
+ "Asian women, intricate jewlery in her hair, 8k",
115
+ "Tom Cruise, intricate jewlery in her hair, 8k",
116
+ 50,
117
+ 10.0,
118
+ 4,
119
+ 1.0,
120
+ 542343,
121
+ ],
122
+ ],
123
+ fn=predict,
124
+ inputs=inputs,
125
+ outputs=image,
126
+ )
127
+ generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
128
+ seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
129
+ merge_ratio.change(
130
+ fn=predict, inputs=inputs, outputs=image, show_progress=False
131
+ )
132
+ guidance.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
133
+ steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
134
+ sharpness.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
135
+ prompt1.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
136
+ prompt2.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
137
+
138
+ demo.queue()
139
+ if __name__ == "__main__":
140
+ demo.launch()