lunarfish commited on
Commit
b82ba17
1 Parent(s): 90b4146

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -52
app.py CHANGED
@@ -1,56 +1,266 @@
 
1
  import gradio as gr
2
  import torch
3
- from torch import autocast
4
- from diffusers import StableDiffusionPipeline
5
-
6
- #model_id = "lunarfish/furrydiffusion"
7
- pipe = StableDiffusionPipeline.from_pretrained("lunarfish/furrydiffusion", torch_type=torch.float16, revision="main")
8
-
9
- num_samples = 2
10
-
11
- def infer(prompt):
12
- images = pipe([prompt] * num_samples, guidance_scale=7.5)["sample"]
13
- return images
14
-
15
-
16
- block = gr.Blocks()
17
-
18
- examples = [
19
- [
20
- 'fox'
21
- ],
22
- [
23
- 'rabbit'
24
- ],
25
- [
26
- 'wolf'
27
- ],
28
- ]
29
-
30
- with block as demo:
31
- with gr.Group():
32
- with gr.Box():
33
- with gr.Row().style(mobile_collapse=False, equal_height=True):
34
-
35
- text = gr.Textbox(
36
- label="Enter your prompt", show_label=False, max_lines=1
37
- ).style(
38
- border=(True, False, True, True),
39
- rounded=(True, False, False, True),
40
- container=False,
41
- )
42
- btn = gr.Button("Run").style(
43
- margin=False,
44
- rounded=(False, True, True, False),
45
- )
46
-
47
- gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="generated_id").style(
48
- grid=[2], height="auto"
49
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- ex = gr.Examples(examples=examples, fn=infer, inputs=[text], outputs=gallery, cache_examples=True)
52
- ex.dataset.headers = [""]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- text.submit(infer, inputs=[text], outputs=gallery)
55
- btn.click(infer, inputs=[text], outputs=gallery)
56
- demo.queue(max_size=25).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler
2
  import gradio as gr
3
  import torch
4
+ from PIL import Image
5
+ import utils
6
+ import datetime
7
+ import time
8
+ import psutil
9
+
10
+ start_time = time.time()
11
+ is_colab = utils.is_google_colab()
12
+
13
+ class Model:
14
+ def __init__(self, name, path="", prefix=""):
15
+ self.name = name
16
+ self.path = path
17
+ self.prefix = prefix
18
+ self.pipe_t2i = None
19
+ self.pipe_i2i = None
20
+
21
+ models = [
22
+ Model("FurryDiffusion", "lunarfish/furrydiffusion", "Furry Diffusion Style"),
23
+ ]
24
+
25
+ scheduler = DPMSolverMultistepScheduler(
26
+ beta_start=0.00085,
27
+ beta_end=0.012,
28
+ beta_schedule="scaled_linear",
29
+ num_train_timesteps=1000,
30
+ trained_betas=None,
31
+ predict_epsilon=True,
32
+ thresholding=False,
33
+ algorithm_type="dpmsolver++",
34
+ solver_type="midpoint",
35
+ lower_order_final=True,
36
+ )
37
+
38
+ custom_model = None
39
+ if is_colab:
40
+ models.insert(0, Model("Custom model"))
41
+ custom_model = models[0]
42
+
43
+ last_mode = "txt2img"
44
+ current_model = models[1] if is_colab else models[0]
45
+ current_model_path = current_model.path
46
+
47
+ if is_colab:
48
+ pipe = StableDiffusionPipeline.from_pretrained(current_model.path, torch_dtype=torch.float16, scheduler=scheduler, safety_checker=lambda images, clip_input: (images, False))
49
+
50
+ else: # download all models
51
+ print(f"{datetime.datetime.now()} Downloading vae...")
52
+ vae = AutoencoderKL.from_pretrained(current_model.path, subfolder="vae", torch_dtype=torch.float16)
53
+ for model in models:
54
+ try:
55
+ print(f"{datetime.datetime.now()} Downloading {model.name} model...")
56
+ unet = UNet2DConditionModel.from_pretrained(model.path, subfolder="unet", torch_dtype=torch.float16)
57
+ model.pipe_t2i = StableDiffusionPipeline.from_pretrained(model.path, unet=unet, vae=vae, torch_dtype=torch.float16, scheduler=scheduler)
58
+ model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(model.path, unet=unet, vae=vae, torch_dtype=torch.float16, scheduler=scheduler)
59
+ except Exception as e:
60
+ print(f"{datetime.datetime.now()} Failed to load model " + model.name + ": " + str(e))
61
+ models.remove(model)
62
+ pipe = models[0].pipe_t2i
63
+
64
+ if torch.cuda.is_available():
65
+ pipe = pipe.to("cuda")
66
+
67
+ device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
68
+
69
+ def error_str(error, title="Error"):
70
+ return f"""#### {title}
71
+ {error}""" if error else ""
72
+
73
+ def custom_model_changed(path):
74
+ models[0].path = path
75
+ global current_model
76
+ current_model = models[0]
77
+
78
+ def on_model_change(model_name):
79
+
80
+ prefix = "Enter prompt. \"" + next((m.prefix for m in models if m.name == model_name), None) + "\" is prefixed automatically" if model_name != models[0].name else "Don't forget to use the custom model prefix in the prompt!"
81
+
82
+ return gr.update(visible = model_name == models[0].name), gr.update(placeholder=prefix)
83
+
84
+ def inference(model_name, prompt, guidance, steps, width=512, height=512, seed=0, img=None, strength=0.5, neg_prompt=""):
85
+
86
+ print(psutil.virtual_memory()) # print memory usage
87
+
88
+ global current_model
89
+ for model in models:
90
+ if model.name == model_name:
91
+ current_model = model
92
+ model_path = current_model.path
93
+
94
+ generator = torch.Generator('cuda').manual_seed(seed) if seed != 0 else None
95
+
96
+ try:
97
+ if img is not None:
98
+ return img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, generator), None
99
+ else:
100
+ return txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator), None
101
+ except Exception as e:
102
+ return None, error_str(e)
103
+
104
+ def txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator):
105
+
106
+ print(f"{datetime.datetime.now()} txt_to_img, model: {current_model.name}")
107
+
108
+ global last_mode
109
+ global pipe
110
+ global current_model_path
111
+ if model_path != current_model_path or last_mode != "txt2img":
112
+ current_model_path = model_path
113
+
114
+ if is_colab or current_model == custom_model:
115
+ pipe = StableDiffusionPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16, scheduler=scheduler, safety_checker=lambda images, clip_input: (images, False))
116
+ else:
117
+ pipe = pipe.to("cpu")
118
+ pipe = current_model.pipe_t2i
119
+
120
+ if torch.cuda.is_available():
121
+ pipe = pipe.to("cuda")
122
+ last_mode = "txt2img"
123
+
124
+ prompt = current_model.prefix + prompt
125
+ result = pipe(
126
+ prompt,
127
+ negative_prompt = neg_prompt,
128
+ # num_images_per_prompt=n_images,
129
+ num_inference_steps = int(steps),
130
+ guidance_scale = guidance,
131
+ width = width,
132
+ height = height,
133
+ generator = generator)
134
+
135
+ return replace_nsfw_images(result)
136
+
137
+ def img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, generator):
138
+
139
+ print(f"{datetime.datetime.now()} img_to_img, model: {model_path}")
140
+
141
+ global last_mode
142
+ global pipe
143
+ global current_model_path
144
+ if model_path != current_model_path or last_mode != "img2img":
145
+ current_model_path = model_path
146
+
147
+ if is_colab or current_model == custom_model:
148
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16, scheduler=scheduler, safety_checker=lambda images, clip_input: (images, False))
149
+ else:
150
+ pipe = pipe.to("cpu")
151
+ pipe = current_model.pipe_i2i
152
+
153
+ if torch.cuda.is_available():
154
+ pipe = pipe.to("cuda")
155
+ last_mode = "img2img"
156
+
157
+ prompt = current_model.prefix + prompt
158
+ ratio = min(height / img.height, width / img.width)
159
+ img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
160
+ result = pipe(
161
+ prompt,
162
+ negative_prompt = neg_prompt,
163
+ # num_images_per_prompt=n_images,
164
+ init_image = img,
165
+ num_inference_steps = int(steps),
166
+ strength = strength,
167
+ guidance_scale = guidance,
168
+ width = width,
169
+ height = height,
170
+ generator = generator)
171
 
