gaur3009 commited on
Commit
e4638d2
1 Parent(s): 60447dc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ from diffusers import DiffusionPipeline
5
+ import torch
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ if torch.cuda.is_available():
12
+ torch.cuda.max_memory_allocated(device=device)
13
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
14
+ pipe.enable_xformers_memory_efficient_attention()
15
+ pipe = pipe.to(device)
16
+ else:
17
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
18
+ pipe = pipe.to(device)
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ MAX_IMAGE_SIZE = 1024
22
+
23
+ # Function to apply FFT and return an image
24
+ def apply_fft(image: Image.Image):
25
+ # Convert the image to grayscale for FFT (can be extended for color images too)
26
+ image_gray = image.convert("L")
27
+
28
+ # Convert the image to numpy array
29
+ image_array = np.array(image_gray)
30
+
31
+ # Apply 2D FFT
32
+ fft_image = np.fft.fft2(image_array)
33
+ fft_shifted = np.fft.fftshift(fft_image) # Shift the zero frequency to the center
34
+
35
+ # Magnitude spectrum for visualization
36
+ magnitude_spectrum = 20 * np.log(np.abs(fft_shifted))
37
+
38
+ # Normalize magnitude spectrum to 0-255 for visualization
39
+ magnitude_spectrum = np.interp(magnitude_spectrum, (magnitude_spectrum.min(), magnitude_spectrum.max()), (0, 255))
40
+
41
+ # Convert back to image
42
+ fft_image_pil = Image.fromarray(magnitude_spectrum.astype(np.uint8))
43
+
44
+ return fft_image_pil
45
+
46
+ def infer(prompt_part1, color, dress_type, design, prompt_part5, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
47
+ prompt = f"{prompt_part1} {color} colored plain {dress_type} with {design} design, {prompt_part5}"
48
+
49
+ if randomize_seed:
50
+ seed = random.randint(0, MAX_SEED)
51
+
52
+ generator = torch.Generator().manual_seed(seed)
53
+
54
+ # Generate the image using the diffusion pipeline
55
+ image = pipe(
56
+ prompt=prompt,
57
+ negative_prompt=negative_prompt,
58
+ guidance_scale=guidance_scale,
59
+ num_inference_steps=num_inference_steps,
60
+ width=width,
61
+ height=height,
62
+ generator=generator
63
+ ).images[0]
64
+
65
+ # Apply FFT post-processing to the generated image
66
+ fft_image = apply_fft(image)
67
+
68
+ return fft_image
69
+
70
+ examples = [
71
+ "red, t-shirt, yellow stripes",
72
+ "blue, hoodie, minimalist",
73
+ "red, sweat shirt, geometric design",
74
+ ]
75
+
76
+ css = """
77
+ #col-container {
78
+ margin: 0 auto;
79
+ max-width: 520px;
80
+ }
81
+ """
82
+
83
+ if torch.cuda.is_available():
84
+ power_device = "GPU"
85
+ else:
86
+ power_device = "CPU"
87
+
88
+ with gr.Blocks(css=css) as demo:
89
+
90
+ with gr.Column(elem_id="col-container"):
91
+ gr.Markdown(f"""
92
+ # Text-to-Image Gradio Template with FFT Post-Processing
93
+ Currently running on {power_device}.
94
+ """)
95
+
96
+ with gr.Row():
97
+
98
+ prompt_part1 = gr.Textbox(
99
+ value="a single",
100
+ label="Prompt Part 1",
101
+ show_label=False,
102
+ interactive=False,
103
+ container=False,
104
+ elem_id="prompt_part1",
105
+ visible=False,
106
+ )
107
+
108
+ prompt_part2 = gr.Textbox(
109
+ label="color",
110
+ show_label=False,
111
+ max_lines=1,
112
+ placeholder="color (e.g., color category)",
113
+ container=False,
114
+ )
115
+
116
+ prompt_part3 = gr.Textbox(
117
+ label="dress_type",
118
+ show_label=False,
119
+ max_lines=1,
120
+ placeholder="dress_type (e.g., t-shirt, sweatshirt, shirt, hoodie)",
121
+ container=False,
122
+ )
123
+
124
+ prompt_part4 = gr.Textbox(
125
+ label="design",
126
+ show_label=False,
127
+ max_lines=1,
128
+ placeholder="design",
129
+ container=False,
130
+ )
131
+
132
+ prompt_part5 = gr.Textbox(
133
+ value="hanging on the plain wall",
134
+ label="Prompt Part 5",
135
+ show_label=False,
136
+ interactive=False,
137
+ container=False,
138
+ elem_id="prompt_part5",
139
+ visible=False,
140
+ )
141
+
142
+ run_button = gr.Button("Run", scale=0)
143
+
144
+ result = gr.Image(label="Result", show_label=False)
145
+
146
+ with gr.Accordion("Advanced Settings", open=False):
147
+
148
+ negative_prompt = gr.Textbox(
149
+ label="Negative prompt",
150
+ max_lines=1,
151
+ placeholder="Enter a negative prompt",
152
+ visible=False,
153
+ )
154
+
155
+ seed = gr.Slider(
156
+ label="Seed",
157
+ minimum=0,
158
+ maximum=MAX_SEED,
159
+ step=1,
160
+ value=0,
161
+ )
162
+
163
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
164
+
165
+ with gr.Row():
166
+
167
+ width = gr.Slider(
168
+ label="Width",
169
+ minimum=256,
170
+ maximum=MAX_IMAGE_SIZE,
171
+ step=32,
172
+ value=512,
173
+ )
174
+
175
+ height = gr.Slider(
176
+ label="Height",
177
+ minimum=256,
178
+ maximum=MAX_IMAGE_SIZE,
179
+ step=32,
180
+ value=512,
181
+ )
182
+
183
+ with gr.Row():
184
+
185
+ guidance_scale = gr.Slider(
186
+ label="Guidance scale",
187
+ minimum=0.0,
188
+ maximum=10.0,
189
+ step=0.1,
190
+ value=0.0,
191
+ )
192
+
193
+ num_inference_steps = gr.Slider(
194
+ label="Number of inference steps",
195
+ minimum=1,
196
+ maximum=12,
197
+ step=1,
198
+ value=2,
199
+ )
200
+
201
+ gr.Examples(
202
+ examples=examples,
203
+ inputs=[prompt_part2]
204
+ )
205
+
206
+ run_button.click(
207
+ fn=infer,
208
+ inputs=[prompt_part1, prompt_part2, prompt_part3, prompt_part4, prompt_part5, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
209
+ outputs=[result]
210
+ )
211
+
212
+ demo.queue().launch()