Prgckwb's picture
:tada: add negative embedding
dc8517c
raw
history blame
No virus
6.18 kB
import dataclasses
import gradio as gr
import spaces
import torch
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import make_image_grid
DIFFUSERS_MODEL_IDS = [
# SD Models
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/stable-diffusion-2-1",
"runwayml/stable-diffusion-v1-5",
# Other Models
"Prgckwb/trpfrog-diffusion",
]
EXTERNAL_MODEL_MAPPING = {
"Beautiful Realistic Asians": "checkpoints/diffusers/Beautiful Realistic Asians v7",
}
MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_MAPPING.keys())
current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == 'cuda':
dtype = torch.float16
pipe = DiffusionPipeline.from_pretrained(
current_model_id,
torch_dtype=dtype,
)
pipe.enable_sequential_cpu_offload()
@dataclasses.dataclass
class Input:
prompt: str
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers"
negative_prompt: str = ''
width: int = 1024
height: int = 1024
guidance_scale: float = 7.5
num_inference_step: int = 28
num_images: int = 4
safety_checker: bool = True
def to_list(self):
return [
self.prompt, self.model_id, self.negative_prompt,
self.width, self.height, self.guidance_scale,
self.num_inference_step, self.num_images, self.safety_checker
]
EXAMPLES = [
Input(prompt='A cat holding a sign that says Hello world').to_list(),
Input(
prompt='Beautiful pixel art of a Wizard with hovering text "Achivement unlocked: Diffusion models can spell now"'
).to_list(),
Input(prompt='A corgi wearing sunglasses says "U-Net is OVER!!"').to_list(),
Input(
prompt='Cinematic Photo of a beautiful korean fashion model bokeh train',
model_id='Beautiful Realistic Asians',
negative_prompt='(worst_quality:2.0) (MajicNegative_V2:0.8) BadNegAnatomyV1-neg bradhands cartoon, cgi, render, illustration, painting, drawing',
width=512,
height=512,
guidance_scale=5.0,
num_inference_step=50,
).to_list()
]
@spaces.GPU(duration=120)
@torch.inference_mode()
def inference(
prompt: str,
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers",
negative_prompt: str = "",
width: int = 512,
height: int = 512,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
num_images: int = 4,
safety_checker: bool = True,
progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
progress(0, "Starting inference...")
global current_model_id, pipe
progress(0.1, 'Loading pipeline...')
if model_id != current_model_id:
# For NOT Diffusers' Models
if model_id not in DIFFUSERS_MODEL_IDS:
model_id = EXTERNAL_MODEL_MAPPING[model_id]
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
)
pipe.enable_sequential_cpu_offload()
current_model_id = model_id
if not safety_checker:
pipe.safety_checker = None
if model_id not in DIFFUSERS_MODEL_IDS:
progress(0.3, 'Loading Textual Inversion...')
# Load Textual Inversion
pipe.load_textual_inversion(
"checkpoints/embeddings/BadNegAnatomyV1 neg.pt", token='BadNegAnatomyV1-neg'
)
# Generation
progress(0.4, 'Generating images...')
images = pipe(
prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
).images
if num_images % 2 == 1:
image = make_image_grid(images, rows=num_images, cols=1)
else:
image = make_image_grid(images, rows=2, cols=num_images // 2)
return image
if __name__ == "__main__":
theme = gr.themes.Default(primary_hue=gr.themes.colors.emerald)
with gr.Blocks(theme=theme) as demo:
gr.Markdown(f"# Stable Diffusion Demo")
with gr.Row():
with gr.Column():
prompt = gr.Text(label="Prompt", placeholder="Enter a prompt here")
model_id = gr.Dropdown(
label="Model ID",
choices=MODEL_CHOICES,
value="stabilityai/stable-diffusion-3-medium-diffusers",
)
# Additional Input Settings
with gr.Accordion("Additional Settings", open=False):
negative_prompt = gr.Text(label="Negative Prompt", value="", )
with gr.Row():
width = gr.Number(label="Width", value=512, step=64, minimum=64, maximum=2048)
height = gr.Number(label="Height", value=512, step=64, minimum=64, maximum=2048)
num_images = gr.Number(label="Num Images", value=4, minimum=1, maximum=10, step=1)
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.5, minimum=0, maximum=10)
num_inference_step = gr.Slider(
label="Num Inference Steps", value=50, minimum=1, maximum=100, step=2
)
with gr.Row():
safety_checker = gr.Checkbox(value=True, label='Use Safety Checker')
with gr.Column():
output_image = gr.Image(label="Image", type="pil")
inputs = [
prompt,
model_id,
negative_prompt,
width,
height,
guidance_scale,
num_inference_step,
num_images,
safety_checker
]
btn = gr.Button("Generate")
btn.click(
fn=inference,
inputs=inputs,
outputs=output_image
)
gr.Examples(
examples=EXAMPLES,
inputs=inputs,
)
demo.queue().launch()