172
+ return replace_nsfw_images(result)
173
+
174
+ def replace_nsfw_images(results):
175
+
176
+ if is_colab:
177
+ return results.images[0]
178
+
179
+ for i in range(len(results.images)):
180
+ if results.nsfw_content_detected[i]:
181
+ results.images[i] = Image.open("nsfw.png")
182
+ return results.images[0]
183
+
184
+ css = """.finetuned-diffusion-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.finetuned-diffusion-div div h1{font-weight:900;margin-bottom:7px}.finetuned-diffusion-div p{margin-bottom:10px;font-size:94%}a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem}
185
+ """
186
+ with gr.Blocks(css=css) as demo:
187
+ gr.HTML(
188
+ f"""
189
+ <div class="finetuned-diffusion-div">
190
+ <div>
191
+ <h1>Furry Diffusion</h1>
192
+ </div>
193
+ <p>This demo is slow on cpu, to use it upgrade to gpu by going to settings after duplicating this space: <a style="display:inline-block" href="https://huggingface.co/spaces/lunarfish/furrydiffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a> </p>
194
+ </p>
195
+ </div>
196
+ """
197
+ )
198
+ with gr.Row():
199
 
200
+ with gr.Column(scale=55):
201
+ with gr.Group():
202
+ model_name = gr.Dropdown(label="Model", choices=[m.name for m in models], value=current_model.name)
203
+ with gr.Box(visible=False) as custom_model_group:
204
+ custom_model_path = gr.Textbox(label="Custom model path", placeholder="Path to model, e.g. nitrosocke/Arcane-Diffusion", interactive=True)
205
+ gr.HTML("<div><font size='2'>Custom models have to be downloaded first, so give it some time.</font></div>")
206
+
207
+ with gr.Row():
208
+ prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=2,placeholder="Enter prompt. Style applied automatically").style(container=False)
209
+ generate = gr.Button(value="Generate").style(rounded=(False, True, True, False))
210
+
211
+
212
+ image_out = gr.Image(height=512)
213
+ # gallery = gr.Gallery(
214
+ # label="Generated images", show_label=False, elem_id="gallery"
215
+ # ).style(grid=[1], height="auto")
216
+ error_output = gr.Markdown()
217
+
218
+ with gr.Column(scale=45):
219
+ with gr.Tab("Options"):
220
+ with gr.Group():
221
+ neg_prompt = gr.Textbox(label="Negative prompt", placeholder="What to exclude from the image")
222
+
223
+ # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
224
+
225
+ with gr.Row():
226
+ guidance = gr.Slider(label="Guidance scale", value=7.5, maximum=15)
227
+ steps = gr.Slider(label="Steps", value=25, minimum=2, maximum=75, step=1)
228
+
229
+ with gr.Row():
230
+ width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
231
+ height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
232
+
233
+ seed = gr.Slider(0, 2147483647, label='Seed (0 = random)', value=0, step=1)
234
+
235
+ with gr.Tab("Image to image"):
236
+ with gr.Group():
237
+ image = gr.Image(label="Image", height=256, tool="editor", type="pil")
238
+ strength = gr.Slider(label="Transformation strength", minimum=0, maximum=1, step=0.01, value=0.5)
239
+
240
+ if is_colab:
241
+ model_name.change(on_model_change, inputs=model_name, outputs=[custom_model_group, prompt], queue=False)
242
+ custom_model_path.change(custom_model_changed, inputs=custom_model_path, outputs=None)
243
+ # n_images.change(lambda n: gr.Gallery().style(grid=[2 if n > 1 else 1], height="auto"), inputs=n_images, outputs=gallery)
244
+
245
+ inputs = [model_name, prompt, guidance, steps, width, height, seed, image, strength, neg_prompt]
246
+ outputs = [image_out, error_output]
247
+ prompt.submit(inference, inputs=inputs, outputs=outputs)
248
+ generate.click(inference, inputs=inputs, outputs=outputs)
249
+
250
+ ex = gr.Examples([
251
+ [models[0].name, "iron man", 7.5, 50],
252
+
253
+ ], inputs=[model_name, prompt, guidance, steps, seed], outputs=outputs, fn=inference, cache_examples=False)
254
+
255
+ gr.HTML("""
256
+ <div style="border-top: 1px solid #303030;">
257
+ <br>
258
+ <p>Model by Linaqruf</p>
259
+ </div>
260
+ """)
261
+
262
+ print(f"Space built in {time.time() - start_time:.2f} seconds")
263
+
264
+ if not is_colab:
265
+ demo.queue(concurrency_count=1)
266
+ demo.launch(debug=is_colab, share=is_colab)