Silence1412's picture
Update app.py
3a7ed92
raw
history blame
6.11 kB
import gradio as gr
import numpy as np
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline
from transformers import pipeline, set_seed
import random
import re
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to('cpu')
gpt2_pipe = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2')
gpt2_pipe2 = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
def infer1(starting_text):
seed = random.randint(100, 1000000)
set_seed(seed)
if starting_text == "":
starting_text: str = re.sub(r"[,:\-–.!;?_]", '', starting_text)
response = gpt2_pipe(starting_text, max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=4)
response_list = []
for x in response:
resp = x['generated_text'].strip()
if resp != starting_text and len(resp) > (len(starting_text) + 4) and resp.endswith((":", "-", "β€”")) is False:
response_list.append(resp+'\n')
response_end = "\n".join(response_list)
response_end = re.sub('[^ ]+\.[^ ]+','', response_end)
response_end = response_end.replace("<", "").replace(">", "")
if response_end != "":
return response_end
def infer2(starting_text):
for count in range(6):
seed = random.randint(100, 1000000)
set_seed(seed)
# If the text field is empty
if starting_text == "":
starting_text: str = line[random.randrange(0, len(line))].replace("\n", "").lower().capitalize()
starting_text: str = re.sub(r"[,:\-–.!;?_]", '', starting_text)
print(starting_text)
response = gpt2_pipe2(starting_text, max_length=random.randint(60, 90), num_return_sequences=8)
response_list = []
for x in response:
resp = x['generated_text'].strip()
if resp != starting_text and len(resp) > (len(starting_text) + 4) and resp.endswith((":", "-", "β€”")) is False:
response_list.append(resp)
response_end = "\n".join(response_list)
response_end = re.sub('[^ ]+\.[^ ]+','', response_end)
response_end = response_end.replace("<", "").replace(">", "")
if response_end != "":
return response_end
if count == 5:
return response_end
def infer3(prompt, negative, steps, scale, seed):
generator = torch.Generator(device='cpu').manual_seed(seed)
img = pipe(
prompt,
height=512,
width=512,
num_inference_steps=steps,
guidance_scale=scale,
negative_prompt = negative,
generator=generator,
).images
return img
block = gr.Blocks()
with block:
with gr.Group():
with gr.Box():
gr.Markdown(
"""
Model: Gustavosta/MagicPrompt-Stable-Diffusion
"""
)
with gr.Row() as row:
with gr.Column():
txt = gr.Textbox(lines=1, label="Initial Text", placeholder="English Text here")
gpt_btn = gr.Button("Generate prompt").style(
margin=False,
rounded=(False, True, True, False),
)
with gr.Column():
out = gr.Textbox(lines=4, label="Generated Prompts")
with gr.Box():
gr.Markdown(
"""
Model: succinctly/text2image-prompt-generator
"""
)
with gr.Row() as row:
with gr.Column():
txt2 = gr.Textbox(lines=1, label="Initial Text", placeholder="English Text here")
gpt_btn2 = gr.Button("Generate prompt").style(
margin=False,
rounded=(False, True, True, False),
)
with gr.Column():
out2 = gr.Textbox(lines=4, label="Generated Prompts")
with gr.Box():
gr.Markdown(
"""
Model: stable diffusion v1.5
"""
)
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
with gr.Column():
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
negative = gr.Textbox(
label="Enter your negative prompt",
show_label=False,
placeholder="Enter a negative prompt",
elem_id="negative-prompt-text-input",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(columns=(1, 2), height="auto")
with gr.Row(elem_id="advanced-options"):
samples = gr.Slider(label="Images", minimum=1, maximum=1, value=1, step=1, interactive=False)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=12, step=1, interactive=True)
scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1, interactive=True)
seed = gr.Slider(label="Random seed",minimum=0,maximum=2147483647,step=1,randomize=True,interactive=True)
gpt_btn.click(infer1,inputs=txt,outputs=out)
gpt_btn2.click(infer2,inputs=txt2,outputs=out2)
btn.click(infer3, inputs=[text, negative, steps, scale, seed], outputs=[gallery])
block.launch(show_api=False,enable_queue=True, debug=True)